# %%
import argparse
import os
import sys

# os.chdir("../")


import accelerate
import auraloss  # freq loss
import lovely_tensors as lt
import matplotlib.pyplot as plt
import matplotlib
import numpy as np
import scipy.io as io
import seaborn as sns
import torch
import torch.nn as nn
import wandb
import yaml
from diffusers.optimization import get_scheduler
from omegaconf import OmegaConf
from scipy.signal import welch
from tqdm.auto import tqdm
from einops import rearrange

from ntldm.networks import AutoEncoder, CountWrapper
from ntldm.utils.plotting_utils import *
from ntldm.losses import latent_regularizer
from ntldm.networks import Denoiser
from diffusers.training_utils import EMAModel
from diffusers.schedulers import DDPMScheduler

# always run from ../ntldm


lt.monkey_patch()
matplotlib.rc_file("matplotlibrc")

# %%
## load config and model path

cfg_ae = OmegaConf.load(
    "conf/sweeps_new/Phoneme_autoencoder-count_s4-phoneme_v2loss.yaml"
)


cfg_yaml = """
denoiser_model:
  C_in: 32
  C: 384
  kernel: s4
  num_blocks: 8
  bidirectional: True
  num_train_timesteps: 1000
training:
  lr: 0.001
  num_epochs: 2000
  num_warmup_epochs: 100
  batch_size: 256
  random_seed: 42
  precision: "no"
exp_name: diffusion_s4-phoneme_newbiggestlessema
"""

cfg = OmegaConf.create(yaml.safe_load(cfg_yaml))
cfg.dataset = cfg_ae.dataset


# %%


import math
from ntldm.data.phoneme import get_phoneme_dataloaders

# set seed
torch.manual_seed(cfg.training.random_seed)
np.random.seed(cfg.training.random_seed)

train_dataloader, val_dataloader, test_dataloader = get_phoneme_dataloaders(
    cfg_ae.dataset.datapath, batch_size=cfg_ae.training.batch_size
)

# %%
ae_model = AutoEncoder(
    C_in=cfg_ae.model.C_in,
    C=cfg_ae.model.C,
    C_latent=cfg_ae.model.C_latent,
    L=cfg_ae.dataset.max_seqlen,
    kernel=cfg_ae.model.kernel,
    num_blocks=cfg_ae.model.num_blocks,
    num_blocks_decoder=cfg_ae.model.get("num_blocks_decoder", cfg_ae.model.num_blocks),
    num_lin_per_mlp=cfg_ae.model.get("num_lin_per_mlp", 2),  # default 2
    bidirectional=cfg_ae.model.get("bidirectional", False),
)

ae_model = CountWrapper(ae_model, use_sin_enc=cfg_ae.model.get("use_sin_enc", False))


accelerator = accelerator = accelerate.Accelerator(
    mixed_precision="no",
    log_with="wandb",
)
#

# prepare the ae model and dataset

ae_model = accelerator.prepare(ae_model)

print(cfg_ae.exp_name)
accelerator.load_state(f"exp/{cfg_ae.exp_name}/epoch_400")  # best checkpoint

(
    train_dataloader,
    val_dataloader,
    test_dataloader,
) = accelerator.prepare(
    train_dataloader,
    val_dataloader,
    test_dataloader,
)


def reconstruct_spikes(model, dataloader):
    model.eval()
    latents = []
    spikes = []
    rec_spikes = []
    signal_masks = []
    for batch in dataloader:
        signal = batch["signal"]
        signal_mask = batch["mask"]
        with torch.no_grad():
            output_rates, z = model(signal)
            z = z.cpu()
        latents.append(z)
        spikes.append(signal.cpu())
        rec_spikes.append(torch.poisson(output_rates.cpu()) * signal_mask.cpu())
        signal_masks.append(signal_mask.cpu())

    return {
        "latents": torch.cat(latents, 0),
        "spikes": torch.cat(spikes, 0),
        "rec_spikes": torch.cat(rec_spikes, 0),
        "signal_masks": torch.cat(signal_masks, 0),
    }


