'''
Author: 
Email: 
Date: 2021-12-21 11:57:44
LastEditTime: 2022-12-27 22:21:02
Description: 
'''

import numpy as np
from torch.utils.data import Dataset, DataLoader


from Because.planner import Planner
from Because.discover import Discover
from Because.utils import TrajectoryDataset, segment_trajectories

class Because(object):
    name = 'Because'

    def __init__(self, args):
        self.model_path = args['model_path']
        self.model_id = args['model_id']
        args['planner']['env_params'] = args['env_params']
        args['planner']['Because_model'] = args['Because_model']
        args['discover']['env_params'] = args['env_params']

        # use discovered graph or not. use gt if not use discovered graph
        self.use_discover = False
        
        # only use causal when we use causal model
        if args['Because_model'] != 'causal':
            self.use_discover = False

        args['planner']['use_discover'] = self.use_discover
        args['planner']['use_gt'] = not self.use_discover

        # two modules
        self.planner = Planner(args['planner'])
        if self.use_discover:
            self.discover = Discover(args['discover'])

        # decide the ratio between generation and discovery (generation is always longer)
        self.stage = 'generation'
        self.episode_counter = 0
        self.discovery_interval = args['discover']['discovery_interval']

    def stage_scheduler(self):
        if (self.episode_counter + 1) % self.discovery_interval == 0:
            self.stage = 'discovery'
        else:
            self.stage = 'generation'
        self.episode_counter += 1

    def select_action(self, env, state, deterministic):
        return self.planner.select_action(env, state, deterministic)

    def select_action_parallel(self, env, state, deterministic):
        return self.planner.select_action_parallel(env, state, deterministic)
    
    def store_transition(self, data):
        self.planner.store_transition(data)
        if self.use_discover: 
            self.discover.store_transition(data)
    
    def train(self):
        # discovery
        if self.stage == 'discovery' and self.use_discover:
            self.discover.update_causal_graph()
            self.planner.set_causal_graph(self.discover.get_adj_matrix_graph())
        
        # generation
        self.planner.train()

        # in the end, update the stage
        self.stage_scheduler()
    
    def train_offline(self, dataloader):
        # discovery
        if self.stage == 'discovery' and self.use_discover:
            self.discover.update_causal_graph()
            self.planner.set_causal_graph(self.discover.get_adj_matrix_graph())

        # generation
        self.planner.train_offline(dataloader)

        # in the end, update the stage
        self.stage_scheduler()
        
    def save_model(self):
        self.planner.save_model(self.model_path, self.model_id)
        if self.use_discover:
            self.discover.save_model(self.model_path, self.model_id)

    def load_model(self):
        self.planner.load_model(self.model_path, self.model_id)
        if self.use_discover:
            self.discover.load_model(self.model_path, self.model_id)
    
    def load_dataset(self, data_path): 
        offline_data = np.load(data_path, allow_pickle=True).item()
        trajectories = segment_trajectories(offline_data, success_only=False)
        dataset = TrajectoryDataset(trajectories)
        dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)
        
        return dataloader
        
        