from transformers import CLIPTextModel, CLIPTokenizer, logging
from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.utils.import_utils import is_xformers_available
from os.path import isfile
from pathlib import Path

# suppress partial model loading warning
logging.set_verbosity_error()

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.utils import save_image

from torch.cuda.amp import custom_bwd, custom_fwd
#from .perpneg_utils import weighted_perpendicular_aggregator

import torch.jit as jit


def symlog(x):
    return torch.sign(x) * torch.log(1 + torch.abs(x))

def symexp(x):
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)

def get_perpendicular_component(x, y):
    assert x.shape == y.shape
    return x - ((torch.mul(x, y).sum())/max(torch.norm(y)**2, 1e-6)) * y


def batch_get_perpendicular_component(x, y):
    assert x.shape == y.shape
    result = []
    for i in range(x.shape[0]):
        result.append(get_perpendicular_component(x[i], y[i]))
    return torch.stack(result)

def seed_everything(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    #torch.backends.cudnn.deterministic = True
    #torch.backends.cudnn.benchmark = True

class StableDiffusion(nn.Module):#jit.ScriptModule):#nn.Module):
    def __init__(self, device, fp16, vram_O, sd_version='2.1', hf_key=None, t_range=[0.02, 0.98], text_cond=''):
        super().__init__()

        self.device = device
        self.sd_version = sd_version

        print(f'[INFO] loading stable diffusion...')

        if hf_key is not None:
            print(f'[INFO] using hugging face custom model key: {hf_key}')
            model_key = hf_key
        elif self.sd_version == '2.1':
            model_key = "stabilityai/stable-diffusion-2-1-base"
        elif self.sd_version == '2.0':
            model_key = "stabilityai/stable-diffusion-2-base"
        elif self.sd_version == '1.5':
            model_key = "runwayml/stable-diffusion-v1-5"
        else:
            raise ValueError(f'Stable-diffusion version {self.sd_version} not supported.')

        self.precision_t = torch.float16 if fp16 else torch.float32

        # Create model
        pipe = StableDiffusionPipeline.from_pretrained(model_key, torch_dtype=self.precision_t)

        if vram_O:
            pipe.enable_sequential_cpu_offload()
            pipe.enable_vae_slicing()
            pipe.unet.to(memory_format=torch.channels_last)
            pipe.enable_attention_slicing(1)
            # pipe.enable_model_cpu_offload()
        else:
            pipe.to(device)

        self.eval()
        for param in self.parameters():
            param.requires_grad = False

        self.vae = pipe.vae
        self.tokenizer = pipe.tokenizer
        self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet

        conditional_text = self.get_text_embeds([text_cond])
        unconditional_text = self.get_text_embeds([""])
        self.c_in = torch.cat([unconditional_text, conditional_text])

        self.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler", torch_dtype=self.precision_t)

        del pipe
        #del self.tokenizer
        #del self.text_encoder

        self.alphas = self.scheduler.alphas_cumprod.to(self.device) # for convenience

        print(f'[INFO] loaded stable diffusion!')

    @torch.no_grad()
    def get_text_embeds(self, prompt):
        # prompt: [str]

        inputs = self.tokenizer(prompt, padding='max_length', max_length=self.tokenizer.model_max_length, return_tensors='pt')
        embeddings = self.text_encoder(inputs.input_ids.to(self.device))[0]

        return embeddings


    def get_new_sds_alignment(self, text_embeddings, pred_rgb, noise_level=650, as_latent=False, noise=None):

        b = pred_rgb.shape[0]

        # interp to 512x512 to be fed into vae.
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) #TODO: unflag if we send in smaller dimension image
        # encode image into latents with vae, requires grad!
        with torch.no_grad():
            latents = self.encode_imgs(pred_rgb_512)

        #t = torch.randint(self.min_step, self.max_step + 1, [b], dtype=torch.long, device=self.device)
        t = torch.randint(noise_level, noise_level + 1, [b], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        sqrt_alpha_prod = self.scheduler.alphas_cumprod.to(self.device)[noise_level] ** 0.5
        sqrt_one_minus_alpha_prod = (1 - self.scheduler.alphas_cumprod.to(self.device)[noise_level]) ** 0.5

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = latents_noisy.repeat_interleave(repeats=2, dim=0)
            new_t = t.repeat_interleave(repeats=2, dim=0)
            noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=text_embeddings).sample

            noise_pred_uncond = noise_pred[torch.arange(0, b*2, step=2)]
            noise_pred_pos = noise_pred[torch.arange(1, b*2, step=2)]

            latent_pred_uncond = (latents_noisy - sqrt_one_minus_alpha_prod * noise_pred_uncond) / sqrt_alpha_prod.to(self.device)
            latent_pred_text = (latents_noisy - sqrt_one_minus_alpha_prod * noise_pred_pos) / sqrt_alpha_prod.to(self.device)


        noise_term = torch.exp(-4*((latent_pred_text - latents)**2).mean([1,2,3]))
        relative_term = 1 - torch.exp(-200 * ((latent_pred_text - latent_pred_uncond)**2).mean([1,2,3]))

        # By default grad_scale is set to 0.2 for a 650 noise level
        return noise_term + relative_term

    @torch.no_grad()
    def alternate_sds_alignment(self, pred_rgb, noise_level=400, alignment_scale=2000, noise=None):

        b = pred_rgb.shape[0]

        # interp to 512x512 to be fed into vae.
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        with torch.no_grad():
            latents = self.encode_imgs(pred_rgb_512)

        t = torch.randint(noise_level, noise_level + 50, [b], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = latents_noisy.repeat_interleave(repeats=2, dim=0)
            new_t = t.repeat_interleave(repeats=2, dim=0)
            #noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=txt_embed).sample
            noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=self.c_in).sample

        noise_pred_uncond = noise_pred[torch.arange(0, b*2, step=2)]
        noise_pred_pos = noise_pred[torch.arange(1, b*2, step=2)]

        alignment_pred = ((noise_pred_pos - noise_pred_uncond)**2).mean([1,2,3])
        #pos_natural_pred = ((noise_pred_pos - noise)**2).mean([1,2,3])
        #uncond_natural_pred = ((noise_pred_uncond - noise)**2).mean([1,2,3])
        #recon_pred = uncond_natural_pred - pos_natural_pred

        result = 1 - torch.exp(-alignment_scale * (alignment_pred))

        #np.mean(1 - np.exp(-0.1*10000 * good_deltas) + 0.05 * np.exp(-10*good_source_deltas), axis=0)
        # + 0.02 * np.exp(-100*(good_source_deltas - bad_source_deltas)), axis=0)

        # symlog(-100*(good_source_deltas - bad_source_deltas)) + symlog(2000*good_deltas)

        #result = 1 - torch.exp(-2000 * alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        #result = symlog(alignment_scale*alignment_pred) + symlog(recon_scale*recon_pred)
        #result = symlog(2000*alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        #result = symlog(4000*alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        #result = symlog(2000*alignment_pred) + symlog(-100*(pos_natural_pred - uncond_natural_pred))
        #noise_pred = noise_pred_uncond + 100 * (noise_pred_pos - noise_pred_uncond)
        #result = 1 - torch.exp(-0.5 * 0.4 * ((noise_pred - noise)**2).mean([1,2,3]))

        return result

    @torch.no_grad()
    def augmented_sds_alignment(self, pred_rgb, noise_level=400, alignment_scale=1000, recon_scale=500, noise=None):

        b = pred_rgb.shape[0]

        # interp to 512x512 to be fed into vae.
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        pred_rgb_512_norm = F.interpolate(pred_rgb/255., (512, 512), mode='bilinear', align_corners=False)
        with torch.no_grad():
            latents = self.encode_imgs(pred_rgb_512)
            norm_latents = self.encode_imgs(pred_rgb_512_norm)

        t = torch.randint(noise_level, noise_level + 100, [b], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = latents_noisy.repeat_interleave(repeats=2, dim=0)
            new_t = t.repeat_interleave(repeats=2, dim=0)
            #noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=txt_embed).sample
            noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=self.c_in).sample

            norm_latents_noisy = self.scheduler.add_noise(norm_latents, noise, t)
            norm_latent_model_input = norm_latents_noisy.repeat_interleave(repeats=2, dim=0)
            norm_noise_pred = self.unet(norm_latent_model_input, new_t, encoder_hidden_states=self.c_in).sample

        noise_pred_uncond = noise_pred[torch.arange(0, b*2, step=2)]
        noise_pred_pos = noise_pred[torch.arange(1, b*2, step=2)]
        norm_noise_pred_uncond = norm_noise_pred[torch.arange(0, b*2, step=2)]
        norm_noise_pred_pos = norm_noise_pred[torch.arange(1, b*2, step=2)]

        # computed from no norm
        alignment_pred = ((noise_pred_pos - noise_pred_uncond)**2).mean([1,2,3])

        # computed from norm
        pos_natural_pred = ((norm_noise_pred_pos - noise)**2).mean([1,2,3])
        uncond_natural_pred = ((norm_noise_pred_uncond - noise)**2).mean([1,2,3])
        recon_pred = uncond_natural_pred - pos_natural_pred

        #np.mean(1 - np.exp(-0.1*10000 * good_deltas) + 0.05 * np.exp(-10*good_source_deltas), axis=0)
        # + 0.02 * np.exp(-100*(good_source_deltas - bad_source_deltas)), axis=0)

        # symlog(-100*(good_source_deltas - bad_source_deltas)) + symlog(2000*good_deltas)

        #result = 1 - torch.exp(-2000 * alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        result = symlog(alignment_scale*alignment_pred) + symlog(recon_scale*recon_pred)
        #result = symlog(2000*alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        #result = symlog(4000*alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        #result = symlog(2000*alignment_pred) + symlog(-100*(pos_natural_pred - uncond_natural_pred))
        #noise_pred = noise_pred_uncond + 100 * (noise_pred_pos - noise_pred_uncond)
        #result = 1 - torch.exp(-0.5 * 0.4 * ((noise_pred - noise)**2).mean([1,2,3]))

        return result

    @torch.no_grad()
    def disentangled_sds_alignment(self, pred_rgb, noise_level=400, alignment_scale=200, recon_scale=2000, noise=None):

        b = pred_rgb.shape[0]

        # interp to 512x512 to be fed into vae.
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        with torch.no_grad():
            latents = self.encode_imgs(pred_rgb_512)

        t = torch.randint(noise_level, noise_level + 100, [b], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = latents_noisy.repeat_interleave(repeats=2, dim=0)
            new_t = t.repeat_interleave(repeats=2, dim=0)
            #noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=txt_embed).sample
            noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=self.c_in).sample

        noise_pred_uncond = noise_pred[torch.arange(0, b*2, step=2)]
        noise_pred_pos = noise_pred[torch.arange(1, b*2, step=2)]

        alignment_pred = ((noise_pred_pos - noise_pred_uncond)**2).mean([1,2,3])
        pos_natural_pred = ((noise_pred_pos - noise)**2).mean([1,2,3])
        uncond_natural_pred = ((noise_pred_uncond - noise)**2).mean([1,2,3])
        recon_pred = uncond_natural_pred - pos_natural_pred

        #np.mean(1 - np.exp(-0.1*10000 * good_deltas) + 0.05 * np.exp(-10*good_source_deltas), axis=0)
        # + 0.02 * np.exp(-100*(good_source_deltas - bad_source_deltas)), axis=0)

        # symlog(-100*(good_source_deltas - bad_source_deltas)) + symlog(2000*good_deltas)

        #result = 1 - torch.exp(-2000 * alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        result = symlog(alignment_scale*alignment_pred) + symlog(recon_scale*recon_pred)
        #result = symlog(2000*alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        #result = symlog(4000*alignment_pred) + symlog(-200*(pos_natural_pred - uncond_natural_pred))
        #result = symlog(2000*alignment_pred) + symlog(-100*(pos_natural_pred - uncond_natural_pred))
        #noise_pred = noise_pred_uncond + 100 * (noise_pred_pos - noise_pred_uncond)
        #result = 1 - torch.exp(-0.5 * 0.4 * ((noise_pred - noise)**2).mean([1,2,3]))

        return result#, symlog(alignment_scale*alignment_pred)


    def get_sds_alignment(self, pred_rgb, noise_level=400, guidance_scale=500, other_scale=100, noise=None):

        b = pred_rgb.shape[0]

        # interp to 512x512 to be fed into vae.
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        with torch.no_grad():
            latents = self.encode_imgs(pred_rgb_512)

        t = torch.randint(noise_level, noise_level + 50, [b], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = latents_noisy.repeat_interleave(repeats=2, dim=0)
            new_t = t.repeat_interleave(repeats=2, dim=0)
            #noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=txt_embed).sample
            noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=self.c_in).sample

        noise_pred_uncond = noise_pred[torch.arange(0, b*2, step=2)]
        noise_pred_pos = noise_pred[torch.arange(1, b*2, step=2)]

        #alignment_pred = ((noise_pred_pos - noise_pred_uncond)**2).mean([1,2,3])
        #pos_natural_pred = ((noise_pred_pos - noise)**2).mean([1,2,3])
        #uncond_natural_pred = ((noise_pred_uncond - noise)**2).mean([1,2,3])

        #np.mean(1 - np.exp(-0.1*10000 * good_deltas) + 0.05 * np.exp(-10*good_source_deltas), axis=0)
        # + 0.02 * np.exp(-100*(good_source_deltas - bad_source_deltas)), axis=0)

        #result = 1 - torch.exp(-guidance_scale*alignment_pred) + 0.02*torch.exp(-other_scale*(pos_natural_pred - uncond_natural_pred))
        noise_pred = noise_pred_uncond + 100 * (noise_pred_pos - noise_pred_uncond)
        result = 1 - torch.exp(-0.5 * 0.4 * ((noise_pred - noise)**2).mean([1,2,3]))

        return result


    def old_get_sds_alignment(self, text_embeddings, pred_rgb, noise_level=450, guidance_scale=100, as_latent=False, grad_scale=0.1, noise=None):

        b = pred_rgb.shape[0]

        # interp to 512x512 to be fed into vae.
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False) #TODO: unflag if we send in smaller dimension image
        # encode image into latents with vae, requires grad!
        with torch.no_grad():
            latents = self.encode_imgs(pred_rgb_512)

        #t = torch.randint(self.min_step, self.max_step + 1, [b], dtype=torch.long, device=self.device)
        t = torch.randint(noise_level, noise_level + 1, [b], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = latents_noisy.repeat_interleave(repeats=2, dim=0)
            new_t = t.repeat_interleave(repeats=2, dim=0)
            noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=text_embeddings).sample

        # perform guidance (high scale from paper!)
        #noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
        noise_pred_uncond = noise_pred[torch.arange(0, b*2, step=2)]
        noise_pred_pos = noise_pred[torch.arange(1, b*2, step=2)]

        #noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)
        noise_pred = noise_pred_pos + guidance_scale * (noise_pred_pos - noise_pred_uncond)

        w = (1 - self.alphas[t])**2

        # By default grad_scale is set to 0.2 for a 650 noise level
        return 1 - torch.exp(-grad_scale * w * ((noise_pred - noise)**2).mean([1,2,3]))
        #return 1 - torch.exp(-grad_scale * w * ((noise_pred - noise)**2).mean([1,2,3]))
        #return torch.exp(grad_scale * w * ((noise_pred - noise)**2).mean([1,2,3])), noise

    def get_sds_perpneg(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=0.5, perp_neg_scale=2.0, noise=None):

        b = pred_rgb.shape[0]
        pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
        # encode image into latents with vae, requires grad!
        with torch.no_grad():
            latents = self.encode_imgs(pred_rgb_512)

        t = torch.randint(750, 750 + 1, [b], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = latents_noisy.repeat_interleave(repeats=3, dim=0)
            new_t = t.repeat_interleave(repeats=3, dim=0)
            noise_pred = self.unet(latent_model_input, new_t, encoder_hidden_states=text_embeddings).sample

        # perform guidance (high scale from paper!)
        #noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
        noise_pred_uncond = noise_pred[torch.arange(0, b*3, step=3)]
        noise_pred_pos = noise_pred[torch.arange(1, b*3, step=3)]
        noise_pred_neg = noise_pred[torch.arange(2, b*3, step=3)]

        pos_diff = noise_pred_pos - noise_pred_uncond
        neg_diff = noise_pred_neg - noise_pred_uncond
        perp = batch_get_perpendicular_component(neg_diff, pos_diff)

        noise_pred = noise_pred_uncond + guidance_scale * (pos_diff - perp_neg_scale * perp)

        w = (1 - self.alphas[t])**2

        # By default grad_scale is set to 0.2 for a 650 noise level
        #return torch.exp(-grad_scale * w * ((noise_pred - noise)**2).mean([1,2,3])), noise
        return 1 - torch.exp(-grad_scale * w * ((noise_pred - noise)**2).mean([1,2,3]))



    def train_step(self, text_embeddings, pred_rgb, source_noise, t, guidance_scale=100, grad_scale=1):

        latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1

        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
        #t = torch.randint(600, 601, (latents.shape[0],), dtype=torch.long, device=self.device)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            #noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, source_noise, t)
            ###latent_model_input = latents_noisy.repeat_interleave(repeats=2, dim=0)
            ###text_embeddings = torch.cat([text_embeddings] * 64)
            #text_embeddings = torch.cat([text_embeddings] * 16)
            ###tt = torch.cat([t] * 128)
            #tt = t.repeat_interleave(repeats=2, dim=0)
            #noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
            noise_pred_uncond = self.unet(latents_noisy, torch.cat([t] * 64), encoder_hidden_states=torch.cat([text_embeddings[:1]] * 64)).sample
            noise_pred_pos = self.unet(latents_noisy, torch.cat([t] * 64), encoder_hidden_states=torch.cat([text_embeddings[1:]] * 64)).sample

            #noise_pred_uncond = noise_pred[torch.arange(0, noise_pred.shape[0], step=2)]
            #noise_pred_pos = noise_pred[torch.arange(1, noise_pred.shape[0], step=2)]
            '''
            noise_pred_uncond = self.unet(latents_noisy, t, encoder_hidden_states=text_embeddings[:1])
            noise_pred_pos = self.unet(latents_noisy, t, encoder_hidden_states=text_embeddings[1:])
            '''
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)

        w = (1 - self.alphas[t])
        grad = grad_scale * w[:, None, None, None] * (noise_pred - source_noise)
        grad = torch.nan_to_num(grad)

        targets = (latents - grad).detach()
        loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]

        return loss
    
    '''
    def train_step(self, text_embeddings, pred_rgb, guidance_scale=100, grad_scale=1):

        latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1

        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
        t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            latent_model_input = torch.cat([latents_noisy] * 2)
            tt = torch.cat([t] * 2)
            noise_pred = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample
            noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_pos - noise_pred_uncond)

        w = (1 - self.alphas[t])
        grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
        grad = torch.nan_to_num(grad)

        targets = (latents - grad).detach()
        loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]

        return loss
    '''
    

    def train_step_perpneg(self, text_embeddings, weights, pred_rgb, guidance_scale=100, as_latent=False, grad_scale=1,
                   save_guidance_path:Path=None):

        B = pred_rgb.shape[0]
        K = (text_embeddings.shape[0] // B) - 1 # maximum number of prompts       

        if as_latent:
            latents = F.interpolate(pred_rgb, (64, 64), mode='bilinear', align_corners=False) * 2 - 1
        else:
            # interp to 512x512 to be fed into vae.
            pred_rgb_512 = F.interpolate(pred_rgb, (512, 512), mode='bilinear', align_corners=False)
            # encode image into latents with vae, requires grad!
            latents = self.encode_imgs(pred_rgb_512)

        # timestep ~ U(0.02, 0.98) to avoid very high/low noise level
        t = torch.randint(self.min_step, self.max_step + 1, (latents.shape[0],), dtype=torch.long, device=self.device)

        # predict the noise residual with unet, NO grad!
        with torch.no_grad():
            # add noise
            noise = torch.randn_like(latents)
            latents_noisy = self.scheduler.add_noise(latents, noise, t)
            # pred noise
            latent_model_input = torch.cat([latents_noisy] * (1 + K))
            tt = torch.cat([t] * (1 + K))
            unet_output = self.unet(latent_model_input, tt, encoder_hidden_states=text_embeddings).sample

            # perform guidance (high scale from paper!)
            noise_pred_uncond, noise_pred_text = unet_output[:B], unet_output[B:]
            delta_noise_preds = noise_pred_text - noise_pred_uncond.repeat(K, 1, 1, 1)
            noise_pred = noise_pred_uncond + guidance_scale #* weighted_perpendicular_aggregator(delta_noise_preds, weights, B)            

        # import kiui
        # latents_tmp = torch.randn((1, 4, 64, 64), device=self.device)
        # latents_tmp = latents_tmp.detach()
        # kiui.lo(latents_tmp)
        # self.scheduler.set_timesteps(30)
        # for i, t in enumerate(self.scheduler.timesteps):
        #     latent_model_input = torch.cat([latents_tmp] * 3)
        #     noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']
        #     noise_pred_uncond, noise_pred_pos = noise_pred.chunk(2)
        #     noise_pred = noise_pred_uncond + 10 * (noise_pred_pos - noise_pred_uncond)
        #     latents_tmp = self.scheduler.step(noise_pred, t, latents_tmp)['prev_sample']
        # imgs = self.decode_latents(latents_tmp)
        # kiui.vis.plot_image(imgs)

        # w(t), sigma_t^2
        w = (1 - self.alphas[t])
        grad = grad_scale * w[:, None, None, None] * (noise_pred - noise)
        grad = torch.nan_to_num(grad)

        if save_guidance_path:
            with torch.no_grad():
                if as_latent:
                    pred_rgb_512 = self.decode_latents(latents)

                # visualize predicted denoised image
                # The following block of code is equivalent to `predict_start_from_noise`...
                # see zero123_utils.py's version for a simpler implementation.
                alphas = self.scheduler.alphas.to(latents)
                total_timesteps = self.max_step - self.min_step + 1
                index = total_timesteps - t.to(latents.device) - 1 
                b = len(noise_pred)
                a_t = alphas[index].reshape(b,1,1,1).to(self.device)
                sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
                sqrt_one_minus_at = sqrt_one_minus_alphas[index].reshape((b,1,1,1)).to(self.device)                
                pred_x0 = (latents_noisy - sqrt_one_minus_at * noise_pred) / a_t.sqrt() # current prediction for x_0
                result_hopefully_less_noisy_image = self.decode_latents(pred_x0.to(latents.type(self.precision_t)))

                # visualize noisier image
                result_noisier_image = self.decode_latents(latents_noisy.to(pred_x0).type(self.precision_t))



                # all 3 input images are [1, 3, H, W], e.g. [1, 3, 512, 512]
                viz_images = torch.cat([pred_rgb_512, result_noisier_image, result_hopefully_less_noisy_image],dim=0)
                save_image(viz_images, save_guidance_path)

        targets = (latents - grad).detach()
        loss = 0.5 * F.mse_loss(latents.float(), targets, reduction='sum') / latents.shape[0]

        return loss


    @torch.no_grad()
    def produce_latents(self, text_embeddings, height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):

        if latents is None:
            latents = torch.randn((text_embeddings.shape[0] // 2, self.unet.in_channels, height // 8, width // 8), device=self.device)

        self.scheduler.set_timesteps(num_inference_steps)

        for i, t in enumerate(self.scheduler.timesteps):
            # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
            latent_model_input = torch.cat([latents] * 2)
            # predict the noise residual
            noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)['sample']

            # perform guidance
            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

            # compute the previous noisy sample x_t -> x_t-1
            latents = self.scheduler.step(noise_pred, t, latents)['prev_sample']

        return latents

    def decode_latents(self, latents):

        latents = 1 / self.vae.config.scaling_factor * latents

        imgs = self.vae.decode(latents).sample
        imgs = (imgs / 2 + 0.5).clamp(0, 1)

        return imgs

    def encode_imgs(self, imgs):
        # imgs: [B, 3, H, W]

        imgs = 2 * imgs - 1

        posterior = self.vae.encode(imgs).latent_dist
        latents = posterior.sample() * self.vae.config.scaling_factor

        return latents

    def prompt_to_img(self, prompts, negative_prompts='', height=512, width=512, num_inference_steps=50, guidance_scale=7.5, latents=None):

        if isinstance(prompts, str):
            prompts = [prompts]

        if isinstance(negative_prompts, str):
            negative_prompts = [negative_prompts]

        # Prompts -> text embeds
        pos_embeds = self.get_text_embeds(prompts) # [1, 77, 768]
        neg_embeds = self.get_text_embeds(negative_prompts)
        text_embeds = torch.cat([neg_embeds, pos_embeds], dim=0) # [2, 77, 768]

        # Text embeds -> img latents
        latents = self.produce_latents(text_embeds, height=height, width=width, latents=latents, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale) # [1, 4, 64, 64]

        # Img latents -> imgs
        imgs = self.decode_latents(latents) # [1, 3, 512, 512]

        # Img to Numpy
        imgs = imgs.detach().cpu().permute(0, 2, 3, 1).numpy()
        imgs = (imgs * 255).round().astype('uint8')

        return imgs


if __name__ == '__main__':

    import argparse
    import matplotlib.pyplot as plt

    parser = argparse.ArgumentParser()
    parser.add_argument('prompt', type=str)
    parser.add_argument('--negative', default='', type=str)
    parser.add_argument('--sd_version', type=str, default='2.1', choices=['1.5', '2.0', '2.1'], help="stable diffusion version")
    parser.add_argument('--hf_key', type=str, default=None, help="hugging face Stable diffusion model key")
    parser.add_argument('--fp16', action='store_true', help="use float16 for training")
    parser.add_argument('--vram_O', action='store_true', help="optimization for low VRAM usage")
    parser.add_argument('-H', type=int, default=512)
    parser.add_argument('-W', type=int, default=512)
    parser.add_argument('--seed', type=int, default=0)
    parser.add_argument('--steps', type=int, default=50)
    opt = parser.parse_args()

    seed_everything(opt.seed)

    device = torch.device('cuda')

    sd = StableDiffusion(device, opt.fp16, opt.vram_O, opt.sd_version, opt.hf_key)

    imgs = sd.prompt_to_img(opt.prompt, opt.negative, opt.H, opt.W, opt.steps)

    # visualize image
    plt.imshow(imgs[0])
    plt.show()