rec_dict = reconstruct_spikes(ae_model, test_dataloader)


# plot reconstructed spikes
plt.figure(figsize=cm2inch((6, 4)))
bins = np.linspace(0, 1000, 100)
plt.hist(
    (rec_dict["spikes"] * rec_dict["signal_masks"]).sum(2).flatten(),
    density=True,
    color="grey",
    bins=bins,
    alpha=0.5,
)
plt.hist(
    (rec_dict["rec_spikes"] * rec_dict["signal_masks"]).sum(2).flatten(),
    density=True,
    color="darkblue",
    bins=bins,
    alpha=0.5,
)

plt.legend(["gt", "ae"])
plt.title("spike count distribution (test set)")

# %%

# plot reconstructed spikes
plt.figure(figsize=cm2inch((6, 4)))
bins = np.linspace(1, 200, 199)
counts, bins, patches = plt.hist(
    (rec_dict["spikes"] * rec_dict["signal_masks"]).sum(1).flatten(),
    density=True,
    color="grey",
    bins=bins,
    alpha=0.5,
)
plt.hist(
    (rec_dict["rec_spikes"] * rec_dict["signal_masks"]).sum(1).flatten(),
    density=True,
    color="darkblue",
    bins=bins,
    alpha=0.5,
)
plt.xlim(20, 150)
plt.yticks([])
plt.legend(["gt", "ae"])
plt.gca().spines["left"].set_visible(False)
plt.title("spike count distribution (test set)")


