import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import torch.optim as optim
import matplotlib.pyplot as plt
import time
import math
import numpy.matlib as npm



class Net(nn.Module):

    def __init__(self, dim, num_layer, dh, Ntrain):
        super(Net, self).__init__()

        scalar = 2
        self.dim = dim/2
        self.num_layer = num_layer

        for i in range(self.num_layer):

            name_y = 'fc' + str(i+1) + '_y'
            name_z = 'fc' + str(i+1) + '_z'
            setattr(self, name_y, nn.Linear(self.dim, self.dim*scalar))
            setattr(self, name_z, nn.Linear(self.dim, self.dim*scalar))
            
        self.dh = dh
        self.sample = np.random.rand(Ntrain,dim)
        self.Ntrain = Ntrain


    def forward(self, x):

        var_y0  = 'var' + str(0) + '_y'
        var_z0  = 'var' + str(0) + '_z'           
        bb = torch.split(x,[self.dim,self.dim],dim=1)
        vars()[var_y0] = torch.clone(bb[0])
        vars()[var_z0] = torch.clone(bb[1])

        for i in range(self.num_layer):
            name_y   = 'fc' + str(i+1)  + '_y'
            name_z   = 'fc' + str(i+1)  + '_z'
            var_y0  = 'var' + str(i)    + '_y'
            var_z0  = 'var' + str(i)    + '_z' 
            var_y1  = 'var' + str(i+1)  + '_y'
            var_z1  = 'var' + str(i+1)  + '_z'
            
            sig_y = torch.unsqueeze(torch.tanh(getattr(self, name_y)(vars()[var_z0])),2)
            K_y = torch.transpose(getattr(self, name_y).weight,0,1)
            vars()[var_y1] = torch.add(vars()[var_y0],  1 * self.dh * torch.squeeze(torch.matmul(K_y, sig_y),2))

            sig_z = torch.unsqueeze(torch.tanh(getattr(self, name_z)(vars()[var_y1])),2)
            K_z = torch.transpose(getattr(self, name_z).weight,0,1)
            vars()[var_z1] = torch.add(vars()[var_z0], -1 * self.dh * torch.squeeze(torch.matmul(K_z, sig_z),2))

        x1 = torch.cat((vars()[var_y1], vars()[var_z1]), 1)
        
        return x1


    def customized_loss_diff(self, x, output, grad_data, iter):

        # Define the weights and bias of the inverse network
        for i in range(self.num_layer):
            name_y = 'fc' + str(i+1) + '_y'
            name_z = 'fc' + str(i+1) + '_z'

            inv_name_y_weight = 'inv_fc' + str(i+1) + '_y_weight'
            inv_name_z_weight = 'inv_fc' + str(i+1) + '_z_weight'
            vars()[inv_name_y_weight] = getattr(self, name_y).weight
            vars()[inv_name_z_weight] = getattr(self, name_z).weight

            inv_name_y_bias = 'inv_fc' + str(i+1) + '_y_bias'
            inv_name_z_bias = 'inv_fc' + str(i+1) + '_z_bias'
            vars()[inv_name_y_bias] = getattr(self, name_y).bias
            vars()[inv_name_z_bias] = getattr(self, name_z).bias

        size = x.size()
        Nss = size[0]
        Jacob = torch.empty(Nss,self.dim*2,self.dim*2)

        for j in range(self.dim*2):

            output_dy = torch.clone(output)
            output_dy[:,j] = output[:,j] + 0.001

            bb = torch.split(output_dy,[self.dim,self.dim],dim=1)

            var_y0  = 'var' + str(self.num_layer-1) + '_y'
            var_z0  = 'var' + str(self.num_layer-1) + '_z'
            vars()[var_y0] = bb[0]
            vars()[var_z0] = bb[1]

            for i in range(self.num_layer-1,-1,-1):
                inv_name_y_weight = 'inv_fc' + str(i+1) + '_y_weight'
                inv_name_z_weight = 'inv_fc' + str(i+1) + '_z_weight'
                inv_name_y_bias = 'inv_fc' + str(i+1) + '_y_bias'
                inv_name_z_bias = 'inv_fc' + str(i+1) + '_z_bias'
                var_y0  = 'var' + str(i)    + '_y'
                var_z0  = 'var' + str(i)    + '_z' 
                var_y1  = 'var' + str(i-1)  + '_y'
                var_z1  = 'var' + str(i-1)  + '_z'

                sig_z = torch.tanh(torch.add(torch.matmul(vars()[inv_name_z_weight], torch.unsqueeze(vars()[var_y0],2)), torch.unsqueeze(vars()[inv_name_z_bias],1)))
                K_z = torch.transpose(vars()[inv_name_z_weight],0,1)
                vars()[var_z1] = torch.add(vars()[var_z0],  1 * self.dh * torch.squeeze(torch.matmul(K_z, sig_z),2))

                sig_y = torch.tanh(torch.add(torch.matmul(vars()[inv_name_y_weight], torch.unsqueeze(vars()[var_z1],2)), torch.unsqueeze(vars()[inv_name_y_bias],1)))
                K_y = torch.transpose(vars()[inv_name_y_weight],0,1)
                vars()[var_y1] = torch.add(vars()[var_y0], -1 * self.dh * torch.squeeze(torch.matmul(K_y, sig_y),2))

            dx = torch.cat((vars()[var_y1], vars()[var_z1]), 1)

            # Test the invertibility 
            if torch.mean(torch.abs(torch.add(-1 * output_dy, self.forward(dx)))) > 1e-5:
                print('Something is wrong in Jacobian computation')
                print(torch.mean(torch.abs(torch.add(-1 * output_dy, self.forward(dx)))))

            for k in range(self.dim*2):
                Jacob[:,j,k] = torch.add(dx[:,k], -1 * x[:,k])

        ex_data = torch.unsqueeze(grad_data,2)

        norm_data = torch.sqrt(torch.sum(torch.mul(ex_data,ex_data),1))

        JJ2 = torch.sqrt(torch.sum(torch.mul(Jacob,Jacob),2))
        JJ3 = torch.unsqueeze(JJ2,2)
        JJ4 = JJ3.expand(-1,-1,self.dim*2)
        JJJ = torch.div(Jacob, JJ4)
        JJP = 1.0*torch.clone(JJJ)

        if iter % 100 == 0 and self.dim == 1:

            # plt.cla()
            ax1.cla()
            ax1.quiver(x[:,0], x[:,1], JJP.detach().numpy()[:,1,0], JJP.detach().numpy()[:,1,1],color='black')
            plt.draw()
            plt.pause(0.001)
            

        J_det = torch.empty(self.Ntrain)
        for k in range(self.Ntrain):
            uuu,eee,vvv = torch.svd(JJJ[k,:,:])
            J_det[k] = torch.prod(eee)
        loss6 = torch.prod(J_det-1.0)
    
        loss0 = torch.squeeze(torch.matmul(JJJ, ex_data),2)
        loss1 = torch.clone(loss0)
        loss1[:,0] = 0.0

        loss2 = torch.sqrt(torch.mean(torch.sum(torch.mul(loss1,loss1),1)))

        loss = loss2 + loss6

        print(loss2, loss6)

        file11 = open("loss.txt","a")
        np.savetxt(file11, (loss2.detach().numpy(), loss6.detach().numpy(), loss.detach().numpy()))
        file11.close

        return loss


