import argparse
import os

import torch as t
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm

from cfg.dataloader_pickle import PickleDataset, transback, load_data
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size

import torch.nn.functional as F

from path_constant import project_root

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from cfg.unet import Unet
import copy
from cfg.embedding import JointEmbedding2, JointConditionalEmbedding
import torch

from cfg.dataloader_pickle import PickleDataset, transback, load_data
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.distributed import get_rank, init_process_group, destroy_process_group, all_gather, get_world_size


from path_constant import project_root

from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
from cfg.unet import Unet
import copy
from cfg.embedding import JointEmbedding2, JointConditionalEmbedding


def load_data(dataset: PickleDataset, batchsize: int) -> tuple[DataLoader, DistributedSampler]:
    trainloader = DataLoader(dataset,
                             batch_size=batchsize,
                             shuffle=True,
                             drop_last=True)
    return trainloader



class Net(nn.Module):
    def __init__(self, outdim):
        super(Net,self).__init__()
        self.linear1 = nn.Linear(3*32*32, 100)
        self.linear2 = nn.Linear(100, 50)
        self.final = nn.Linear(50, outdim)
        self.relu = nn.ReLU()

    def forward(self, img): #convert + flatten
        x = img.view(-1, 3*32*32)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.final(x)
        return x


def train(net1, net2, dataloader, params):

    transform = transforms.Compose([
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    cross_el1 = nn.CrossEntropyLoss()
    cross_el2 = nn.CrossEntropyLoss()
    optimizer = t.optim.Adam(list(net1.parameters())+ list(net2.parameters()), lr=0.001) #e-1
    epoch = 20

    for epoch in range(epoch):
        net1.train()
        net2.train()


        with tqdm(dataloader, dynamic_ncols=True) as tqdmDataLoader:
            for batch in tqdmDataLoader:
                # input= transform(batch['W1']).to(device)
                input= batch['W1'].to(device)
                lab1= F.one_hot( batch['W2a'].to(device), num_classes=10).float()
                lab2= F.one_hot( batch['W2b'].to(device), num_classes=6).float()

                x= input
                optimizer.zero_grad()
                output1 = net1(x.view(-1, 3*32*32))
                output2 = net2(x.view(-1, 3*32*32))
                loss = cross_el1(output1, lab1) + cross_el2(output2, lab2)

                loss.backward()
                optimizer.step()




    checkpoint = {
                        'netW2a':net1.state_dict(),
                        'netW2b':net2.state_dict(),
                        'optimizer':optimizer.state_dict(),
                    }

    moddir= os.path.join(params.moddir, 'W2')
    os.makedirs(moddir, exist_ok=True)
    torch.save({'last_epoch':epoch+1}, os.path.join(moddir,f'last_epoch.pt'))
    torch.save(checkpoint, os.path.join(moddir, f'ckpt_{epoch+1}_checkpoint.pt'))





if __name__ == '__main__':

    if __name__ == '__main__':
        # several hyperparameters for model
        parser = argparse.ArgumentParser(description='test for diffusion model')
        parser.add_argument('--train_pkl', type=str,
                            default=f"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl")
        parser.add_argument('--val_pkl', type=str,
                            default=f"{project_root}/napkin_mnist/base_data/napkin_mnist_train.pkl")
        parser.add_argument('--datakey', type=str, help='which of the data keys is the one we want to generate')
        parser.add_argument('--condkey', type=str, help='which of the data keys is the one we use for conditioning')
        parser.add_argument('--batchsize', type=int, default=256, help='batch size per device for training Unet model')
        parser.add_argument('--numworkers', type=int, default=4, help='num workers for training Unet model')
        parser.add_argument('--inch', type=int, default=3, help='input channels for Unet model')
        parser.add_argument('--modch', type=int, default=64, help='model channels for Unet model')
        parser.add_argument('--T', type=int, default=1000, help='timesteps for Unet model')
        parser.add_argument('--outch', type=int, default=3, help='output channels for Unet model')
        parser.add_argument('--chmul', type=list, default=[1, 2, 2, 2],
                            help='architecture parameters training Unet model')
        parser.add_argument('--numres', type=int, default=2, help='number of resblocks for each block in Unet model')
        parser.add_argument('--cdim', type=int, default=64, help='dimension of conditional embedding')
        parser.add_argument('--useconv', type=bool, default=True, help='whether use convlution in downsample')
        parser.add_argument('--droprate', type=float, default=0.1, help='dropout rate for model')
        parser.add_argument('--dtype', default=torch.float32)
        parser.add_argument('--lr', type=float, default=2e-4, help='learning rate')
        parser.add_argument('--w', type=float, default=1.8,
                            help='hyperparameters for classifier-free guidance strength')
        parser.add_argument('--v', type=float, default=0.3,
                            help='hyperparameters for the variance of posterior distribution')
        parser.add_argument('--epoch', type=int, default=50, help='epochs for training')
        parser.add_argument('--multiplier', type=float, default=2.5, help='multiplier for warmup')
        parser.add_argument('--threshold', type=float, default=0.1, help='threshold for classifier-free guidance')
        parser.add_argument('--interval', type=int, default=10, help='epoch interval between two evaluations')
        parser.add_argument('--moddir', type=str,
                            default=f'{project_root}/Baselines/DiffusionBasedCausalModels/imgcond_model',
                            help='model addresses')
        parser.add_argument('--samdir', type=str,
                            default=f'{project_root}/Baselines/DiffusionBasedCausalModels/imgcond_sample',
                            help='sample addresses')
        parser.add_argument('--genbatch', type=int, default=80, help='batch size for sampling process')
        # parser.add_argument('--clsnum',type=int,default=1000,help='num of label classes')
        parser.add_argument('--num_steps', type=int, default=50, help='sampling steps for DDIM')
        parser.add_argument('--eta', type=float, default=0, help='eta for variance during DDIM sampling process')
        parser.add_argument('--select', type=str, default='linear', help='selection stragies for DDIM')
        parser.add_argument('--ddim', type=lambda x: (str(x).lower() in ['true', '1', 'yes']), default=False,
                            help='whether to use ddim')
        parser.add_argument('--local_rank', default=-1, type=int, help='node rank for distributed training')

        # args = parser.parse_args()

        params, unknown = parser.parse_known_args()

        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # load data
        train_data = PickleDataset(params.train_pkl)
        val_data = PickleDataset(params.val_pkl)
        dataloader = load_data(train_data, params.batchsize)
        val_loader = load_data(val_data, params.genbatch // torch.cuda.device_count())







        ################
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)), ])

        # load data
        train_data = PickleDataset(params.train_pkl)
        val_data = PickleDataset(params.val_pkl)
        dataloader = load_data(train_data, params.batchsize)
        val_loader = load_data(val_data, params.genbatch)

        net1 = Net(outdim=10).to(device)
        net2 = Net(outdim=6).to(device)

        train(net1,net2, dataloader, params)



        with t.no_grad():
            total = 0
            correct1 = 0
            correct2 = 0

            with tqdm(val_loader, dynamic_ncols=True) as tqdmDataLoader:
                for batch in tqdmDataLoader:
                    input = batch['W1'].to(device)
                    lab1 = F.one_hot(batch['W2a'].to(device), num_classes=10).float()
                    lab2 = F.one_hot(batch['W2b'].to(device), num_classes=6).float()
                    x = input

                    output1 = net1(x.view(-1, 3 * 32 * 32))
                    output2 = net2(x.view(-1, 3 * 32 * 32))

                    for idx, i in enumerate(output1):
                        if t.argmax(i) == t.argmax(lab1[idx]):
                            correct1 += 1

                    for idx, i in enumerate(output2):
                        if t.argmax(i) == t.argmax(lab2[idx]):
                            correct2 += 1

                        total += 1

        print(f'accuracy: {round(correct1 / total, 3)},  {round(correct2 / total, 3)}')

        n = 18
        plt.imshow(x[n].permute(1, 2, 0).cpu())
        plt.show()
        print(t.argmax(net1(x[n].view(-1, 3 * 32 * 32))[0]), lab1[n])
        print(t.argmax(net2(x[n].view(-1, 3 * 32 * 32))[0]), lab2[n])



