from datasets.AVideoDataset import AVideoDataset
import numpy as np
import os
import pickle
from scipy.stats import entropy
from sklearn.cluster import KMeans
from sklearn.metrics.cluster import normalized_mutual_info_score, adjusted_mutual_info_score
import torch
import torch.distributed as dist
from torch.utils.data.sampler import Sampler, SubsetRandomSampler
import utils
from models import VideoWrapper


class Subset_Sampler(Sampler):
    """
    Sample indices.
    """

    def __init__(self, indices):
        self.indices = indices

    def __iter__(self):
        return iter(self.indices)

    def __len__(self):
        return len(self.indices)


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k."""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)
        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))
        res = []
        for k in topk:
            correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def get_cluster_assignments_gpu(args, dataset, model, logger=None, group=None, device='cuda'):
    # clear cache at beginning
    torch.cuda.empty_cache()
    model.eval()
    N = len(dataset)
    # this process deals only with a subset of the dataset
    local_nmb_data = N // args.world_size
    train_indices = torch.arange(
        args.global_rank * local_nmb_data,
        (args.global_rank + 1) * local_nmb_data
    ).int()
    # create subset sampler
    sampler = Subset_Sampler(train_indices)

    # we need a data loader
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=args.batch_size,
        sampler=sampler if args.world_size > 8 else None,
        num_workers=args.workers,
        pin_memory=True,
        collate_fn=None,
        shuffle=False if args.world_size > 8 else True
    )

    # Ensure processes reach to end of optim clusters
    if args.distributed and args.world_size > 8:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()

    # use GAP features
    if args.headcount > 1:
        model.module.return_features = True
    aggregtensor = torch.cuda.DoubleTensor if args.headcount == 1 else torch.cuda.FloatTensor
    dtype = torch.float64 if args.headcount == 1 else torch.float32

    viz_list = []
    for batch_idx, batch in enumerate(dataloader):
        # Get data
        video, audio, label, _, _ = batch

        # Move to GPU
        video = video.cuda(non_blocking=True)
        audio = audio.cuda(non_blocking=True)
        label = label.cuda(non_blocking=True)

        # Forward pass
        if args.mil_nce:
            out = model(video)
            feat_v = out['video_embedding']
            feat_a = torch.rand(feat_v.size(0), 512).cuda()
        elif args.xdc:
            feat_v, feat_a = model(video, audio)
        elif args.dpc:
            (BS, C, T, H, W) = video.shape
            _, feat_v = model(video.view(BS, 8, C, 5, H, W))
            feat_v = feat_v.squeeze()
            feat_a = torch.rand(feat_v.size(0), 256).cuda()
        elif args.sela:
            feat_v, feat_a = model(video[:, :, 0, :, :], audio)
            feat_a = torch.rand(feat_v.size(0), args.num_clusters).cuda()
        else:
            feat_v, feat_a = model(video, audio)

        if args.global_rank == 0 and batch_idx % 10 == 0:
            utils.print_or_log((batch_idx, label.shape, label.max(), video.shape, audio.shape, feat_v.shape, feat_a.shape), logger=logger)
        
        # save results
        if args.mode == 'val' and args.corruption == 1 and args.ours:
            if args.headcount > 1:
                head_a = getattr(model.module, f'mlp_a0')
                head_v = getattr(model.module, f'mlp_v0')
                feat_v_head = head_v.forward(feat_v)
                feat_a_head = head_a.forward(feat_a)
            PS_v_sk = torch.nn.functional.softmax(feat_v_head, dim=1, dtype=torch.float64)
            PS_a_sk = torch.nn.functional.softmax(feat_a_head, dim=1, dtype=torch.float64)
            PS_av = torch.mul(PS_v_sk, PS_a_sk)
            self_labels_np  = PS_av.argmax(1).cpu().numpy()
            viz_list.extend([(video[i, :, 15, :, :].cpu().numpy(), feat_v[i].cpu().numpy(), feat_a[i].cpu().numpy(), label[i], self_labels_np[i]) for i in range(len(video))])

        if batch_idx % 5 == 0:
            acc = accuracy(feat_v, label)
            utils.print_or_log(f"Batch {batch_idx}: {acc}")

        if args.world_size > 8:
            # gather the features computed by all processes
            all_feat_v_list = [aggregtensor(feat_v.size()) for src in range(args.world_size)]
            all_feat_a_list = [aggregtensor(feat_a.size()) for src in range(args.world_size)]
            all_labels_list = [torch.zeros(label.size(0), dtype=torch.long).cuda() for _ in range(args.world_size)]
            
            dist.all_gather(all_feat_v_list, feat_v)
            dist.all_gather(all_feat_a_list, feat_a)
            dist.all_gather(all_labels_list, label)

            # only main process stores all features
            if args.global_rank == 0:
                all_feat_v = torch.cat(all_feat_v_list)
                all_feat_a = torch.cat(all_feat_a_list)
                all_labels = torch.cat(all_labels_list).cpu()
        else:
            all_feat_v = feat_v
            all_feat_a = feat_a
            all_labels = label

        if batch_idx == 0 and (args.global_rank == 0):
            fr = 0
            K = feat_v.size(1)
            utils.print_or_log(f"storing features of size {K}", logger=logger)
            PS_v = torch.zeros((N, K), dtype=dtype, device=device)
            PS_a = torch.zeros((N, K), dtype=dtype, device=device)
            labels = torch.zeros(N, dtype=torch.long)

        # fill in arrays on main node
        if args.global_rank == 0:
            to = fr + all_feat_v.shape[0]
            print(fr, to, all_labels.shape, all_feat_v.shape)
            PS_v[fr: to] = all_feat_v
            PS_a[fr: to] = all_feat_a
            labels[fr: to] = all_labels
            fr = to

        if args.distributed and args.world_size > 8:
            if group is not None:
                dist.barrier(group=group)
            else:
                dist.barrier()

    # Dump results
    if args.global_rank == 0:
        utils.print_or_log(f"Labels: ", logger=logger)
        utils.print_or_log(np.unique(labels.cpu().numpy(), return_counts=True), logger=logger)
        if args.ours:
            PS_v_heads,PS_a_heads = [], []
            for h in range(args.headcount):
                head_a = getattr(model.module, f'mlp_a{h}')
                head_v = getattr(model.module, f'mlp_v{h}')
                PS_v_heads.append(head_v.forward(PS_v))
                PS_a_heads.append(head_a.forward(PS_a))
            PS = [PS_v_heads, labels, PS_a_heads]
        else:
            PS = [PS_v, labels, PS_a]
        with open(f'cluster_fit_PS_matrices_{args.run_id}.pkl', 'wb') as handle:
            pickle.dump(PS, handle, protocol=pickle.HIGHEST_PROTOCOL)
        utils.print_or_log(f"Finished Dumping!", logger=logger)
        if args.mode == 'val' and args.corruption == 1:
            with open(os.path.join(args.output_dir, args.run_id + '.pkl), 'wb') as handle:
                pickle.dump(viz_list, handle, protocol=pickle.HIGHEST_PROTOCOL)


    # Make other processes wait
    if args.distributed and args.world_size > 8:
        if group is not None:
            dist.barrier(group=group)
        else:
            dist.barrier()

    return

def parse_args():
    def str2bool(v):
        v = v.lower()
        if v in ('yes', 'true', 't', '1'):
            return True
        elif v in ('no', 'false', 'f', '0'):
            return False
        raise ValueError('Boolean argument needs to be true or false. '
                         'Instead, it is %s.' % v)

    import argparse
    parser = argparse.ArgumentParser(description='Video Representation Learning')
    parser.register('type', 'bool', str2bool)

    parser.add_argument(
        '--output-dir',
        default='.',
        help='path where to save'
    )
    parser.add_argument(
        '--run-id',
        default='ours_vggsound_val',
        help='SLURM JOB ID',
    )
    parser.add_argument(
        '--weights-path',
        default='',
        help='Path to weights file',
    )
    parser.add_argument(
        "--pretrained",
        type='bool',
        default='False',
        help="Use pre-trained models from the modelzoo",
    )
    parser.add_argument(
        "--mil-nce",
        type='bool',
        default='False',
        help="Use MIL-NCE model",
    )
    parser.add_argument(
        "--dpc",
        type='bool',
        default='False',
        help="Use DPC model",
    )
    parser.add_argument(
        "--xdc",
        type='bool',
        default='False',
        help="Use XDC model",
    )
    parser.add_argument(
        "--sela",
        type='bool',
        default='False',
        help="Use SELA frame based model",
    )
    parser.add_argument(
        '--dataset',
        default='vggsound',
        help='name of dataset'
    )
    parser.add_argument(
        '--mode',
        default='val',
        help='mode of dataset'
    )
    parser.add_argument(
        '--num-data-samples',
        default=14032, # 230976 / 18968 (Kineitcs), 170752 / 14032 (vggsound)
        type=int,
        help='number of samples in dataset'
    )

    # AUDIO UTILS
    parser.add_argument(
        '--aud-sample-rate',
        default=24000,
        type=int,
        help='audio sample rate'
    )
    parser.add_argument(
        '--aud-spec-type',
        default=2,
        type=int,
        help='audio spec type' # 1 : (40, 99), (257, 199)
    )
    parser.add_argument(
        '--use-volume-jittering',
        default='False',
        type='bool',
        help='use volume jittering'
    )
    parser.add_argument(
        '--use-temporal-jittering',
        default='False',
        type='bool',
        help='use temporal jittering'
    )
    parser.add_argument(
        '--num-sec',
        default=1,
        type=int,
        help='Number of seconds'
    )
    parser.add_argument(
        '--z-normalize',
        default='True',
        type='bool',
        help='normalize audio'
    )
    parser.add_argument(
        '--corruption',
        default=1,
        type=int,
        help='corruption factor'
    )

    ### DATA
    parser.add_argument(
        '-b',
        '--batch-size',
        default=96,
        type=int
    )
    parser.add_argument(
        '--clip-len',
        default=30,
        type=int,
        help='number of frames per clip'
    )
    parser.add_argument(
        '-j', '--workers',
        default=10,
        type=int,
        metavar='N',
        help='number of data loading workers (default: 16)'
    )
    parser.add_argument(
        '--train-crop-size',
        default=112,
        type=int,
        help='Size of spatial crops'
    )
    parser.add_argument(
        '--sample-rate',
        default=1,
        type=int,
        help='Subsampling rate: num frames between clips'
    )

    ### MODEL
    parser.add_argument(
        '--model',
        default='avc',
        help='model',
        choices=['avc', 'r2plus1d_18']
    )
    parser.add_argument(
        '--vid-base-arch',
        default='r2plus1d_18',
        help='Video Base Arch for A-V model',
        choices=['r2plus1d_18', 'mc3_18', 's3d', 'r2plus1d_34', 'r2plus1d_50']
    )
    parser.add_argument(
        '--aud-base-arch',
        default='resnet9',
        help='Audio Base Arch for A-V model',
        choices=['resnet9', 'resnet18', 'vgg_audio', 'resnet34', 'resnet50']
    )
    parser.add_argument(
        '--use-mlp',
        default='True',
        help='Use MLP projection head'
    )
    parser.add_argument(
        '--mlptype',
        default=1,
        type=int,
        help='MLP type (default: -1) -1 is linear layer'
    )
    parser.add_argument(
        '--num-clusters',
        default=309,
        type=int,
        help='number of clusters'
    )
    parser.add_argument(
        '--headcount',
        type=int,
        default=10,
        help='how many heads each modality has'
    )
    parser.add_argument(
        "--norm-feat",
        type='bool',
        default='False',
        help="Normalize embeddings",
    )


    # distributed training parameters
    parser.add_argument(
        '--device',
        default='cuda',
        help='device'
    )
    parser.add_argument(
        '--distributed',
        type='bool',
        default='False',
        help="ddp mode",
    )
    parser.add_argument(
        '--dist-backend',
        default='nccl',
        type=str,
        help='distributed backend'
    )
    parser.add_argument(
        '--dist-url',
        default='env://',
        help='url used to set up distributed training'
    )
    parser.add_argument(
        '--world-size',
        default=1,
        type=int,
        help='number of distributed processes'
    )
    parser.add_argument(
        '--debug_slurm',
        type='bool',
        default='False',
        help="Debug SLURM",
    )
    parser.add_argument(
        '--local_rank',
        default=-1,
        type=int,
        help='Local rank of node')
    parser.add_argument(
        '--master_port',
        default=-1,
        type=int,
        help='Master port of Job'
    )

    args = parser.parse_args()
    return args


if __name__ == '__main__':

    # parse args
    args = parse_args()

    # set multi-processing start method
    import torch.multiprocessing as mp

    try:
        mp.set_start_method('forkserver')
    except RuntimeError:
        pass

    # Init distributed mode
    if torch.cuda.is_available():
        global group
        group = utils.init_distributed_mode(args, make_communication_groups=True)
        if args.distributed:
            group = group[0]

        device = torch.device(args.device)
        if args.world_size <= 8:
            device = 'cuda:0'

    # Set up logger
    logger = None
    if args.distributed:
        filename = str(args.job_id) + '_' + str(args.global_rank) + '_log.out'
        logger = utils.setup_logger(
            "Video_reader, classification",
            args.output_dir,
            True,
            logname=filename
        )
    
    if args.dataset == 'vggsound':
        args.num_clusters = 309
        if args.mode == 'train':
            args.num_data_samples = 170752
        else:
            args.num_data_samples = 14032
    elif args.dataset == 'kinetics':
        args.num_clusters = 400
        if args.mode == 'train':
            args.num_data_samples = 230976
        else:
            args.num_data_samples = 18968
    elif args.dataset == 'kinetics_sound':
        args.num_clusters = 32
        if args.mode == 'train':
            args.num_data_samples = 22408
        else:
            args.num_data_samples = 22408
    elif args.dataset == 'ave':
        args.num_clusters = 28
        if args.mode == 'train':
            args.num_data_samples = 3328
        else:
            args.num_data_samples = 3328

    if args.mil_nce:
        args.clip_len = 32
        args.train_crop_size = 224
        args.sample_rate = 1
        args.target_fps = 10
        args.mean = [0, 0, 0]
        args.std = [1, 1, 1]
    elif args.xdc:
        args.clip_len = 32
        args.train_crop_size = 224
        args.sample_rate = 1
        args.target_fps = 30
        args.mean = [0.485, 0.456, 0.406]
        args.std = [0.229, 0.224, 0.225]
        args.aud_spec_type = 1
        args.z_normalize = False
        args.aud_sample_rate = 48000
    elif args.dpc:
        args.clip_len = 40
        args.train_crop_size = 224
        args.sample_rate = 3
        args.target_fps = 30
        args.mean = [0.485, 0.456, 0.406]
        args.std = [0.229, 0.224, 0.225]
    elif args.sela:
        args.clip_len = 30
        args.train_crop_size = 112
        args.sample_rate = 8
        args.target_fps = 30
        args.mean = [0.45, 0.45, 0.45]
        args.std = [0.225, 0.225, 0.225]
    else:
        args.clip_len = 30
        args.train_crop_size = 112
        args.sample_rate = 1
        args.target_fps = 30
        args.mean = [0.45, 0.45, 0.45]
        args.std = [0.225, 0.225, 0.225]

    # Get dataset
    dataset = AVideoDataset(
        ds_name=args.dataset,
        mode=args.mode,
        num_frames=args.clip_len,
        sample_rate=args.sample_rate,
        train_crop_size=args.train_crop_size,
        num_data_samples=args.num_data_samples,
        target_fps=args.target_fps,
        decode_audio=True,
        num_sec=args.num_sec,
        aud_sample_rate=args.aud_sample_rate,
        aud_spec_type=args.aud_spec_type,
        use_volume_jittering=False,
        use_temporal_jittering=args.use_temporal_jittering,
        z_normalize=args.z_normalize,
        center_crop=True,
        temp_jitter=False,
        mean=args.mean,
        std=args.std,
        corruption=args.corruption
    )

    weight_path_type = type(args.weights_path)
    weight_path_not_none = args.weights_path != 'None' if weight_path_type == str else args.weights_path is not None

    # Get model
    if args.mil_nce:
        from MIL_NCE.s3dg import S3D
        utils.print_or_log(f"Loading MIL-NCE pretrained model", logger=logger)
        model = S3D('MIL_NCE/s3d_dict.npy', 512)
        model.load_state_dict(torch.load('MIL_NCE/s3d_howto100m.pth'))
        model.to(device)
        model = torch.nn.DataParallel(model)
    elif args.xdc:
        from XDC.model import build_video_encoder, XDC, AudioEncoder
        import torchvision
        import torch.nn as nn
        utils.print_or_log(f"Loading XDC pretrained model", logger=logger)
        video_model = build_video_encoder(1)
        video_model.load_state_dict(torch.load('XDC/r2.5d_epoch30_inputcount30051712_final.pth'))
        audio_model = AudioEncoder()
        audio_model.load_state_dict(torch.load('XDC/a_resnet_epoch16_inputcount16091136_checkpoint_f134616582.pth'))
        video_model.fc = nn.Identity()
        audio_model.fc = nn.Identity()
        model = XDC(video_model, audio_model)
        model.to(device)
        model = torch.nn.DataParallel(model)
    elif args.dpc:
        from DPC.model_3d_lc import LC
        from DPC.resnet_2d3d import neq_load_customized
        model = LC(
            sample_size=224, 
            num_seq=8, 
            seq_len=5, 
            network='resnet34',
            #num_class=309,
            dropout=0.5
        )
        model.to(device)
        model = torch.nn.DataParallel(model)
        checkpoint = torch.load('DPC/k400_224_r34_dpc-rnn_runningStats.pth.tar')
        try: model.load_state_dict(checkpoint['state_dict'])
        except:
            print('=> [Warning]: weight structure is not equal to test model; Use non-equal load ==')
            model = neq_load_customized(model, checkpoint['state_dict'])
        print("=> loaded testing checkpoint '{}' (epoch {})".format('DPC', checkpoint['epoch']))
    elif args.sela:
        model = utils.load_model(
            model_name='resnet18',
            vid_base_arch=None,
            aud_base_arch=None,
            pretrained=False,
            norm_feat=False,
            use_mlp=True,
            headcount=1,
            num_classes=args.num_clusters,
            return_features=False
        )
        model.mlp_v = None
        model = VideoWrapper(model, model.mlp_v, num_clusters=args.num_clusters, hc=1)
        model.to(device)
        model = torch.nn.DataParallel(model)
        utils.restart_from_checkpoint(
            args,
            ckp_path=args.weights_path,
            model=model
        )
    else: # OURS, Kinetics-pretrained and scratch eval
        args.ours = True
        args.headcount = args.headcount if weight_path_not_none else 1
        model = utils.load_model(
            model_name=args.model,
            vid_base_arch=args.vid_base_arch,
            aud_base_arch=args.aud_base_arch,
            pretrained=args.pretrained,
            norm_feat=args.norm_feat,
            use_mlp=args.use_mlp,
            mlptype=args.mlptype,
            headcount=args.headcount,
            num_classes=args.num_clusters,
            return_features=False if weight_path_not_none else True
        )
        model.to(device)
        if args.world_size <= 8:
            utils.print_or_log(f'World size: {args.world_size}, Loading data parallel', logger=logger)
            model = torch.nn.DataParallel(model)
        else:
            if args.distributed:
                ngpus_per_node = torch.cuda.device_count()
                model = torch.nn.parallel.DistributedDataParallel(
                    model,
                    device_ids=[args.local_rank],
                    output_device=args.local_rank,
                    broadcast_buffers=False
                )

        # Load model weights
        to_restore = {'epoch': 0}
        if not args.pretrained:
            if weight_path_not_none:
                utils.print_or_log("Loading model weights", logger=logger)
                utils.restart_from_checkpoint(
                    args,
                    ckp_path=args.weights_path,
                    run_variables=to_restore,
                    model=model
                )
            else:
                utils.print_or_log("Random weights", logger=logger)

    # Get cluster assignments
    with torch.no_grad():
        get_cluster_assignments_gpu(args, dataset, model, logger=logger, group=group)
