from algorithms.fedavg import FedAvgSever, FedAvgEvaluation
from algorithms.fl_base import FLDeviceBase
import torch
import copy

from visualization.state_dict import display_state_dict, state_dict_dif

class SubsetDevice(FLDeviceBase):
    scale_factor = -1.0

    #!overrides
    def _round(self):
        self._check_trainable()

        #train fedrolex network
        self._train()
        return


class CaldasServer(FedAvgSever):
    _device_class = SubsetDevice
    _device_evaluation_class = FedAvgEvaluation
    _list_of_indices_dict = None
    scale_factor = -1.0

    @staticmethod
    def embedd_submodel(indices_dict, submodel, global_model):
        model = copy.deepcopy(global_model)
        for key, item in model.items():
            item.zero_()
            pass

        mask = copy.deepcopy(model)
        for key, item in mask.items():
            if len(item.shape) == 4:
                mask[key] = item[:,:,0,0]

        exceptions = ['num_batches_tracked']
        for key, item in model.items():
            if all(excluded_key not in key for excluded_key in exceptions):
                #Case for bias, running_mean and running_var
                if len(item.shape) == 1:
                    #output bias remains fully-sized
                    if key == 'linear.bias':
                        model[key] = submodel[key]
                        mask[key] = torch.ones(model[key].shape)
                    else:
                        indices = indices_dict[key][0]
                        model[key][indices] = submodel[key]
                        mask[key][indices] = 1
                #Case for Conv2d
                elif len(item.shape) == 4:

                    #Special Case for first Conv2d input Layer
                    if item.shape[1] == 3:
                        indices_out = indices_dict[key][0]
                        model[key][indices_out, :, :, :] = submodel[key]
                        mask[key][indices_out, :] = 1
                    else:
                        indices_in = indices_dict[key][1]
                        indices_out = indices_dict[key][0]
                        indices_out, indices_in = torch.meshgrid(indices_out, indices_in, indexing='ij')
                        model[key][indices_out, indices_in, :, :] = submodel[key]
                        mask[key][indices_out, indices_in] = 1
                
                #Case for Fully Connected Layers
                elif len(item.shape) == 2:
                    indices_in = indices_dict[key][0]
                    model[key][:, indices_in] = submodel[key]
                    mask[key][:, indices_in] = 1
                else: raise NotImplementedError
        return model, mask
    
    @staticmethod
    def model_averaging(list_of_state_dicts, list_of_masks, eval_device_dict=None, storage_path=None):
        averaging_exceptions = ['num_batches_tracked']

        # For visual debugging of mask
        MASK_DEBUG = {}

        averaged_dict = copy.deepcopy(eval_device_dict)
        for key in averaged_dict:
            if all(module_name not in key for module_name in averaging_exceptions):
                stacked_mask = torch.stack([mask[key] for mask in list_of_masks], dim=0)
                parameter_stack = torch.stack([sd[key] for sd in list_of_state_dicts], dim=0)
                mask = torch.sum(stacked_mask, dim=0)
                if len(averaged_dict[key].shape) == 4:
                    mask = mask[:, :, None, None].repeat(1, 1, averaged_dict[key].shape[2], averaged_dict[key].shape[3])
                averaged_parameter = torch.sum(parameter_stack, dim=0)/mask
                averaged_dict[key][mask != 0] = averaged_parameter[mask != 0]
                pass

                # For visual debugging of mask
                if storage_path is not None:
                    if 'conv' in key:
                        MASK_DEBUG.update({key : mask[:,:,0,0]})
                    else:
                        MASK_DEBUG.update({key: mask})

        averaged_dict = {k: v for k, v in averaged_dict.items() if all(module_name not in k for module_name in averaging_exceptions)}
        
        # For visual debugging of mask
        if storage_path is not None:
            display_state_dict(MASK_DEBUG, storage_path + '/averaging_mask.png')
        
        return averaged_dict

    #!overrides
    def pre_round(self, round_n, rng):
        rand_selection_idxs = self.random_device_selection(self.n_devices, self.n_active_devices, rng)

        # Extraction of trainable model out of the full model
        device_model_list = []
        list_of_indices_dict = []
        for _, (dev_index) in enumerate(rand_selection_idxs):

            device_model, indices_dict = self.extract_fnc(self.scale_factor_list[dev_index], self._global_model, round_n=round_n)
            device_model_list.append(device_model)
            list_of_indices_dict.append(indices_dict)

        self._list_of_indices_dict = list_of_indices_dict
        return rand_selection_idxs, device_model_list
    
    #!overrides
    def post_round(self, round_n, idxs):
        used_devices = [self._devices_list[i] for i in idxs]

        # DEBUG code
        DEBUG_old = copy.deepcopy(self._global_model)

        # Extract individual Device models with stored indices
        device_models = []
        device_masks = []

        for i, device in enumerate(used_devices):
            model, mask = self.embedd_submodel(self._list_of_indices_dict[i], device.get_model_state_dict(), self._global_model)
            device_models.append(model)
            device_masks.append(mask)

        # Model Averaging based on extracted local models
        averaged_model = self.model_averaging(device_models, device_masks,
                                              eval_device_dict=self._global_model, storage_path=self._storage_path)

        # Setting new global model
        self._global_model = averaged_model
        
        # DEBUG code
        display_state_dict(state_dict_dif(averaged_model, DEBUG_old), self._storage_path + '/state_dict_viz.png')
        
        pass
