
import os
import sys
import json
import numpy as np
import random
import argparse

from tqdm import tqdm

random.seed(42)

parser = argparse.ArgumentParser()

parser.add_argument(
    "--dataset",
    type=str,
    default='nsd',
)

parser.add_argument(
    "--subjects",
    nargs='+',
    type=str,
    default=['subj01', 'subj02', 'subj05', 'subj07']
)

parser.add_argument(
    "--train",
    action='store_true'
)

args = parser.parse_args()


def main():
    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'
    images_dir = f'{root_dir}/images'
    conversations_dir = f'{root_dir}/{args.dataset}_gpt_conversation'
    vision_embeds_dir = f'{root_dir}/vision_embeds'

    blip_captions = json.load(open(f'{root_dir}/{args.dataset}_captions.json', 'r'))
    coco_caption = json.load(open(f'{root_dir}/{args.dataset}_coco_captions.json', 'r'))
    train_flag = 'tr' if args.train else 'te'

    conversations = []
    atlas = {}

    for subject in args.subjects:

        data_path = f'{root_dir}/sft_data/{subject}'
        data = json.load(open(f'{data_path}/sft_{subject}_{train_flag}.json', 'r'))

        conversations.extend(data['conversations'])

        print(data['info'])
        atlas[subject] = data['info']['atlas']

    info = {
        'dataset': args.dataset,
        "subject": "all",
        "train": train_flag,
        "fmri_mean": f"{root_dir}/fmris/mean.npy",
        "fmri_std": f"{root_dir}/fmris/std.npy",
        "atlas": atlas,
    }

    results = {
        "info": info,
        "conversations": conversations
    }

    os.makedirs(f'{root_dir}/sft_data/all', exist_ok=True)
    with open(f'{root_dir}/sft_data/all/sft_all_{train_flag}.json', 'w') as f:
        json.dump(results, f, indent=4)


if __name__ == '__main__':
    main()