class Inv_Net(nn.Module):

    def __init__(self, dim, num_layer,dh):
        super(Inv_Net, self).__init__()

        scalar = 2
        self.dim = dim/2
        self.num_layer = num_layer

        for i in range(self.num_layer):
            name_y = 'fc' + str(i+1) + '_y'
            name_z = 'fc' + str(i+1) + '_z'
            setattr(self, name_y, nn.Linear(self.dim, self.dim*scalar))
            setattr(self, name_z, nn.Linear(self.dim, self.dim*scalar))

        self.dh = dh



    def forward(self, x):

        bb = torch.split(x,[self.dim,self.dim],dim=1)
        y = bb[0]
        z = bb[1]

        for i in range(self.num_layer-1,-1,-1):
            name_y = 'fc' + str(i+1) + '_y'
            name_z = 'fc' + str(i+1) + '_z'

            sig_z = torch.unsqueeze(torch.tanh(getattr(self, name_z)(y)),2)
            K_z = torch.transpose(getattr(self, name_z).weight,0,1) 
            z = torch.add(z,  1 * self.dh * torch.squeeze(torch.matmul(K_z, sig_z),2))

            sig_y = torch.unsqueeze(torch.tanh(getattr(self, name_y)(z)),2)
            K_y = torch.transpose(getattr(self, name_y).weight,0,1)
            y = torch.add(y, -1 * self.dh * torch.squeeze(torch.matmul(K_y, sig_y),2))

        x1 = torch.cat((y, z), 1)
        return x1    



