import torch
import numpy as np
import sys
sys.path.append(".")

from mh_filters.imh import ZSampler, NoFilter, ImhFilter
from metrics.inception_score_script import inception_score
from metrics.new_fid import calculate_fid


def g_scores_evaluation(G, dim_z, init, device, true_activation_path,
                        chain_size=int(10e3), full=True):
    z = ZSampler(dim_z, device)
    sampler = NoFilter(device, G, z)
    sampler.sample_chain(chain_size, init.size(0))
    dataset_g = torch.utils.data.TensorDataset(sampler.chain)
    assert len(dataset_g) >= chain_size - init.size(0)
    is_g = inception_score(dataset_g, device, resize=True, splits=1)[0]
    torch.cuda.empty_cache()

    dataloader_g = torch.utils.data.DataLoader(dataset_g, batch_size=256, shuffle=True, num_workers=40)
    fid_g = calculate_fid([true_activation_path, ''], ['', dataloader_g], device, full=full)
    torch.cuda.empty_cache()
    print(is_g, fid_g)
    return is_g, fid_g


def imh_scores_evaluation(D, G, dim_z, init, device, true_activation_path,
                          chain_size=int(10e3), full=True):
    print('Start sampling')
    z = ZSampler(dim_z, device)
    sampler = ImhFilter(device, D, G, z)
    sampler.sample_chains_only_accepts(chain_size, init)
    n_acc = [i.size(0) for i in sampler.chain]
    print('Avg. accep per batch:', np.mean(np.array(n_acc) / init.size(0)))

    chain = torch.cat(sampler.chain)
    dataset_mh = torch.utils.data.TensorDataset(chain)
    assert len(dataset_mh) >= chain_size

    is_mh = inception_score(dataset_mh, device, resize=True, splits=1)[0]
    torch.cuda.empty_cache()

    dataloader_mh = torch.utils.data.DataLoader(dataset_mh, batch_size=256, shuffle=True, num_workers=40)
    fid_mh = calculate_fid([true_activation_path, ''], ['', dataloader_mh], device, full=full)
    torch.cuda.empty_cache()

    print(is_mh, fid_mh)
    return is_mh, fid_mh
