import torch
import re


class SGNHT:

    def __init__(self, model, N, eta_theta0, c_theta0):
        self.pattern1 = re.compile(r'linear|conv|bn')
        self.pattern2 = re.compile(r'lstm')
        self.pattern3 = re.compile(r'layer')
        self.pattern4 = re.compile(r'shortcut|pool')
        self.N = N
        self.model = model
        self.eta_theta0, self.c_theta0 = eta_theta0, c_theta0

        # for name, module in self.model._modules.items():
        #     if self.pattern1.match(name):
        #         print("This is the module: " + str(name))
        #         size_w = module.weight.data.shape
        #         module.register_buffer('v_w', torch.zeros(size_w))
        #         module.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
        #         module.register_buffer('c_theta', torch.Tensor([c_theta0]))
        #         module.register_buffer('z_theta', module.c_theta)
        #         module.register_buffer('n_w', torch.zeros(size_w))
        #
        #         if module.bias is not None:
        #             size_b = module.bias.data.shape
        #             module.register_buffer('v_b', torch.zeros(size_b))
        #             module.register_buffer('n_b', torch.zeros(size_b))
        #     # elif self.pattern2.match(name):
        #     #     print("This is the module: " + str(name))
        #     #     size_wih = module.weight_ih_l0.data.shape
        #     #     size_bih = module.bias_ih_l0.data.shape
        #     #     size_whh = module.weight_hh_l0.data.shape
        #     #     size_bhh = module.bias_hh_l0.data.shape
        #     #     module.register_buffer('v_wih', torch.zeros(size_wih))
        #     #     module.register_buffer('v_bih', torch.zeros(size_bih))
        #     #     module.register_buffer('v_whh', torch.zeros(size_whh))
        #     #     module.register_buffer('v_bhh', torch.zeros(size_bhh))
        #     #     module.register_buffer('eta_theta', torch.Tensor([eta_theta0*self.N]))
        #     #     module.register_buffer('c_theta', torch.Tensor([c_theta0]))
        #     #     module.register_buffer('z_theta', module.c_theta)
        #     #     module.register_buffer('n_wih', torch.zeros(size_wih))
        #     #     module.register_buffer('n_bih', torch.zeros(size_bih))
        #     #     module.register_buffer('n_whh', torch.zeros(size_whh))
        #     #     module.register_buffer('n_bhh', torch.zeros(size_bhh))
        #     elif self.pattern3.match(name):
        #         # the nn.Sequential() layer
        #         for name_, module_ in module._modules.items():
        #             for name__, module__ in module_._modules.items():
        #                 print("This is the module: " + str(name__))
        #                 if self.pattern1.match(name__):
        #                     size_w = module__.weight.data.shape
        #                     module__.register_buffer('v_w', torch.zeros(size_w))
        #                     module__.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
        #                     module__.register_buffer('c_theta', torch.Tensor([c_theta0]))
        #                     module__.register_buffer('z_theta', module__.c_theta)
        #                     module__.register_buffer('n_w', torch.zeros(size_w))
        #
        #                     if module__.bias is not None:
        #                         size_b = module__.bias.data.shape
        #                         module__.register_buffer('v_b', torch.zeros(size_b))
        #                         module__.register_buffer('n_b', torch.zeros(size_b))
        #                 # elif self.pattern2.match(name__):
        #                 #     print("This is the module: " + str(name__))
        #                 #     size_wih = module__.weight_ih_l0.data.shape
        #                 #     size_bih = module__.bias_ih_l0.data.shape
        #                 #     size_whh = module__.weight_hh_l0.data.shape
        #                 #     size_bhh = module__.bias_hh_l0.data.shape
        #                 #     module__.register_buffer('v_wih', torch.zeros(size_wih))
        #                 #     module__.register_buffer('v_bih', torch.zeros(size_bih))
        #                 #     module__.register_buffer('v_whh', torch.zeros(size_whh))
        #                 #     module__.register_buffer('v_bhh', torch.zeros(size_bhh))
        #                 #     module__.register_buffer('eta_theta', torch.Tensor([eta_theta0 * self.N]))
        #                 #     module__.register_buffer('c_theta', torch.Tensor([c_theta0]))
        #                 #     module__.register_buffer('z_theta', module__.c_theta)
        #                 #     module__.register_buffer('n_wih', torch.zeros(size_wih))
        #                 #     module__.register_buffer('n_bih', torch.zeros(size_bih))
        #                 #     module__.register_buffer('n_whh', torch.zeros(size_whh))
        #                 #     module__.register_buffer('n_bhh', torch.zeros(size_bhh))
        #                 elif self.pattern4.match(name__):
        #                     # print("Shortcut doesn't have weights")
        #                     pass
        #                 else:
        #                     print('Cannot find the module: '+str(name__))
        #     else:
        #         print('Cannot find the module: ' + str(name))

        for name, module in self.model._modules.items():
            if self.pattern1.match(name):
                self.register_pattern1(module)

            elif self.pattern2.match(name):
                self.register_pattern2(module)

            elif self.pattern3.match(name):
                # the nn.Sequential() layer
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        print("This is the module: " + str(name__))
                        if self.pattern1.match(name__):
                            self.register_pattern1(module__)
                        elif self.pattern4.match(name__):
                            pass
                        else:
                            print('Cannot find the module: ' + str(name__))


            elif self.pattern4.match(name):

                pass
            else:
                print('Cannot find the module: ' + str(name))

    def register_pattern1(self,module):
        size_w = module.weight.data.shape
        module.register_buffer('v_w', torch.zeros(size_w))
        module.register_buffer('n_w', torch.zeros(size_w))
        if module.bias is not None:
            size_b = module.bias.data.shape
            module.register_buffer('v_b', torch.zeros(size_b))
            module.register_buffer('n_b', torch.zeros(size_b))

        module.register_buffer('eta_theta', torch.Tensor([self.eta_theta0 * self.N]))
        module.register_buffer('c_theta', torch.Tensor([self.c_theta0]))
        module.register_buffer('z_theta', module.c_theta)

    def register_pattern2(self,module):
        size_wih = module.weight_ih_l0.data.shape
        size_bih = module.bias_ih_l0.data.shape
        size_whh = module.weight_hh_l0.data.shape
        size_bhh = module.bias_hh_l0.data.shape
        module.register_buffer('v_wih', torch.zeros(size_wih))
        module.register_buffer('v_bih', torch.zeros(size_bih))
        module.register_buffer('v_whh', torch.zeros(size_whh))
        module.register_buffer('v_bhh', torch.zeros(size_bhh))
        module.register_buffer('eta_theta', torch.Tensor([self.eta_theta0 * self.N]))
        module.register_buffer('c_theta', torch.Tensor([self.c_theta0]))
        module.register_buffer('z_theta', module.c_theta)
        module.register_buffer('n_wih', torch.zeros(size_wih))
        module.register_buffer('n_bih', torch.zeros(size_bih))
        module.register_buffer('n_whh', torch.zeros(size_whh))
        module.register_buffer('n_bhh', torch.zeros(size_bhh))


    def get_z_theta(self):
        buffer_ = []
        for name, module in self.model._modules.items():
            if self.pattern1.match(name) or self.pattern2.match(name):
                buffer_.append(torch.norm(module.z_theta))
            elif self.pattern3.match(name):
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        if self.pattern1.match(name__) or self.pattern2.match(name__):
                            buffer_.append(torch.norm(module__.z_theta))
            else:
                print('Cannot find the module: ' + str(name))
        return sum(buffer_)

    def update(self):

        # for name, module in self.model._modules.items():
        #     if self.pattern1.match(name):
        #         w, dw = module.weight.data, module.weight.grad.data
        #         module.v_w.add_(- dw * module.eta_theta - module.z_theta * module.v_w+ module.n_w.normal_() * (
        #             2 * module.c_theta * module.eta_theta / self.N).sqrt_()) #
        #         w.add_(module.v_w)
        #
        #         if module.bias is not None:
        #             b, db = module.bias.data, module.bias.grad.data
        #             module.v_b.add_(- db * module.eta_theta - module.z_theta * module.v_b+ module.n_b.normal_() * (
        #                 2 * module.c_theta * module.eta_theta / self.N).sqrt_()) #
        #             b.add_(module.v_b)
        #             xx = ((module.v_w ** 2).sum() + (module.v_b ** 2).sum()) / (
        #                 w.numel() + b.numel()) - module.eta_theta / self.N
        #             # print('{} z_theta change: {}'.format(module,xx))
        #             module.z_theta.add_(xx)
        #         else:
        #             xx = (module.v_w ** 2).sum() / w.numel() - module.eta_theta / self.N
        #             # print('{} z_theta change: {}'.format(module, xx))
        #             module.z_theta.add_(xx)
        #
        #     # elif self.pattern2.match(name):
        #     #     wih, dwih = module.weight_ih_l0.data, module.weight_ih_l0.grad.data
        #     #     bih, dbih = module.bias_ih_l0.data, module.bias_ih_l0.grad.data
        #     #     whh, dwhh = module.weight_hh_l0.data, module.weight_hh_l0.grad.data
        #     #     bhh, dbhh = module.bias_hh_l0.data, module.bias_hh_l0.grad.data
        #     #     module.v_wih.add_(- dwih * module.eta_theta - module.z_theta * module.v_wih + module.n_wih.normal_() * (2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        #     #     module.v_bih.add_(- dbih * module.eta_theta - module.z_theta * module.v_bih + module.n_bih.normal_() * (2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        #     #     module.v_whh.add_(- dwhh * module.eta_theta - module.z_theta * module.v_whh + module.n_whh.normal_() * (2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        #     #     module.v_bhh.add_(- dbhh * module.eta_theta - module.z_theta * module.v_bhh + module.n_bhh.normal_() * (2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        #     #     wih.add_(module.v_wih)
        #     #     bih.add_(module.v_bih)
        #     #     whh.add_(module.v_whh)
        #     #     bhh.add_(module.v_bhh)
        #     #     module.z_theta.add_(((module.v_wih**2).sum() + (module.v_bih**2).sum() + (module.v_whh**2).sum() + (module.v_bhh**2).sum()) / (wih.numel() + whh.numel() + bih.numel() + bhh.numel()) - module.eta_theta / self.N)
        #     elif self.pattern3.match(name):
        #         # the nn.Sequential() layer
        #         for name_, module_ in module._modules.items():
        #             for name__, module__ in module_._modules.items():
        #                 if self.pattern1.match(name__):
        #                     w, dw = module__.weight.data, module__.weight.grad.data
        #                     module__.v_w.add_(
        #                         - dw * module__.eta_theta - module__.z_theta * module__.v_w + module__.n_w.normal_() * (
        #                             2 * module__.c_theta * module__.eta_theta / self.N).sqrt_())  #
        #                     w.add_(module__.v_w)
        #
        #                     if module__.bias is not None:
        #                         b, db = module__.bias.data, module__.bias.grad.data
        #                         module__.v_b.add_(
        #                             - db * module__.eta_theta - module__.z_theta * module__.v_b + module__.n_b.normal_() * (
        #                                 2 * module__.c_theta * module__.eta_theta / self.N).sqrt_())  #
        #                         b.add_(module__.v_b)
        #                         module__.z_theta.add_(((module__.v_w ** 2).sum() + (module__.v_b ** 2).sum()) / (
        #                             w.numel() + b.numel()) - module__.eta_theta / self.N)
        #                     else:
        #                         module__.z_theta.add_(
        #                             (module__.v_w ** 2).sum() / w.numel() - module__.eta_theta / self.N)
        #
        #                 elif self.pattern4.match(name__):
        #                     pass
        #                 else:
        #                     print('Did not update the parameters of the module: '+str(name__))
        # update the relevant variables of parameters
        for name, module in self.model._modules.items():
            if self.pattern1.match(name):
                self.update_pattern1(module)

            elif self.pattern2.match(name):
                self.update_pattern2(module)

            elif self.pattern3.match(name):
                # the nn.Sequential() layer
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        if self.pattern1.match(name__):
                            self.update_pattern1(module__)
                        elif self.pattern2.match(name__):
                            self.update_pattern2(module__)
                        elif self.pattern4.match(name__):
                            pass
                        else:
                            print('Did not update the parameters of the module: ' + str(name__))
            elif self.pattern4.match(name):

                pass
            else:
                print('Cannot find the module: ' + str(name))

    def update_pattern1(self,module):
        w, dw = module.weight.data, module.weight.grad.data
        module.v_w.add_(
            - dw * module.eta_theta - module.z_theta * module.v_w + module.n_w.normal_() * (
                2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        w.add_(module.v_w)

        if module.bias is not None:
            b, db = module.bias.data, module.bias.grad.data
            module.v_b.add_(
                - db * module.eta_theta - module.z_theta * module.v_b + module.n_b.normal_() * (
                    2 * module.c_theta * module.eta_theta / self.N).sqrt_())
            b.add_(module.v_b)

        if module.bias is not None:
            xx = ((module.v_w ** 2).sum() + (module.v_b ** 2).sum()) / (
            w.numel() + b.numel()) - module.eta_theta / self.N
            module.z_theta.add_(xx)
        else:
            xx = (module.v_w ** 2).sum() / w.numel() - module.eta_theta / self.N
            module.z_theta.add_(xx)

    def update_pattern2(self,modulẻ̉̉):
        #### ???? ####
        wih, dwih = module.weight_ih_l0.data, module.weight_ih_l0.grad.data
        bih, dbih = module.bias_ih_l0.data, module.bias_ih_l0.grad.data
        whh, dwhh = module.weight_hh_l0.data, module.weight_hh_l0.grad.data
        bhh, dbhh = module.bias_hh_l0.data, module.bias_hh_l0.grad.data

        module.v_wih.add_(- dwih * module.eta_theta - module.z_theta * module.v_wih + module.n_wih.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        module.v_bih.add_(- dbih * module.eta_theta - module.z_theta * module.v_bih + module.n_bih.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        module.v_whh.add_(- dwhh * module.eta_theta - module.z_theta * module.v_whh + module.n_whh.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        module.v_bhh.add_(- dbhh * module.eta_theta - module.z_theta * module.v_bhh + module.n_bhh.normal_() * (
        2 * module.c_theta * module.eta_theta / self.N).sqrt_())
        wih.add_(module.v_wih)
        bih.add_(module.v_bih)
        whh.add_(module.v_whh)
        bhh.add_(module.v_bhh)

        xx = ((module.v_wih ** 2).sum() + (module.v_bih ** 2).sum() + (module.v_whh ** 2).sum() + (module.v_bhh ** 2).sum()) \
             /(wih.numel() + whh.numel() + bih.numel() + bhh.numel()) - module.eta_theta / self.N
        module.z_theta.add_(xx)


    def resample_momenta(self):
        for name, module in self.model._modules.items():
            if self.pattern1.match(name):
                module.v_w.normal_().mul_((module.eta_theta / self.N).sqrt_())
                if module.bias is not None:
                    module.v_b.normal_().mul_((module.eta_theta / self.N).sqrt_())
            elif self.pattern2.match(name):
                module.v_wih.normal_().mul_((module.eta_theta / self.N).sqrt_())
                module.v_bih.normal_().mul_((module.eta_theta / self.N).sqrt_())
                module.v_whh.normal_().mul_((module.eta_theta / self.N).sqrt_())
                module.v_bhh.normal_().mul_((module.eta_theta / self.N).sqrt_())
            elif self.pattern3.match(name):
                # the nn.Sequential() layer
                for name_, module_ in module._modules.items():
                    for name__, module__ in module_._modules.items():
                        if self.pattern1.match(name__):
                            module__.v_w.normal_().mul_((module__.eta_theta / self.N).sqrt_())
                            if module__.bias is not None:
                                module__.v_b.normal_().mul_((module__.eta_theta / self.N).sqrt_())
                        elif self.pattern4.match(name__):
                            pass
                        else:
                            print('Did not resample the momenta of the module: '+str(name__))