def test_func(x, example):

    size = x.shape
    Ns = size[0]
    dim = size[1]

    if example == 1:
        f = 0.5*(np.sin(2.0* math.pi * np.sum(x,axis=1)) + 1)
        df = npm.repmat(np.expand_dims(0.5 *2.0 * math.pi * np.cos(2.0* math.pi * np.sum(x,axis=1)), axis=1),1,dim)
        
    elif example == 2:
        x1 = np.copy(x)
        x1[:,0] = x1[:,0] - 0.5
        f = np.exp(-1.0 * np.sum(x1**2, axis=1))
        df = -2.0 * x1 * npm.repmat(np.expand_dims(f, axis=1),1,dim)

    elif example == 3:
        f = x[:,0]**3 + x[:,1]**3 + x[:,0] * 0.2 + 0.6 * x[:,1]
        df = npm.repmat(np.expand_dims(f,axis=1),1,dim)
        df[:,0] = 3.0*x[:,0]**2.0 + 0.2
        df[:,1] = 3.0*x[:,1]**2.0 + 0.6    

    elif example == 4:
        f1 = 2.0*np.exp(-1 * np.sum((x-0.0) * (x-0.0), axis=1) * 2.0)
        f2 = 2.0*np.exp(-1 * np.sum((x-1.0) * (x-1.0), axis=1) * 2.0)
        f = f1 + f2
        df = -8.0 * (x-0.0) * npm.repmat(np.expand_dims(f1, axis=1),1,dim) -8.0 * (x-1.0) * npm.repmat(np.expand_dims(f2, axis=1),1,dim)

    elif example == 5:
        cc = 0.1
        ww = 0.0
        f = np.cos(np.sum(x*cc,axis=1))
        df = npm.repmat(np.expand_dims(-1 * cc * np.cos(np.sum(x,axis=1)), axis=1),1,dim)

    elif example == 6:  
        cc = 1.2
        ww = 0.0
        f = np.prod((cc**(-2.0) + (x-ww)**2.0)**(-1.0), axis=1)    
        df = npm.repmat(np.expand_dims(f,axis=1),1,dim) * -1.0 * (cc**-2.0 + (x-ww)**2.0)**-1.0 * 2.0 * (x-ww)

    elif example == 7:
        cc = 0.1
        ww = 0.0
        f = (1.0 + np.sum(x*cc,axis=1))**-(dim+1)
        df = -(dim+1) * cc * npm.repmat(np.expand_dims((1.0 + np.sum(x*cc,axis=1))**-(dim+2),axis=1),1,dim)

    elif example == 8:
        center = 0.0
        f = 0.5 * np.sum((x-center) * (x-center), axis=1)
        df = 1.0 * (x-center) 

    elif example == 9:
        center = 0.0
        f = np.sin(np.sum((x-center) * (x-center), axis=1))
        ff  = np.cos(np.sum((x-center) * (x-center), axis=1))
        df = 2.0 * (x-center) * npm.repmat(np.expand_dims(ff, axis=1),1,dim)

    elif example == 10:
        center = 0.0
        f  = np.sin(np.sum((x-center)**3.0, axis=1))
        ff = np.cos(np.sum((x-center)**3.0, axis=1))
        df = 3.0 * (x-center)**2.0 * npm.repmat(np.expand_dims(ff, axis=1),1,dim)    

    else:
        print('Wrong example number!')

    return (f,df)