# %%
# create the latent dataset
class LatentPhonemeDataset(torch.utils.data.Dataset):
    def __init__(
        self, dataloader, ae_model, clip=True, latent_means=None, latent_stds=None
    ):
        self.full_dataloader = dataloader
        self.ae_model = ae_model
        (
            self.latents,
            self.train_spikes,
            self.train_spike_masks,
            self.embeddings,
            self.embedding_masks,
            self.original_sentences,
            self.phonemized_sentences,
        ) = self.create_latents()

        # normalize to N(0, 1)
        if latent_means is None or latent_stds is None:
            print(self.latents.shape, self.train_spike_masks.shape)
            masked_sum = (
                self.latents * self.train_spike_masks[:, : self.latents.shape[1], :]
            ).sum(dim=(0, 2))
            masked_count = self.train_spike_masks[:, : self.latents.shape[1], :].sum(
                dim=(0, 2)
            )
            latent_means = masked_sum / masked_count

            # compute masked variance
            masked_square_sum = (
                self.latents.pow(2)
                * self.train_spike_masks[:, : self.latents.shape[1], :]
            ).sum(dim=(0, 2))
            latent_means_sq = latent_means.pow(2)
            masked_variance = (masked_square_sum / masked_count) - latent_means_sq

            latent_stds = masked_variance.sqrt().unsqueeze(0).unsqueeze(2)

            self.latent_means = latent_means.unsqueeze(0).unsqueeze(2)
            self.latent_stds = latent_stds

        else:
            self.latent_means = latent_means
            self.latent_stds = latent_stds

        # normalize latents channel-wise to N(0, 1)
        self.latents = (self.latents - self.latent_means) / self.latent_stds

        if clip:
            self.latents = self.latents.clamp(-5, 5)

    def symmetrically_pad_and_expand_embedding(
        self, embedding, embedding_mask, latent_mask
    ):
        """
        NOTE: due to a bug in the phoneme dataset, the embeddings
        from the original dataset are padded on the left.

        """
        L, C = embedding.shape

        embedding = embedding.permute(1, 0)  # [C, L]
        embedding_mask = embedding_mask.permute(1, 0)  # [C, L]

        embedding_len = int(embedding_mask[0].sum(0).item())  # l < L
        # print('embedding_len', embedding_len)

        pad_left = (L - embedding_len) // 2
        pad_right = L - embedding_len - pad_left

        embedding = torch.nn.functional.pad(
            embedding[:, L - embedding_len :], (pad_left, pad_right), mode="replicate"
        )
        embedding_mask = torch.zeros_like(embedding[:1])  # [1, L]
        embedding_mask[:, pad_left : pad_left + embedding_len] = 1

        latent_max_len = latent_mask.shape[-1]  # Ll > L

        # interpolate with nearest neighbors in the time dim (L)
        # [C, L] -> [C, Ll]

        embedding = torch.nn.functional.interpolate(
            embedding.unsqueeze(0), (latent_max_len), mode="nearest"
        ).squeeze(
            0
        )  # [C, L] -> [1, C, L] -> [1, C, Ll] -> [C, Ll]

        embedding_mask = torch.nn.functional.interpolate(
            embedding_mask.unsqueeze(0), (latent_max_len), mode="nearest"
        ).squeeze(
            0
        )  # [1, L] -> [1, 1, Ll] -> [1, 1, Ll] -> [1, Ll] -> [C, Ll]

        return embedding, embedding_mask

    def symmetrically_pad_latent(self, latent, latent_mask):
        C, L = latent.shape

        latent_len = int(latent_mask[0].sum().item())

        pad_left = (L - latent_len) // 2
        pad_right = L - latent_len - pad_left

        latent = torch.nn.functional.pad(
            latent[:, :latent_len], (pad_left, pad_right), mode="replicate"
        )
        latent_mask = torch.zeros_like(latent[:1])  # [1, L]
        latent_mask[:, pad_left : pad_left + latent_len] = 1

        return latent, latent_mask

    def create_latents(self):
        latent_dataset = []

        train_spikes = []
        train_spike_masks = []

        embeddings = []
        embedding_masks = []

        original_sentences = []
        phonemized_sentences = []

        ## dataset output:
        # def __getitem__(self, idx):
        #     return {
        #         "original_sentence": self.original_sentences[idx],
        #         "phonemized_sentence": self.phonemized_sentences[idx],
        #         "embedding": self.embeddings[idx],
        #         "embedding_mask": self.embedding_masks[idx],
        #         "signal": (self.spikes[idx].T if self.time_last else self.spikes[idx]),
        #         "mask": (self.masks[idx].T if self.time_last else self.masks[idx]),
        #     }

        self.ae_model.eval()
        for i, batch in tqdm(
            enumerate(self.full_dataloader),
            total=len(self.full_dataloader),
            desc="Creating latent dataset",
        ):
            with torch.no_grad():
                z = self.ae_model.encode(batch["signal"])
                latent_dataset.append(z.cpu())

                train_spikes.append(batch["signal"].cpu())
                train_spike_masks.append(batch["mask"].cpu())

                embeddings.append(batch["embedding"].cpu())
                embedding_masks.append(batch["embedding_mask"].cpu())

                original_sentences.extend(batch["original_sentence"])
                phonemized_sentences.extend(batch["phonemized_sentence"])

        return (
            torch.cat(latent_dataset),
            torch.cat(train_spikes),
            torch.cat(train_spike_masks),
            torch.cat(embeddings),
            torch.cat(embedding_masks),
            original_sentences,
            phonemized_sentences,
        )

    def __len__(self):
        return len(self.latents)

    def __getitem__(self, idx):
        embedding, embedding_mask = self.symmetrically_pad_and_expand_embedding(
            self.embeddings[idx], self.embedding_masks[idx], self.train_spike_masks[idx]
        )
        latent, latent_mask = self.symmetrically_pad_latent(
            self.latents[idx], self.train_spike_masks[idx]
        )

        return {
            "signal": self.train_spikes[idx],  # not symmetrically padded
            "latent": latent,
            "mask": latent_mask,
            "embedding": embedding,
            "embedding_mask": embedding_mask,
            "original_sentence": self.original_sentences[idx],
            "phonemized_sentence": self.phonemized_sentences[idx],
        }


