import time
import numpy as np
from typing import Any, Callable, Dict, List, Optional, Union
from diffusers import AnimateDiffPipeline, LCMScheduler, MotionAdapter, DDIMScheduler
from diffusers.pipelines.animatediff import AnimateDiffPipelineOutput
from diffusers.image_processor import PipelineImageInput
from diffusers.utils import (USE_PEFT_BACKEND,
    deprecate,
    logging,
    replace_example_docstring,
    scale_lora_layers,
    unscale_lora_layers,
    export_to_gif)

from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F

#from dreamfusion.guidance.resnet_wider import resnet50x4

def symlog(x):
    return torch.sign(x) * torch.log(1 + torch.abs(x))

def symexp(x):
    return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)

class Txt2VidDiffusion(nn.Module):
    def __init__(self):
        super().__init__()
        #simclr_model = resnet50x4()
        #sd = torch.load('src/dreamfusion/guidance/resnet50-4x.pth', map_location='cpu')
        #simclr_model.load_state_dict(sd['state_dict'])
        #self.simclr_model = torch.nn.DataParallel(simclr_model).to('cuda')
        #self.simclr_model.eval()

    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt):
        # prompt, negative_prompt: [str]

        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            self.device,
            1,
            True,
            negative_prompt,
        )
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        return prompt_embeds

    @torch.no_grad()
    def encode_imgs(self, rgb_input):
        # rgb_input: [B, 3, H, W]

        rgb_input = F.interpolate(rgb_input, (512, 512), mode='bilinear', align_corners=False)

        normalized = self.image_processor.preprocess(rgb_input)

        latents = self.vae.encode(normalized).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor
        return latents

    @torch.no_grad()
    def encode_simclr(self, rgb_input):
        rgb_input = F.interpolate(rgb_input, (256, 256), mode='bilinear', align_corners=False)
        return self.simclr_model(rgb_input)

    @torch.no_grad()
    def tensor2vid(self, video: torch.Tensor):
        batch_size, channels, num_frames, height, width = video.shape
        outputs = []
        for batch_idx in range(batch_size):
            batch_vid = video[batch_idx].permute(1, 0, 2, 3)
            batch_output = self.image_processor.postprocess(batch_vid, 'pt')
            batch_output = F.interpolate(batch_output, (256, 256), mode='bilinear', align_corners=False)

            outputs.append(batch_output)
        outputs = torch.stack(outputs)
        return outputs

    @torch.no_grad()
    def get_simclr_alignment(self, text_embeddings, latents, simclr_latents, guidance_scale=2.5, noise_level=400, noise=None, length=150):

        latents = latents[:length]

        t = torch.randint(noise_level, noise_level + 100, [1], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        latents_noisy = self.scheduler.add_noise(latents, noise, t)

        latent_preds = torch.zeros_like(latents.permute(1,0,2,3)).to('cuda')
        # maybe implement a mask here?
        masks = torch.zeros((latents.shape[0])).to('cuda')

        batch_size = 16
        for i in range(latents_noisy.shape[0] // batch_size + int(latents_noisy.shape[0] % batch_size != 0)):
            batch_input = latents_noisy[i*batch_size:(i+1)*batch_size].permute(1,0,2,3).unsqueeze(0)
            latent_model_input = torch.cat([batch_input] * 2)

            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_embeddings,
            ).sample

            noise_pred_uncond, noise_pred_text = noise_pred#.chunk(2)

            #latent_preds[:, i*batch_size:(i+1)*batch_size] = noise_pred_uncond + 2.5*(noise_pred_text - noise_pred_uncond)
            noise_pred = noise_pred_uncond + guidance_scale*(noise_pred_text - noise_pred_uncond)
            self.pipe.scheduler.num_inference_steps = 1
            self.pipe.scheduler._step_index = 0
            latent_preds[:, i*batch_size:(i+1)*batch_size] = self.pipe.scheduler.step(noise_pred, t, batch_input).prev_sample
        
        video_tensor = self.pipe.decode_latents(latent_preds.unsqueeze(0))
        vids = self.tensor2vid(video_tensor)

        pred_output = self.encode_simclr(vids[0])

        simclr_sim = torch.mm(simclr_latents, pred_output.T) / torch.mm(torch.norm(simclr_latents, dim=1, keepdim=True), torch.norm(pred_output, dim=1, keepdim=True).T)

        #TODO: make the simclr_sim diagonal instead averaged over overlapping chunks of 16.

        #pred_output = self.simclr_model(vids[0])
        # torch.mm(output, output.T) / torch.mm(torch.norm(output, dim=1, keepdim=True), torch.norm(output, dim=1, keepdim=True).T)
        rewards = torch.diag(simclr_sim)
        return rewards
    '''
    @torch.no_grad()
    def get_sds_alignment(self, text_embeddings, latents, noise_level=400, alignment_scale=2000, recon_scale=200, noise=None, length=150):

        latents = latents[:length]

        t = torch.randint(noise_level, noise_level + 100, [1], dtype=torch.long, device=self.device)

        if noise is None:
            noise = torch.randn_like(latents)

        latents_noisy = self.scheduler.add_noise(latents, noise, t)

        alignment_preds = torch.zeros((latents.shape[0])).to('cuda')
        recon_preds = torch.zeros((latents.shape[0])).to('cuda')
        # maybe implement a mask here?
        masks = torch.zeros((latents.shape[0])).to('cuda')

        #batch_size = 16
        batch_size = 1
        for i in range(latents_noisy.shape[0] // (batch_size // 2)):# + int(latents_noisy.shape[0] % (batch_size//2) != 0)):
            batch_input = latents_noisy[i*(batch_size//2):(i+2)*(batch_size//2)].permute(1,0,2,3).unsqueeze(0)
            latent_model_input = torch.cat([batch_input] * 2)

            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_embeddings,
            ).sample

            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

            source_noise = noise[i*(batch_size//2):(i+2)*(batch_size//2)].permute(1,0,2,3).unsqueeze(0)

            alignment_pred = ((noise_pred_text - noise_pred_uncond)**2).mean([0,1,3,4])
            pos_natural_pred = ((noise_pred_text - source_noise)**2).mean([0,1,3,4])
            uncond_natural_pred = ((noise_pred_uncond - source_noise)**2).mean([0,1,3,4])
            recon_pred = uncond_natural_pred - pos_natural_pred

            alignment_preds[i*(batch_size//2):(i+2)*(batch_size//2)] += alignment_pred
            recon_preds[i*(batch_size//2):(i+2)*(batch_size//2)] += recon_pred
            masks[i*(batch_size//2):(i+2)*(batch_size//2)] += 1

        import pdb
        pdb.set_trace()

        rewards = symlog(alignment_scale*alignment_preds / masks) + symlog(recon_scale*recon_preds / masks)

        return rewards
    '''


    @torch.no_grad()
    def get_sds_alignment(self, text_embeddings, latents, noise_level=400, alignment_scale=2000, recon_scale=200, noise=None, length=150):

        latents = latents[:length]

        t = torch.randint(noise_level, noise_level + 100, [1], dtype=torch.long, device=self.device)

        if noise is None:
            #noise = torch.randn_like(latents[:1])
            noise = torch.randn_like(latents)

        latents_noisy = self.scheduler.add_noise(latents, noise, t)

        alignment_preds = torch.zeros((latents.shape[0])).to('cuda')
        recon_preds = torch.zeros((latents.shape[0])).to('cuda')
        # maybe implement a mask here?

        #batch_size = 16
        batch_size = 2
        for i in range(latents_noisy.shape[0] // (batch_size)):# + int(latents_noisy.shape[0] % (batch_size//2) != 0)):
            batch_input = latents_noisy[i*(batch_size):(i+1)*(batch_size)].permute(1,0,2,3).unsqueeze(0)
            latent_model_input = torch.cat([batch_input] * 2)

            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_embeddings,
            ).sample

            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

            #source_noise = noise[i*(batch_size//2):(i+2)*(batch_size//2)].permute(1,0,2,3).unsqueeze(0)
            #source_noise = noise.permute(1,0,2,3).unsqueeze(0)
            source_noise = noise[i*(batch_size):(i+1)*batch_size].permute(1,0,2,3).unsqueeze(0)

            alignment_pred = ((noise_pred_text - noise_pred_uncond)**2).mean([0,1,3,4])
            pos_natural_pred = ((noise_pred_text - source_noise)**2).mean([0,1,3,4])
            uncond_natural_pred = ((noise_pred_uncond - source_noise)**2).mean([0,1,3,4])
            recon_pred = uncond_natural_pred - pos_natural_pred

            alignment_preds[i*(batch_size):(i+1)*(batch_size)] += alignment_pred
            recon_preds[i*(batch_size):(i+1)*(batch_size)] += recon_pred
            #masks[i*(batch_size):(i+1)*(batch_size)] += 1

        rewards = symlog(alignment_scale*alignment_preds) + symlog(recon_scale*recon_preds)
        #rewards = symlog(alignment_scale*alignment_preds / masks) + symlog(recon_scale*recon_preds / masks)

        return rewards

    @torch.no_grad()
    def get_loop_sds_alignment(self, text_embeddings, latents, t, alignment_scale=2000, recon_scale=200, noise=None):

        if noise is None:
            #noise = torch.randn_like(latents[:1])
            noise = torch.randn_like(latents)

        latents_noisy = self.scheduler.add_noise(latents, noise, t)

        batch_input = latents_noisy.permute(1,0,2,3).unsqueeze(0)
        latent_model_input = torch.cat([batch_input] * 2)

        noise_pred = self.unet(
            latent_model_input,
            t,
            encoder_hidden_states=text_embeddings,
        ).sample

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

        source_noise = noise.permute(1,0,2,3).unsqueeze(0)

        alignment_pred = ((noise_pred_text - noise_pred_uncond)**2).mean([0,1,3,4])
        pos_natural_pred = ((noise_pred_text - source_noise)**2).mean([0,1,3,4])
        uncond_natural_pred = ((noise_pred_uncond - source_noise)**2).mean([0,1,3,4])
        recon_pred = uncond_natural_pred - pos_natural_pred

        rewards = symlog(alignment_scale*alignment_pred) + symlog(recon_scale*recon_pred)
        return rewards



class VideoAnimateLCM(Txt2VidDiffusion):
    def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):
        super().__init__()

        self.device = device

        print(f'[INFO] loading video diffusion...')

        self.precision_t = torch.float16 if fp16 else torch.float32

        # Create model
        adapter = MotionAdapter.from_pretrained("wangfuyun/AnimateLCM", torch_dtype=torch.float16)
        pipe = AnimateDiffPipeline.from_pretrained("emilianJR/epiCRealism", motion_adapter=adapter, torch_dtype=torch.float16)
        # probably not needed
        pipe.scheduler = LCMScheduler.from_config(pipe.scheduler.config, beta_schedule="linear")

        pipe.load_lora_weights("wangfuyun/AnimateLCM", weight_name="AnimateLCM_sd15_t2v_lora.safetensors", adapter_name="lcm-lora")
        pipe.set_adapters(["lcm-lora"], [0.8])

        pipe.enable_vae_slicing()
        pipe.enable_model_cpu_offload()
        pipe.to(device)

        self.pipe = pipe
    
        self.vae = pipe.vae
        self.encode_prompt = pipe.encode_prompt
        #self.tokenizer = pipe.tokenizer
        #self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet
        self.scheduler = pipe.scheduler

        self.image_processor = pipe.image_processor

        print(f'[INFO] loaded video diffusion!')

    '''
    @torch.no_grad()
    def get_text_embeds(self, prompt, negative_prompt):
        # prompt, negative_prompt: [str]

        prompt_embeds, negative_prompt_embeds = self.encode_prompt(
            prompt,
            self.device,
            1,
            True,
            negative_prompt,
        )
        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
        return prompt_embeds

    @torch.no_grad()
    def encode_imgs(self, rgb_input):
        # rgb_input: [B, 3, H, W]

        rgb_input = F.interpolate(rgb_input, (512, 512), mode='bilinear', align_corners=False)

        normalized = self.image_processor.preprocess(rgb_input)

        latents = self.vae.encode(normalized).latent_dist.sample()
        latents = latents * self.vae.config.scaling_factor
        return latents

    @torch.no_grad()
    def get_sds_alignment(self, text_embeddings, latents, alignment_scale=2000, recon_scale=200, noise_level=400, noise=None, length=150):

        latents = latents[:length]

        t = torch.randint(noise_level, noise_level + 100, [1], dtype=torch.long, device=self.device)

        # maybe implement a mask here?

        if noise is None:
            noise = torch.randn_like(latents)

        latents_noisy = self.scheduler.add_noise(latents, noise, t)

        rewards = torch.zeros((latents.shape[0])).to('cuda')

        batch_size = 16
        for i in range(latents_noisy.shape[0] // batch_size + int(latents_noisy.shape[0] % batch_size != 0)):
            batch_input = latents_noisy[i*batch_size:(i+1)*batch_size].permute(1,0,2,3).unsqueeze(0)
            latent_model_input = torch.cat([batch_input] * 2)

            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=text_embeddings,
            ).sample

            noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)

            source_noise = noise[i*batch_size:(i+1)*batch_size].permute(1,0,2,3).unsqueeze(0)

            alignment_pred = ((noise_pred_text - noise_pred_uncond)**2).mean([0,1,3,4])
            pos_natural_pred = ((noise_pred_text - source_noise)**2).mean([0,1,3,4])
            uncond_natural_pred = ((noise_pred_uncond - source_noise)**2).mean([0,1,3,4])
            recon_pred = uncond_natural_pred - pos_natural_pred

            rewards[i*batch_size:(i+1)*batch_size] = symlog(alignment_scale*alignment_pred) + symlog(recon_scale*recon_pred)

        return rewards
    '''


