import random
import pickle5 as pickle
import os
import numpy as np
from PIL import Image
import torchvision

file_name = 'train_dataset.pkl'
factor = "XWIND"
data_path = f"/path/to/MMRL/trajdata/RoboMani/reach/{factor}/"
factor = "XWIND"
dest_path = f"/path/to/MMRL/trajdata/RoboMani/reach/LUSR/{factor}/"
if not os.path.exists(dest_path):
    os.mkdir(dest_path)

import os

# experts = 10
# with open(os.path.join(data_path, file_name), 'rb') as f:
#     data = pickle.load(f)
#     frames_np = []
#     for i in range(experts):
#         aug = torchvision.transforms.Compose([
#             torchvision.transforms.Resize((224,224)),
#             torchvision.transforms.ColorJitter(
#                 # brightness= (0.5 + i*0.1, 0.5 + i*0.1),
#                 # contrast=(0.5 + i*0.2, 0.5 + i*0.2),
#                 # saturation=(0.5 + i*0.2, 0.5 + i*0.2),
#                 hue=(-0.5 + i*0.1, -0.5 + i*0.1)
#             ),
#         ])
#         for ep in data.keys():
#             if isinstance(ep, int):
#                 for j in range(len(data[ep]["reward"])):
#                     frame = np.array(aug(Image.fromarray(data[ep]["frame"][j])))
#                     # frame.save("test.png")
#                     frames_np.append(frame)
    
#     np.savez(os.path.join(dest_path,'0.npz') ,obs=frames_np)
# exit()

with open(os.path.join(data_path, file_name), 'rb') as f:
    data = pickle.load(f)
    frames_np = []
    aug = torchvision.transforms.Compose([
            torchvision.transforms.Resize((224,224)),  
        ])
    for mdp in data.keys():
        if isinstance(mdp, int):
            for j in range(len(data[mdp]["reward"])):
                frame = np.array(aug(Image.fromarray(data[mdp]["frame"][j])))
                # frame.save("test.png")
                frames_np.append(frame)
    np.savez(os.path.join(dest_path,'0.npz') ,obs=frames_np)

exit()
# frames = np.load("/home/andykim0723/LUSR/data/main_data/domain0/0.npz")['obs']
# print(frames.shape)
# labels = np.load("/home/andykim0723/LUSR/data/main_data/domain0/0.npz")['labels']
# print(labels)
# exit()
for i_split in split:
    episode_frames = []
    directory = data_path+i_split
    for file_name in os.listdir(directory):
        file_path = directory + '/' + file_name
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
            frames = data['frame']
            frames = [np.expand_dims(frame, axis=0) for frame in frames]
            frames = np.concatenate(frames)
            episode_frames.append(frames)

    total_frames = np.concatenate(episode_frames)
    target_directory = target_path+i_split
    np.savez(target_directory +'/0.npz',total_frames)
    # save as npz

    frame_dict[i_split] = total_frames

source_frames = frame_dict['source']
target_frames = frame_dict['target']

import numpy as np
frames = np.load("/home/andykim0723/LUSR/data/carla_data/weather0/0.npz")['obs']
print(frames.shape)






####################

# assign directory
directory = 'data/1030_domains/hue/source'
 # iterate over files in
 # that directory
 

sub_dir = 'tmp'


frames = [] 
for num, filename in enumerate(os.listdir(directory)):
    p = os.path.join(directory, filename)
      # checking if it is a file
    with open(p, 'rb') as f:
        data = pickle.load(f)
        frame = data['frame_un']
        frames.append(frame)
        f.close()

source_frames = [item for sublist in frames for item in sublist]
source_frames = [np.expand_dims(sf,axis=0) for sf in source_frames]
source_frames = np.concatenate(source_frames)

print(f"source domain: {source_frames.shape}")

directory = 'data/1030_domains/hue/target'
frames = [] 
for num, filename in enumerate(os.listdir(directory)):
    p = os.path.join(directory, filename)
      # checking if it is a file
    with open(p, 'rb') as f:
        data = pickle.load(f)
        frame = data['frame_un']
        frames.append(frame)
        f.close()

target_frames = [item for sublist in frames for item in sublist]
target_frames = [np.expand_dims(sf,axis=0) for sf in target_frames]
target_frames = np.concatenate(target_frames)

print(f"target domain: {target_frames.shape}")


frames = [source_frames,target_frames]
dest_path = '/home/andykim0723/LUSR/data/da_data/hue/'
splits = ['domain0','domain1']

for i, split in enumerate(splits):
    path = dest_path + split
    save_frame = frames[i]
    np.savez(path +'/0.npz',obs=save_frame)


