import torch
import numpy as np
import re
import torch.nn as nn
import random
import statistics

class RENHD(nn.Module):
    ## RENHD is defined as a torch.nn.Module
    ## so that it can be moved to certain GPU device
    ## module.register_buffer are also in GPU
    ## otherwise there is mismatch in GPU and CPU parameters in self.update()
    def __init__(self, model, N, eta_theta0, c_theta0, mu0):
        super(RENHD, self).__init__()
        self.pattern1 = re.compile(r'linear|conv|bn')
        self.pattern2 = re.compile(r'lstm')
        self.pattern3 = re.compile(r'layer')
        self.pattern4 = re.compile(r'shortcut|pool')
        self.N = N
        self.model = model
        self.eta_theta0, self.c_theta0, self.mu0 = eta_theta0, c_theta0, mu0

        for name, module in self.model._modules.items():
            if self.pattern1.match(name):
                self.register_pattern1(module)

            elif self.pattern2.match(name):
                self.register_pattern2(module)

            elif self.pattern3.match(name):
                # the nn.Sequential() layer
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        print("This is the module: " + str(name__))
                        if self.pattern1.match(name__):
                            self.register_pattern1(module__)
                        elif self.pattern4.match(name__):
                            pass
                        else:
                            print('Cannot find the module: ' + str(name__))

            elif self.pattern4.match(name):

                pass
            else:
                print('Cannot find the module: ' + str(name))

    def register_pattern1(self,module):
        size_w = module.weight.data.shape
        module.register_buffer('v_w', torch.zeros(size_w))
        module.register_buffer('n_w', torch.zeros(size_w))
        if module.bias is not None:
            size_b = module.bias.data.shape
            module.register_buffer('v_b', torch.zeros(size_b))
            module.register_buffer('n_b', torch.zeros(size_b))

        module.register_buffer('eta_theta', torch.Tensor([self.eta_theta0 * self.N]))
        module.register_buffer('c_theta', torch.Tensor([self.c_theta0]))
        module.register_buffer('mu', torch.Tensor([self.mu0]))
        module.register_buffer('z_theta', module.c_theta)

    def register_pattern2(self,module):
        size_wih = module.weight_ih_l0.data.shape
        size_bih = module.bias_ih_l0.data.shape
        size_whh = module.weight_hh_l0.data.shape
        size_bhh = module.bias_hh_l0.data.shape
        module.register_buffer('v_wih', torch.zeros(size_wih))
        module.register_buffer('v_bih', torch.zeros(size_bih))
        module.register_buffer('v_whh', torch.zeros(size_whh))
        module.register_buffer('v_bhh', torch.zeros(size_bhh))
        module.register_buffer('eta_theta', torch.Tensor([self.eta_theta0 * self.N]))
        module.register_buffer('c_theta', torch.Tensor([self.c_theta0]))
        module.register_buffer('mu', torch.Tensor([self.mu0]))
        module.register_buffer('z_theta', module.c_theta)
        module.register_buffer('n_wih', torch.zeros(size_wih))
        module.register_buffer('n_bih', torch.zeros(size_bih))
        module.register_buffer('n_whh', torch.zeros(size_whh))
        module.register_buffer('n_bhh', torch.zeros(size_bhh))

    def get_z_theta(self):
        buffer_ = []
        for name, module in self.model._modules.items():
            if self.pattern1.match(name) or self.pattern2.match(name):
                buffer_.append(torch.norm(module.z_theta))
            elif self.pattern3.match(name):
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        if self.pattern1.match(name__) or self.pattern2.match(name__):
                            buffer_.append(torch.norm(module__.z_theta))
            else:
                print('Cannot find the module: ' + str(name))
        return sum(buffer_)

    def update(self,xi):

        # update the relevant variables of parameters
        for name, module in self.model._modules.items():
            if self.pattern1.match(name):
                self.update_pattern1(module,xi)

            elif self.pattern2.match(name):
                self.update_pattern2(module,xi)

            elif self.pattern3.match(name):
                # the nn.Sequential() layer
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        if self.pattern1.match(name__):
                            self.update_pattern1(module__,xi)
                        elif self.pattern2.match(name__):
                            self.update_pattern2(module__, xi)
                        elif self.pattern4.match(name__):
                            pass
                        else:
                            print('Did not update the parameters of the module: ' + str(name__))
            elif self.pattern4.match(name):

                pass
            else:
                print('Cannot find the module: ' + str(name))

    def update_pattern1(self,module,xi):
        w, dw = module.weight.data, module.weight.grad.data
        module.v_w.add_(
            - dw * module.eta_theta - module.z_theta * module.v_w + module.n_w.normal_() * (
                2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        w.add_(module.v_w)

        if module.bias is not None:
            b, db = module.bias.data, module.bias.grad.data
            module.v_b.add_(
                - db * module.eta_theta - module.z_theta * module.v_b + module.n_b.normal_() * (
                    2 * module.c_theta * module.eta_theta / self.N).sqrt_())
            b.add_(module.v_b)

        # thermal inertia mu = 1
        if module.bias is not None:
            xx = module.mu * (((module.v_w ** 2).sum() + (module.v_b ** 2).sum()) / (
            w.numel() + b.numel()) - xi * module.eta_theta / self.N)
            module.z_theta.add_(xx * self.mu0)
        else:
            xx = module.mu * ((module.v_w ** 2).sum() / w.numel() - xi * module.eta_theta / self.N)
            module.z_theta.add_(xx * self.mu0)

    def update_pattern2(self,modulẻ̉̉,xi):
        #### ???? ####
        # module.z_theta = module.c_theta / xi
        wih, dwih = module.weight_ih_l0.data, module.weight_ih_l0.grad.data
        bih, dbih = module.bias_ih_l0.data, module.bias_ih_l0.grad.data
        whh, dwhh = module.weight_hh_l0.data, module.weight_hh_l0.grad.data
        bhh, dbhh = module.bias_hh_l0.data, module.bias_hh_l0.grad.data

        module.v_wih.add_(- dwih * module.eta_theta - module.z_theta * module.v_wih + module.n_wih.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        module.v_bih.add_(- dbih * module.eta_theta - module.z_theta * module.v_bih + module.n_bih.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        module.v_whh.add_(- dwhh * module.eta_theta - module.z_theta * module.v_whh + module.n_whh.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        module.v_bhh.add_(- dbhh * module.eta_theta - module.z_theta * module.v_bhh + module.n_bhh.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        wih.add_(module.v_wih)
        bih.add_(module.v_bih)
        whh.add_(module.v_whh)
        bhh.add_(module.v_bhh)

        xx = module.mu * (((module.v_wih ** 2).sum() + (module.v_bih ** 2).sum() + (
                            module.v_whh ** 2).sum() + (module.v_bhh ** 2).sum()) /
                          (wih.numel() + whh.numel() + bih.numel() + bhh.numel())
                            - xi * module.eta_theta / self.N)
        module.z_theta.add_(xx * self.mu0)

    def resample_momenta(self,xi):
        for name, module in self.model._modules.items():
            if self.pattern1.match(name):
                module.z_theta = module.c_theta / xi
                module.v_w.normal_().mul_((xi * module.eta_theta / self.N).sqrt_())
                if module.bias is not None:
                    module.v_b.normal_().mul_((xi * module.eta_theta / self.N).sqrt_())
            elif self.pattern2.match(name):
                module.z_theta = module.c_theta / xi
                module.v_wih.normal_().mul_((xi * module.eta_theta / self.N).sqrt_())
                module.v_whh.normal_().mul_((xi * module.eta_theta / self.N).sqrt_())
                module.v_bih.normal_().mul_((xi * module.eta_theta / self.N).sqrt_())
                module.v_bhh.normal_().mul_((xi * module.eta_theta / self.N).sqrt_())
            elif self.pattern3.match(name):
                # the nn.Sequential() layer
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        if self.pattern1.match(name__):
                            module__.z_theta = module__.c_theta / xi
                            module__.v_w.normal_().mul_((xi * module__.eta_theta / self.N).sqrt_())
                            if module__.bias is not None:
                                module__.v_b.normal_().mul_((xi * module__.eta_theta / self.N).sqrt_())
                        elif self.pattern4.match(name__):
                            pass
                        else:
                            print('Did not resample the momenta of the module: '+str(name__))


def f(x):
    # coef calculated by Matlab; sigma_*^2 = 0.2, lambda = 0.1;
    # minibatch var[loss] must < 0.2 (i.e. < sigma_*^2)
    g = 1. / (1 + np.exp(-x))

    return 0.8950 * g - 0.1450 * g ** 2 - 2.1000 * g ** 3 +\
           2.5500 * g ** 4 - 1.8000 * g ** 5 + 0.6000 * g ** 6

def exchange_replica(shared_dict,odd_exchange,suc_exchange,
                     num_process,first_replica_xi,xi_ladder,
                     sigma2=0.2):
    exc_num, suc_num = 0,0
    ##### exchange temperature #####
    if bool(shared_dict):
        if odd_exchange:
            start = 1
        else:
            start = 0

        for process_id in range(start, num_process, 2):
            if process_id == 0:
                first_replica_xi.append(xi_ladder[process_id])
            if process_id < num_process - 1:
                exc_num += 1
                deltaE = (1 / xi_ladder[process_id] - 1 / xi_ladder[process_id+1]) * \
                         (- shared_dict[process_id] + shared_dict[process_id+1])
                key = 'deltaE_'+str(process_id)+str(process_id+1)

                if len(shared_dict[key])<10:
                    shared_dict[key] = shared_dict[key] + [deltaE]
                else:
                    shared_dict[key] = shared_dict[key][-9:] + [deltaE]

                # MH test
                # acceptance = 1 / (1 + np.exp(-deltaE)) - random.random()

                # logistic test
                if len(shared_dict[key]) <= 1:
                    std = 0
                else:
                    std = statistics.stdev(shared_dict[key])
                print('{} variance: {}'.format(key,std**2))
                if std**2 >= sigma2:
                    print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
                    acceptance = -1
                else:
                    z_star =  np.sqrt(sigma2 - std**2) * np.random.normal()
                    u = np.random.uniform()
                    acceptance = deltaE + z_star + f(u)

                if acceptance > 0:
                    suc_num += 1
                    xi_ladder[process_id + 1], xi_ladder[process_id] = xi_ladder[process_id], xi_ladder[
                        process_id + 1]
                    # print("Successfully exchange replicas between {} and {}" \
                    #       .format(process_id, process_id + 1))
                    suc_exchange = True
                else:
                    # print("Failed to exchange replicas between {} and {}" \
                    #       .format(process_id, process_id + 1))
                    pass

        if suc_exchange:
            odd_exchange = not odd_exchange
        print("Successful exchange ratio: ", str(suc_num /(exc_num+1e-6)))
    return odd_exchange,suc_exchange,xi_ladder,first_replica_xi