class VideoAnimateDiffusion(Txt2VidDiffusion):
    def __init__(self, device, fp16=True, t_range=[0.02, 0.98]):
        super().__init__()
        #super(VideoAnimateDiffusion, self).__init__(device)

        self.device = device

        print(f'[INFO] loading video animate diffusion...')

        self.precision_t = torch.float16 if fp16 else torch.float32


        adapter = MotionAdapter.from_pretrained("guoyww/animatediff-motion-adapter-v1-5-2", torch_dtype=torch.float16)
        # load SD 1.5 based finetuned model
        model_id = "SG161222/Realistic_Vision_V5.1_noVAE"
        pipe = AnimateDiffPipeline.from_pretrained(model_id, motion_adapter=adapter, torch_dtype=torch.float16)
        pipe.scheduler = DDIMScheduler.from_pretrained(
            model_id,
            subfolder="scheduler",
            clip_sample=False,
            timestep_spacing="linspace",
            beta_schedule="linear",
            steps_offset=1,
        )

        # enable memory savings
        pipe.enable_vae_slicing()
        #pipe.enable_model_cpu_offload()
        # Create model
        pipe.to(device)

        self.pipe = pipe

        self.vae = pipe.vae
        self.encode_prompt = pipe.encode_prompt
        #self.tokenizer = pipe.tokenizer
        #self.text_encoder = pipe.text_encoder
        self.unet = pipe.unet
        self.scheduler = pipe.scheduler

        self.image_processor = pipe.image_processor

        print(f'[INFO] loaded video animate diffusion!')
