import torch
import os
import torch.optim as optim
import torch.utils.data as data
import math

from solver_base import Solver
from torchvision.utils import save_image

class OurSolver(Solver):
    def __init__(self, args):
        super(OurSolver, self).__init__(args)

    def solve(self):
        num_iter = 0
        for epoch_count in range(self.args.num_epochs):
            for images, _ in self.train_data:
                orig_images, trans_images = images
                trans_images = trans_images.view(
                        trans_images.size(0)*self.args.num_samples, self.args.num_channels, 32, 32)

                if self.cuda:
                    orig_images = orig_images.cuda()
                    trans_images = trans_images.cuda()

                recon, mu, logvar = self.model(orig_images)
                trans_recon, trans_mu, trans_logvar = self.model(trans_images)

#                kl_loss = self.kl_from_prior(mu, std)
#                recon_loss = self.recon_loss(recon, orig_images)
                vae_loss, recon_loss, kl_loss = self.base_vae_loss(recon, orig_images, mu, logvar)

                trans_vae_loss, trans_recon_loss, trans_kl_loss = \
                        self.base_vae_loss(trans_recon, trans_images, trans_mu, trans_logvar)
                mu_repeat = torch.repeat_interleave(mu, repeats=self.args.num_samples, dim=0)
                logvar_repeat = torch.repeat_interleave(logvar, repeats=self.args.num_samples, dim=0)

                
                if self.args.trans_distance == 'wass':
                    trans_distance_loss = self.wasserstein_distance(
                            trans_mu, trans_logvar, mu_repeat, logvar_repeat) 
                elif self.args.trans_distance == 'kl':
                    trans_distance_loss = self.kl_from_another_gaussian(
                            trans_mu, trans_logvar, mu_repeat, logvar_repeat) 

#                trans_recon_loss = self.recon_loss(trans_recon, trans_images)
                
#                total_loss = self.args.beta * (kl_loss + kl_trans_loss) + \
#                        recon_loss + trans_recon_loss + self.args.lambda_ * kl_trans_orig
#                total_loss = self.args.beta * kl_loss + \
#                        recon_loss + self.args.lambda_ * kl_trans_orig
                total_loss = vae_loss + trans_vae_loss + self.args.lambda_ * trans_distance_loss

                self.optim.zero_grad()
                total_loss.backward()
                self.optim.step()

                if num_iter % 50 == 0:
                    self.log_stats(
                            num_iter,
                            total_loss=total_loss,
                            kl_loss=kl_loss,
                            recon_loss=recon_loss,
                            trans_kl_loss=trans_kl_loss,
                            trans_recon_loss=trans_recon_loss,
                            trans_distance_loss=trans_distance_loss,
                    )
                num_iter += 1

            self.test(epoch_count)