def sensitivity(net, inv_net, x, example):

    size = x.shape
    Ns = size[0]
    dim = size[1]
    dx = 0.001

    sen_ind_new = np.zeros(dim)
    sen_ind_old = np.zeros(dim)

    # Transformed x
    f1,df1 = test_func(x,example)
    x1 = net(torch.from_numpy(x))

    for i in range(0,dim):

        # Perturb x1 in the transformed space
        x2 = torch.clone(x1)
        x2[:,i] = x2[:,i] + dx
        f2,df2 = test_func(inv_net(x2).detach().numpy(), example)
        sen_ind_new[i] = np.mean(np.abs(f2-f1)/dx)

        x3 = np.copy(x)
        x3[:,i] = x3[:,i] + dx
        f3,df3 = test_func(x3, example)
        sen_ind_old[i] = np.mean(np.abs(f3-f1)/dx)

    if Ns > 1000 and Npar == 2:
        index = np.argsort(x1.detach().numpy(), axis=0)
        ax.cla()
        ax.plot(x1.detach().numpy()[index,0],f1[index],'r+')
        plt.draw()
        plt.pause(0.001)


    return sen_ind_old,sen_ind_new



def to_numpy(arr):
    if isinstance(arr, np.ndarray):
        return arr
    try:
        from pycuda import gpuarray
        if isinstance(arr, gpuarray.GPUArray):
            return arr.get()
    except ImportError:
        pass
    try:
        import torch
        if isinstance(arr, torch.Tensor):
            return arr.cpu().numpy()
    except ImportError:
        pass


