import glob
import os
import pickle
import torch


def get_clusters(path='', jobid='25476687', num_clusters=400, ckpt_num=45):
    desc_path = os.listdir(path)
    ckpt_path = os.path.join(path, desc_path[0], 'checkpoints')
    result_dict = {}

    kinetic_train_path = 'datasets/data/kinetics_train.txt'
    with open(kinetic_train_path, 'r') as f:
        kinetics_paths = f.readlines()

    vid_valid_file = f'datasets/data/kinetics_valid.pkl'

    if os.path.exists(vid_valid_file):
        with open(vid_valid_file, 'rb') as handle:
            valid_indices = pickle.load(handle)

    final_kinetics_paths = []
    for ix in valid_indices:
        final_kinetics_paths.append(kinetics_paths[ix])
    for epoch in range(ckpt_num, ckpt_num + 10, 10):
        print(f"Epoch: {epoch}")
        full_path = os.path.join(ckpt_path, f'ckpt_{epoch}.pth')
        ckpt = torch.load(full_path)
        self_labels = ckpt['selflabels'][:, 0]
        full_list = []
        for cluster_i in range(num_clusters):
            print(f"Epoch: {epoch}, cluster: {cluster_i}")
            cluster_indices = (self_labels == cluster_i).nonzero().cpu().numpy()
            cluster_list = []
            for index in cluster_indices[:, 0]:
                path = final_kinetics_paths[index]
                vid_name = path.split('.')[0].split('/')[-1]
                gt_class = path.split('.')[0].split('/')[-2]
                youtube_id = '_'.join(vid_name.split('_')[0:-2])
                start_time = vid_name.split('_')[-2]
                end_time = vid_name.split('_')[-1]
                res_tuple = (youtube_id, int(start_time), int(end_time), gt_class)
                cluster_list.append(res_tuple)
            full_list.append(cluster_list)
        result_dict[str(epoch)] = full_list
    
    with open(f'cluster_vis/{jobid}.pkl', 'wb') as handle:
        pickle.dump(result_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Get Clusters')

    ### Retrieval params
    parser.add_argument(
        '--path', 
        default='vggsound.pth',
        type=str, 
        help='path to ckpt'
    )
    parser.add_argument(
        '--jobid', 
        default='2374533',
        type=str, 
        help='SLURM JOBID'
    )
    parser.add_argument(
        '--num-clusters', 
        default=400,
        type=int, 
        help='Num clusters'
    )
    parser.add_argument(
        '--ckpt-num', 
        default=45,
        type=int, 
        help='Ckpt to get labels'
    )
    
    args = parser.parse_args()
    get_clusters(
        path=args.path,
        jobid=args.jobid,
        num_clusters=args.num_clusters
        ckpt_num=args.ckpt_num
    )
