import torch
import torch.nn.functional as F


class BaseFlow:
    def __init__(self, N: int):
        self.N = N
        X, Y = torch.meshgrid(torch.arange(self.N), torch.arange(self.N), indexing="xy")
        self.XY = torch.stack([X, Y], dim=-1)
        self.neg_prompt = "poorly drawn,cartoon, 2d, disfigured, bad art, deformed, poorly drawn, extra limbs, close up, b&w, weird colors, blurry"

    def get_default_prompt(self) -> str:
        return (self.pos_prompt, self.neg_prompt)

    def get_default_image(self) -> torch.Tensor:
        raise NotImplementedError

    def get_default_framesteps(self) -> torch.Tensor:
        raise NotImplementedError

    def get_flow(t) -> torch.Tensor:
        raise NotImplementedError

    def warp(self, should_apply_to_last_frame=False, **kwargs):
        framesteps = kwargs.pop("framesteps", self.get_default_framesteps())
        image = kwargs.pop("image", self.get_default_image())

        if type(framesteps) is not torch.Tensor:
            framesteps = torch.tensor(framesteps)
        if type(image) is not torch.Tensor:
            image = torch.tensor(image)

        frames = []
        current = image.clone()

        for time in framesteps:
            flow = self.get_flow(time)
            if should_apply_to_last_frame:
                base = current.unsqueeze(0)
            else:
                base = image.unsqueeze(0)

            current = F.grid_sample(base, flow.unsqueeze(0), **kwargs).squeeze(0)
            frames.append(current)

        frames = torch.stack(frames)
        return frames

    def get_spatial_eta(self, t):
        return 0.0

    def warp_latent_and_correct(
        self,
        t,
        original_frame,
        alphabar_tau,
        previous_frame,
        mode="nearest",
        padding_mode="reflection",
    ):
        flow = self.get_flow(t).to(previous_frame.device)
        warped_image = torch.nn.functional.grid_sample(
            previous_frame,
            flow.unsqueeze(0),
            align_corners=True,
            mode=mode,
            padding_mode=padding_mode,
        )
        return warped_image