latent_dataset_train = LatentPhonemeDataset(train_dataloader, ae_model, clip=False)


latent_dataset_val = LatentPhonemeDataset(
    val_dataloader,
    ae_model,
    latent_means=latent_dataset_train.latent_means,
    latent_stds=latent_dataset_train.latent_stds,
    clip=False,
)
latent_dataset_test = LatentPhonemeDataset(
    test_dataloader,
    ae_model,
    latent_means=latent_dataset_train.latent_means,
    latent_stds=latent_dataset_train.latent_stds,
    clip=False,
)
# %%

element = latent_dataset_train[0]

plt.plot(element["latent"][0])
plt.plot(element["mask"][0])
# %%

print("latent dataset", latent_dataset_train.latents)
print("latent dataset means", latent_dataset_train.latent_means)
print("latent dataset stds", latent_dataset_train.latent_stds)
plt.figure(figsize=cm2inch(5, 3))
hist = plt.hist(
    latent_dataset_train.latents[:10].flatten(), bins=200, density=True, alpha=0.5
)
hist = plt.hist(
    latent_dataset_val.latents[:10].flatten(), bins=200, density=True, alpha=0.5
)
plt.legend()
plt.title("Latent dataset histogram")
plt.show()
# %%
train_latent_dataloader = torch.utils.data.DataLoader(
    latent_dataset_train,
    batch_size=cfg.training.batch_size,
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

val_latent_dataloader = torch.utils.data.DataLoader(
    latent_dataset_val,
    batch_size=cfg.training.batch_size,
    shuffle=False,
    num_workers=4,
    pin_memory=True,
)

num_batches = len(train_latent_dataloader)

# check if signal length is power of 2
if cfg.dataset.max_seqlen & (cfg.dataset.max_seqlen - 1) != 0:
    cfg.training.precision = "no"  # torch.fft doesnt support half if L!=2^x

# prepare the denoiser model and dataset
(
    train_latent_dataloader,
    val_latent_dataloader,
) = accelerator.prepare(
    train_latent_dataloader,
    val_latent_dataloader,
)

# %%

## initialize (unconditional) denoiser

from ntldm.networks import Denoiser

denoiser = Denoiser(
    C_in=cfg.denoiser_model.C_in + 1,  # 1 for mask
    C=cfg.denoiser_model.C,
    L=cfg.dataset.max_seqlen,
    kernel=cfg.denoiser_model.kernel,
    num_blocks=cfg.denoiser_model.num_blocks,
    bidirectional=cfg.denoiser_model.get("bidirectional", True),
)

# initial values may be way off so better to scale down the output layer
denoiser.conv_out.weight.data = denoiser.conv_out.weight.data * 0.1
denoiser.conv_out.bias.data = denoiser.conv_out.bias.data * 0.1

start_epoch = 0

# # load previous checkpoint
# from safetensors.torch import load_file
# state_dict = load_file(f'exp/{cfg.exp_name}/epoch_2100/model_1.safetensors')
# display(state_dict)
# denoiser.load_state_dict(state_dict)
# start_epoch = 1370
finetune = False


scheduler = DDPMScheduler(
    num_train_timesteps=cfg.denoiser_model.num_train_timesteps,
    clip_sample=False,
    beta_schedule="linear",  # ddpm doesnt support cosine
)


optimizer = torch.optim.AdamW(
    denoiser.parameters(), lr=cfg.training.lr
)  # default wd=0.01 for now


num_batches = len(train_latent_dataloader)
lr_scheduler = get_scheduler(
    name="cosine",
    optimizer=optimizer,
    num_warmup_steps=num_batches
    * cfg.training.num_warmup_epochs,  # warmup for 10% of epochs
    num_training_steps=num_batches
    * cfg.training.num_epochs
    * 1.3,  # total number of steps
)

# prepare the denoiser model and dataset
(
    denoiser,
    optimizer,
    lr_scheduler,
) = accelerator.prepare(
    denoiser,
    optimizer,
    lr_scheduler,
)

ema_model = EMAModel(
    denoiser, min_value=(0.99 if finetune else 0), max_value=0.99, power=3 / 4
)


# %%


def sample_spikes_with_mask(
    ema_denoiser, scheduler, ae, cfg, lengths=None, batch_size=1, device="cuda"
):
    z_t = torch.randn((batch_size, cfg.denoiser_model.C_in, cfg.dataset.max_seqlen)).to(
        device
    )

    if lengths is None:
        lengths = torch.linspace(100, 512, batch_size).long().to(device)
    else:
        if isinstance(lengths, int):
            lengths = torch.tensor([lengths] * batch_size).to(device)
        elif isinstance(lengths, list):
            lengths = torch.tensor(lengths).long().to(device)

    masks = torch.zeros(batch_size, cfg.dataset.max_seqlen).to(device)
    for i, l in enumerate(lengths):
        padding_left = (cfg.dataset.max_seqlen - l) // 2
        padding_right = cfg.dataset.max_seqlen - l - padding_left
        masks[i, padding_left : padding_left + l] = 1.0

    masks = masks.unsqueeze(1)

    ema_denoiser_avg = ema_denoiser.averaged_model
    ema_denoiser_avg.eval()
    scheduler.set_timesteps(cfg.denoiser_model.num_train_timesteps)

    for t in tqdm(scheduler.timesteps, desc="Sampling DDPM (different masks)"):
        with torch.no_grad():
            model_output = ema_denoiser_avg(
                torch.cat([z_t, masks], dim=1),
                torch.tensor([t] * batch_size).to(device).long(),
            )[:, :-1]
        z_t = scheduler.step(model_output, t, z_t, return_dict=False)[0]

    z_t = z_t * latent_dataset_train.latent_stds.to(
        z_t.device
    ) + latent_dataset_train.latent_means.to(z_t.device)

    with torch.no_grad():
        rates = ae.decode(z_t).cpu()

    spikes = torch.poisson(rates)

    return {
        "rates": rates,
        "spikes": spikes,
        "latents": z_t.cpu(),
        "masks": masks.cpu(),
        "mask_lengths": lengths,
    }


# %%
def plot_real_vs_sampled_rates_and_spikes(
    real_rates,
    sampled_rates,
    real_spikes,
    sampled_spikes,
    real_masks,
    sampled_masks,
    batch_idx=0,
):
    B, C, L = real_rates.shape

    fig, axs = plt.subplots(2, 2, figsize=cm2inch(12, 8), dpi=300)

    real_rates = real_rates[batch_idx]
    sampled_rates = sampled_rates[batch_idx]
    real_spikes = real_spikes[batch_idx]
    sampled_spikes = sampled_spikes[batch_idx]
    real_masks = real_masks[batch_idx]
    sampled_masks = sampled_masks[batch_idx]

    real_mask_idx_with_1 = torch.arange(real_masks[0].nonzero().flatten().numel())
    sampled_mask_idx_with_1 = sampled_masks[0].nonzero().flatten()
    # print(real_mask_idx_with_1)

    im = axs[0, 0].imshow(
        real_rates[:, real_mask_idx_with_1], cmap="viridis", alpha=1.0, aspect="auto"
    )
    axs[0, 0].set_title("Real rates")
    fig.colorbar(im, ax=axs[0, 0], orientation="vertical", fraction=0.046, pad=0.04)

    im = axs[0, 1].imshow(
        sampled_rates[:, sampled_mask_idx_with_1],
        cmap="viridis",
        alpha=1.0,
        aspect="auto",
    )
    axs[0, 1].set_title("Sampled rates")
    fig.colorbar(im, ax=axs[0, 1], orientation="vertical", fraction=0.046, pad=0.04)

    im = axs[1, 0].imshow(
        real_spikes[:, real_mask_idx_with_1], cmap="Greys", alpha=1.0, aspect="auto"
    )
    axs[1, 0].set_title("Real spikes")
    fig.colorbar(im, ax=axs[1, 0], orientation="vertical", fraction=0.046, pad=0.04)

    im = axs[1, 1].imshow(
        sampled_spikes[:, sampled_mask_idx_with_1],
        cmap="Greys",
        alpha=1.0,
        aspect="auto",
    )
    axs[1, 1].set_title("Sampled spikes")
    fig.colorbar(im, ax=axs[1, 1], orientation="vertical", fraction=0.046, pad=0.04)

    print(
        real_rates[:, real_mask_idx_with_1].shape,
        sampled_rates[:, sampled_mask_idx_with_1].shape,
        real_spikes[:, real_mask_idx_with_1].shape,
        sampled_spikes[:, sampled_mask_idx_with_1].shape,
    )

    # add colorbars
    for i, ax in enumerate(axs.flatten()):
        if i % 2 != 0:
            ax.set_yticks([])
        # ax.set_yticks([])

    fig.tight_layout()
    plt.show()


def plot_population_spike_count_distribution(
    train_spikes, train_masks, sampled_spikes, sampled_masks, figsize=(6, 4), dpi=300
):
    def strip_padding(spikes, masks):
        trimmed_spikes = []
        for i in range(len(spikes)):
            nonzero_mask = masks[i, 0].nonzero().flatten()
            spike = spikes[i]
            spike_ = spike[:, nonzero_mask]
            trimmed_spikes.append(spike_)
        return trimmed_spikes

    def concatenate_spikes(spikes):
        return torch.cat(spikes, dim=-1).sum(0).cpu().numpy()

    def plot_histogram(train_data, sampled_data, ax):
        sampled_mean = sampled_data.mean()
        sampled_std = sampled_data.std()
        train_mean = train_data.mean()
        train_std = train_data.std()

        counts, bins, patches = ax.hist(
            sampled_data,
            color="darkred",
            label=f"diffusion\n{sampled_mean:.1f}$\pm${sampled_std:.1f}",
            bins=50,
            alpha=0.5,
            density=True,
        )
        ax.hist(
            train_data,
            color="grey",
            label=f"gt\n{train_mean:.1f}$\pm${train_std:.1f}",
            bins=bins,
            alpha=0.5,
            density=True,
        )
        ax.vlines([train_mean], [0], [max(counts)], color="grey", linestyle="--")
        ax.vlines([sampled_mean], [0], [max(counts)], color="darkred", linestyle="--")
        ax.set_xlim(20, 180)
        ax.set_yticks([])
        ax.spines["left"].set_visible(False)
        ax.legend()

    train_spikes_trimmed = strip_padding(train_spikes, train_masks)
    sampled_spikes_trimmed = strip_padding(sampled_spikes, sampled_masks)

    summed_spikes_train = concatenate_spikes(train_spikes_trimmed)
    summed_spikes_sampled = concatenate_spikes(sampled_spikes_trimmed)

    fig, ax = plt.subplots(1, 1, figsize=cm2inch(*figsize), dpi=dpi)
    plot_histogram(summed_spikes_train, summed_spikes_sampled, ax)
    fig.suptitle("population spike count distribution")

    return fig


# %%

loss_fn = torch.nn.SmoothL1Loss(
    beta=0.04, reduction="none"
)  # faster convergence than mse

start_epoch = 0

wandb.init(project="ntldm-phoneme", entity="anon-project")
print("initialized wandb")
pbar = tqdm(range(start_epoch, cfg.training.num_epochs + start_epoch), desc="epochs")
for epoch in pbar:
    for i, batch in enumerate(train_latent_dataloader):
        denoiser.train()
        optimizer.zero_grad()

        z = batch["latent"]
        z_mask = batch["mask"]
        embedding = batch["embedding"]
        embedding_mask = batch["embedding_mask"]

        t = torch.randint(
            0, cfg.denoiser_model.num_train_timesteps, (z.shape[0],), device="cpu"
        ).long()

        noise = torch.randn_like(z)
        noisy_z = scheduler.add_noise(z, noise, t)

        noise_pred = denoiser(torch.cat([noisy_z, z_mask], dim=1), t)
        noise_pred = noise_pred[
            :, :-1
        ]  # remove the dim corresponding to conditioning mask

        loss = loss_fn(noise, noise_pred)
        loss = loss * z_mask  # mask out the padding
        loss = loss.mean()

        accelerator.backward(loss)
        accelerator.clip_grad_norm_(denoiser.parameters(), 1.0)

        optimizer.step()
        lr_scheduler.step()

        if i % 10 == 0:
            pbar.set_postfix(
                {
                    "loss": loss.item(),
                    "lr": lr_scheduler.get_last_lr()[0],
                    "epoch": epoch,
                }
            )
            wandb.log(
                {
                    "loss": loss.item(),
                    "lr": lr_scheduler.get_last_lr()[0],
                    "epoch": epoch,
                }
            )

        ema_model.step(denoiser)

    do_plot = (
        ((epoch % 50 == 0) and (epoch <= 200))
        or ((epoch % 100 == 0) and (epoch > 200))
        or (epoch == cfg.training.num_epochs - 1)
    )
    if do_plot:  # plot samples

        denoiser.eval()

        ret_dict = sample_spikes_with_mask(
            ema_model,
            scheduler,
            ae_model,
            cfg,
            lengths=[
                int(i.item())
                for i in train_latent_dataloader.dataset.train_spike_masks[::30, 0].sum(
                    -1
                )
            ],
            batch_size=train_latent_dataloader.dataset.train_spike_masks[::30, 0].shape[
                0
            ],
            device="cuda",
        )

        fig = plot_population_spike_count_distribution(
            train_latent_dataloader.dataset.train_spikes[::30],
            train_latent_dataloader.dataset.train_spike_masks[::30],
            ret_dict["spikes"],
            ret_dict["masks"],
        )

        wandb.log({"population_spike_count_distribution": wandb.Image(fig)})

        # save
        accelerator.save_state(f"exp/{cfg.exp_name}/epoch_{epoch}")
        torch.save(
            ema_model.averaged_model.state_dict(),
            f"exp/{cfg.exp_name}/epoch_{epoch}/ema_model.pt",
        )

pbar.close()

# %%

ret_dict = sample_spikes_with_mask(
    ema_model,
    scheduler,
    ae_model,
    cfg,
    lengths=[
        int(i.item())
        for i in train_latent_dataloader.dataset.train_spike_masks[::20, 0].sum(-1)
    ],
    batch_size=train_latent_dataloader.dataset.train_spike_masks[::20, 0].shape[0],
    device="cuda",
)

fig = plot_population_spike_count_distribution(
    train_latent_dataloader.dataset.train_spikes[::20],
    train_latent_dataloader.dataset.train_spike_masks[::20],
    ret_dict["spikes"],
    ret_dict["masks"],
)
# %%
import matplotlib
import numpy as np

import matplotlib.pyplot as plt

# plot only colorbar (horizonta, greys from 0 to 5, discretized)
fig, ax = plt.subplots(1, 1, figsize=cm2inch(4, 0.5), dpi=300)
cmap = plt.cm.Greys
norm = plt.Normalize(vmin=0, vmax=5)
cb1 = matplotlib.colorbar.ColorbarBase(
    ax,
    cmap=cmap,
    norm=norm,
    orientation="horizontal",
    boundaries=np.arange(0, 7) - 0.5,
    ticks=[0, 5],
)
cb1.set_label("spike count")
