import torch
import numpy as np
import sys
sys.path.append(".")

from mh_filters.imh import ZSampler, NoFilter
from mg_filters import MarkovZNoFilter, MmhFilter
from metrics.inception_score_script import inception_score
from metrics.new_fid import calculate_fid


def g_scores_evaluation(N, G, dim_z, batch_size, device, act_path, full=True):
    z_sampler = ZSampler(dim_z, device)
    no_filter = NoFilter(device, G, z_sampler)
    no_filter.sample_chain(N, batch_size)

    dataset_g = torch.utils.data.TensorDataset(no_filter.chain)
    is_g = inception_score(dataset_g, device, resize=True, splits=1)[0]
    torch.cuda.empty_cache()

    dataloader_g = torch.utils.data.DataLoader(dataset_g, batch_size=256, shuffle=True, num_workers=40)
    fid_g = calculate_fid([act_path, ''], ['', dataloader_g], device, full=full)
    torch.cuda.empty_cache()
    print('G score:', is_g, fid_g)
    return is_g, fid_g


def markov_scores_evaluation(MG, device, T, z_init, act_path, full=True):
    sampler = MarkovZNoFilter(MG, device)
    sampler.sample_chains(T, z_init.size(0), z_init)

    chain = sampler.x_chain.flatten(0, 1)
    print(chain.size(0))

    dataset_g = torch.utils.data.TensorDataset(chain)
    is_g = inception_score(dataset_g, device, resize=True, splits=1)[0]
    torch.cuda.empty_cache()

    dataloader_g = torch.utils.data.DataLoader(dataset_g, batch_size=256, shuffle=True, num_workers=40)
    fid_g = calculate_fid([act_path, ''], ['', dataloader_g], device, full=full)
    torch.cuda.empty_cache()
    print('MG score:', is_g, fid_g)
    return is_g, fid_g


def mh_scores_evaluation(D, MG, device, N, T, z_init, x_init, act_path, full=True):
    mh_filters = MmhFilter(D.eval(), MG, device)
    mh_filters.sample_chains(N, T, z_init, x_init)
    chain = torch.cat(mh_filters.oa)
    print('Acc N: ', chain.size(0))

    dataset_g = torch.utils.data.TensorDataset(chain)
    is_g = inception_score(dataset_g, device, resize=True, splits=1)[0]
    torch.cuda.empty_cache()

    dataloader_g = torch.utils.data.DataLoader(dataset_g, batch_size=256, shuffle=True, num_workers=40)
    fid_g = calculate_fid([act_path, ''], ['', dataloader_g], device, full=full)
    torch.cuda.empty_cache()
    print('MG-MH score:', is_g, fid_g)
    return is_g, fid_g