def gridplot(u, Nx=64, Ny=64, displacement=True, color='black', **kwargs):
    """Given a displacement field, plot a displaced grid"""
    u = to_numpy(u)
    assert u.shape[0] == 1, "Only send one deformation at a time"
    from matplotlib import pyplot as plt
    if Nx is None:
        Nx = u.shape[2]
    if Ny is None:
        Ny = u.shape[3]
    # downsample displacements
    h = np.copy(u[0,:,::u.shape[2]//Nx, ::u.shape[3]//Ny])
    # now reset to actual Nx Ny that we achieved
    Nx = h.shape[1]
    Ny = h.shape[2]
    # adjust displacements for downsampling
    h[0,...] /= float(u.shape[2])/Nx
    h[1,...] /= float(u.shape[3])/Ny
    if displacement: # add identity
        h[0,...] += np.arange(Nx).reshape((Nx,1))
        h[1,...] += np.arange(Ny).reshape((1,Ny))
    # put back into original index space
    h[0,...] *= float(u.shape[2])/Nx
    h[1,...] *= float(u.shape[3])/Ny
    # create a meshgrid of locations
    for i in range(h.shape[1]):
        plt.plot(h[0,i,:], h[1,i,:], color=color, **kwargs)
    for i in range(h.shape[2]):
        plt.plot(h[0,:,i], h[1,:,i], color=color, **kwargs)
    for ix, xn in zip([0,-1],['B','T']):
        for iy, yn in zip([0,-1],['L','R']):
            plt.plot(h[0,ix,iy],h[1,ix,iy], 'o', label='({xn},{yn})'.format(xn=xn,yn=yn))
    plt.legend()
    plt.axis('equal')
    plt.show()




# ------------------------------------------------
#                The main routine
# ------------------------------------------------
fig = plt.figure(1, figsize=(12, 3.7))
ax = fig.add_subplot(1, 3, 1)
plt.ion()
plt.show()

ax1 = fig.add_subplot(1, 3, 2)
plt.ion()
plt.show()


# Define the overall data type
torch.set_default_tensor_type(torch.DoubleTensor)

# The index of the test function
example = 2

# The dimension of input parameter space
Npar = 2

# The dimension of the output space
Nout = 1

# The number of layers
Nlayer = 7

# The number of training set
NTrain = 121

# The number of validation set
NValid = 5000

# The weigth describing anisotropy
# www = torch.ones(1,Npar)

# Step size of the ResNet
dh = 0.25

# Learning rate
learning_rate = 0.01

if Npar == 2:
    xx1 = np.linspace(0.0,1.0, num=8)
    yy1 = np.linspace(0.0,1.0, num=8)
    [xxx,yyy] = np.meshgrid(xx1,yy1)
    testx = np.concatenate((np.reshape(xxx,(8**2,1)), np.reshape(yyy,(8**2,1))), axis=1)
    

# Generate the training set
if Npar == 2:
    xx1 = np.linspace(0.0,1.0, num=11)
    yy1 = np.linspace(0.0,1.0, num=11)
    [xxx,yyy] = np.meshgrid(xx1,yy1)
    xTrain = np.concatenate((np.reshape(xxx,(11**2,1)), np.reshape(yyy,(11**2,1))), axis=1)
else:
    xTrain = np.random.rand(NTrain,Npar)


fTrain, dfTrain = test_func(xTrain, example)
xTrain_torch  = torch.as_tensor(xTrain,dtype=torch.double)
fTrain_torch  = torch.as_tensor(fTrain,dtype=torch.double)
dfTrain_torch = torch.as_tensor(dfTrain,dtype=torch.double)


# Generate the validation set
xValid = np.random.rand(NValid,Npar)
fValid, dfValid = test_func(xValid, example)
xValid_torch  = torch.as_tensor(xValid,dtype=torch.double)
fValid_torch  = torch.as_tensor(fValid,dtype=torch.double)
dfValid_torch = torch.as_tensor(dfValid,dtype=torch.double)

file1 = open("train_x.txt","w+")
file2 = open("train_f.txt","w+") 
file4 = open("valid_x.txt","w+") 
file5 = open("valid_f.txt","w+") 

np.savetxt(file1, xTrain)
np.savetxt(file2, fTrain)
np.savetxt(file4, xValid)
np.savetxt(file5, fValid)


# Build the forward network
net = Net(Npar,Nlayer,dh,NTrain)

# Initialize the gradient
net.zero_grad()

# Choose the optimizer
optimizer = optim.SGD(net.parameters(), lr=learning_rate)
# optimizer = optim.Adadelta(net.parameters(), lr=learning_rate, weight_decay = 0.01)

# Stopping indicator
counter = 0

# The value of the loss function from last iteration
pre_sen = 100000.0

# Training 
for i in range(0, 20000):

    optimizer.zero_grad()

    output = net(xTrain_torch)

    loss = net.customized_loss_diff(xTrain_torch, output, dfTrain_torch, i)

    if i % 100 == 0 and Npar == 2:
        ax1.quiver(xTrain[:,0], xTrain[:,1], dfTrain[:,0],dfTrain[:,1],color='0.6')
        plt.draw()
        plt.ion()
        plt.show()

    if i % 10 == 0:

        # Build the inverse network based on the trained forward network
        inv_net = Inv_Net(Npar,Nlayer,dh)
        inv_net.zero_grad()

        for j in range(inv_net.num_layer):
            name_y = 'fc' + str(j+1) + '_y'
            name_z = 'fc' + str(j+1) + '_z'
            getattr(inv_net,name_y).weight = torch.nn.Parameter(getattr(net,name_y).weight)
            getattr(inv_net,name_z).weight = torch.nn.Parameter(getattr(net,name_z).weight)
            getattr(inv_net,name_y).bias = torch.nn.Parameter(getattr(net,name_y).bias)
            getattr(inv_net,name_z).bias = torch.nn.Parameter(getattr(net,name_z).bias)

        if torch.mean(torch.abs(torch.add(-1 * xValid_torch, inv_net(net(xValid_torch))))) > 1e-5:
            print('inv_net is wrong!')
            print(torch.mean(torch.abs(torch.add(-1 * xValid_torch, inv_net(net(xValid_torch))))))

        sen_ind_old, sen_ind_new = sensitivity(net, inv_net, xValid, example)
        sen_ind1 = np.sum(sen_ind_new)
        sen_ind2 = np.sum(sen_ind_old)
        sen_ind  = sen_ind1/sen_ind2
        print('--------------------------')
        print('Validation data')
        print(sen_ind_new)
        print(i,torch.clone(loss).detach().numpy(),sen_ind, sen_ind1, sen_ind2)
        file9 = open("sen_valid_new.txt","w")
        np.savetxt(file9, sen_ind_new)
        file10 = open("sen_valid_old.txt","w")
        np.savetxt(file10, sen_ind_old)
       

        sen_ind_old, sen_ind_new = sensitivity(net, inv_net, xTrain, example)
        sen_ind1 = np.sum(sen_ind_new)  
        sen_ind2 = np.sum(sen_ind_old)
        sen_ind  = sen_ind1/sen_ind2
        print('--------------------------')
        print('Training data')
        print(sen_ind_new)
        print(i,torch.clone(loss).detach().numpy(),sen_ind)
        file7 = open("sen_train_new.txt","w")
        np.savetxt(file7, sen_ind_new)
        file8 = open("sen_train_old.txt","w")
        np.savetxt(file8, sen_ind_old)


        # store the data for 2D plot
        if Npar == 2:
            testo_torch = net(torch.from_numpy(testx))#-torch.from_numpy(testx)
            testo = testo_torch.detach().numpy()
            testo1 = np.reshape(testo[:,0],(1,1,8,8))
            testo2 = np.reshape(testo[:,1],(1,1,8,8))
            testo3 = np.concatenate((testo1,testo2), axis=1)

            ax3 = fig.add_subplot(1, 3, 3)
            ax3.cla()
            gridplot(testo3, Nx=8, Ny=8, displacement=False)

        if loss < 0.0001: 
            break

    loss.backward()

    optimizer.step()
    
    output = net(xTrain_torch)
    output_valid = net(xValid_torch)

    if i % 10 == 0:
        # torch.save(net, './Pnet.pth')
        file3 = open("train_trans_x.txt","w")
        file6 = open("valid_trans_x.txt","w") 
        np.savetxt(file3, output.detach().numpy())
        np.savetxt(file6, output_valid.detach().numpy())




# Build the inverse network based on the trained forward network
inv_net = Inv_Net(Npar,Nlayer,dh)
inv_net.zero_grad()

for i in range(inv_net.num_layer):
    name_y = 'fc' + str(i+1) + '_y'
    name_z = 'fc' + str(i+1) + '_z'
    getattr(inv_net,name_y).weight = torch.nn.Parameter(getattr(net,name_y).weight)
    getattr(inv_net,name_z).weight = torch.nn.Parameter(getattr(net,name_z).weight)
    getattr(inv_net,name_y).bias = torch.nn.Parameter(getattr(net,name_y).bias)
    getattr(inv_net,name_z).bias = torch.nn.Parameter(getattr(net,name_z).bias)    

# Test the invertibility of the inv_net
if torch.mean(torch.abs(torch.add(-1 * xTrain_torch, inv_net(net(xTrain_torch))))) > 1e-5:
    print('inv_net is wrong!')
    print(torch.mean(torch.abs(torch.add(-1 * xTrain_torch, inv_net(net(xTrain_torch))))))

# Compare the sensitivities
sen_ind_old, sen_ind_new = sensitivity(net, inv_net, xTrain, example)
print(sen_ind_old)
print(sen_ind_new)

sen_ind_old, sen_ind_new = sensitivity(net, inv_net, xValid, example)
print(sen_ind_old)
print(sen_ind_new)

ffTrain = net(xTrain_torch)



















