import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from numpy.random import random_sample
import matplotlib.pyplot as plt
from scipy.optimize import minimize
import pickle as pkl
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader
import datetime


class Tractrix2DSampler():

    def __init__(self, x_mean=0, y_mean=0, x_std=1, y_std=1):
        self.x_mean = x_mean
        self.y_mean = y_mean
        self.x_std = x_std
        self.y_std = y_std


    def get_from_t(self, t):
        return ((t - np.tanh(t) - self.x_mean)/self.x_std, (1/np.cosh(t) - self.y_mean)/self.y_std )

    def get_n_samples(self, n=1000, sample_range=(-2, 2)):
        xs = []
        ys = []
        t_samples = (sample_range[1] - sample_range[0]) * random_sample(size = n) + sample_range[0]
        ts_final = []
        all_samples = []
        for t_sample in t_samples:
            x, y = self.get_from_t(t_sample)
            ts_final.append([t_sample])
            xs.append(x)
            ys.append(y)
            all_samples.append([t_sample, x, y])
        all_samples.sort()
        return xs, ys, t_samples, all_samples

    """def get_dist_from_t0_old(self, t):
        a = 2.06
        b = 0.323
        if t == 0:
            return 0
        expr = ((b*np.cosh(t)*np.arcsinh(b*np.cosh(t)/np.sqrt((a**2) - (b**2))))/(np.sqrt((b**2)*(np.cosh(t)**2) + a**2 - b**2)) - 1)
        print("two exprs for %f"%(t))
        print(expr)
        expr2 = np.sqrt((np.tanh(t)**2)/(a**2) + 1/((np.cosh(t)**2)*(b**2)))
        print(expr2)
        return expr*expr2

    def get_dist_from_t0(self, t):
        a = 2.06
        b = 0.323
        y = 1/np.cosh(t)

        expr1 = ((y**2)*((b**2) - (a**2)) + (a**2))
        expr2 = np.sqrt(expr1)
        expr3 = a*np.arctanh(expr2/a)
        print(y, expr2, expr3, np.sign(np.tanh(y)/np.cosh(y))*np.sign(y))
        return (expr2 - expr3)/(a*b)"""


    def get_distance(self, t1, t2):

        t_min = np.minimum(t1, t2)
        t_max = np.maximum(t1, t2)
        a = self.x_std
        b = self.y_std
        A = 0
        x = t_min
        dx = 0.005
        while x <= t_max:
            A += dx*np.sqrt((1 - np.tanh(x))**2/a**2 + (np.tanh(x)/np.cosh(x))**2/b**2)
            x += dx
        #print("Integrated dist %f"%(A))
        return A


    def get_interval_length(self, t1, t2):
        return self.get_distance(t1, t2)

class Net(nn.Module):

    def __init__(self, sample_mean, sample_std):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 3x3 square convolution
        # kernel
        self.layer_neurons = {0: 10, 1: 16, 2: 1}
        self.sample_mean = sample_mean
        print(self.sample_mean)
        self.sample_std = sample_std
        self.fc1 = nn.Linear(2, 10, bias=True)
        self.fc2 = nn.Linear(10, 16, bias=True)
        self.fcout = nn.Linear(16, 1, bias=True)
        self.layers = [self.fc1, self.fc2, self.fcout]
        
    def standardize_data(self, val_arr):
        val_arr = val_arr - self.sample_mean / self.sample_std
        return val_arr

    def forward(self, x):
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fcout(x)
        return x

    def get_layerwise_activation(self, layer_num, x):
        y = torch.Tensor(2)
        torch.cat([(x - torch.tanh(x) - self.sample_mean[0])/self.sample_std[0],
                                      (1/torch.cosh(x) - self.sample_mean[1])/self.sample_std[1]], out=y)
        x = y
        for i in range(layer_num + 1):
            if i == layer_num:
                x = torch.abs(self.layers[i](x))
            else:
                x = F.relu(self.layers[i](x))
        return x

    def num_flat_features(self, x):
        size = x.size()[1:]  # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features

def plot_tractrix():
    tractrix_sampler = Tractrix2DSampler()
    xs, ys, _ = tractrix_sampler.get_n_samples(n=2000, sample_range=(-4, 4))
    plt.scatter(xs, ys)
    plt.show()

def duplicate_exists(list_to_check, value_to_check, boundary_thresh):
    found_duplicate = False
    for existing_x in list_to_check:
        if np.abs(existing_x - value_to_check) < boundary_thresh:
            return True
    return found_duplicate

def safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, breakpoint_x):
    if layer_number not in linear_boundaries:
        linear_boundaries[layer_number] = {}
    if neuron_number not in linear_boundaries[layer_number]:
        linear_boundaries[layer_number][neuron_number] = [breakpoint_x]
    else:
        if not duplicate_exists(linear_boundaries[layer_number][neuron_number], breakpoint_x, 0.01):
            linear_boundaries[layer_number][neuron_number] .append(breakpoint_x)

def count_linear_regions(model_path, net=None, sample_range = (-5, 5)):
    if not net:
        net = torch.load(model_path)

    zero_threshold = 0.001
    linear_boundaries = {}

    for layer_number in range(len(net.layers) - 1):
        for neuron_number in range(net.layer_neurons[layer_number]):
            neuron_linear_regions = []
            #print(layer_number, neuron_number, net.layer_neurons[layer_number])
            #print(":::::::::::::::::::::::::::::::::::::::::::::::::::::::::")
            num_init_points = 10
            for i in range(1, num_init_points):

                start_point = sample_range[1] * i/num_init_points

                def fun_to_optimize(t):
                    x = torch.from_numpy(np.array(t)).float()
                    output = net.get_layerwise_activation(layer_number, x).detach().numpy()
                    return output[neuron_number]

                t_breakpoint = minimize(fun_to_optimize, [start_point], method='SLSQP', tol=1e-6, bounds=[(0, sample_range[1])],
                                        options={'eps': 0.000005, 'maxiter': 1000, 'disp': False})

                if t_breakpoint.fun <= zero_threshold:
                    safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, t_breakpoint.x)
                t_breakpoint_2 = minimize(fun_to_optimize, [-1 * start_point], method='SLSQP', tol=1e-6,
                                          bounds=[(sample_range[0], 0)],
                                          options={'eps': 0.000005, 'maxiter': 1000, 'disp': False})

                if t_breakpoint_2.fun <= zero_threshold:
                    safe_add_linear_boundary(linear_boundaries, layer_number, neuron_number, t_breakpoint_2.x)
    boundary_values = []
    for layer_number, layer_boundaries in linear_boundaries.items():
        for neuron_boundaries in layer_boundaries.values():

            for neuron_boundary in neuron_boundaries:
                for boundary_value in neuron_boundary:
                    if not duplicate_exists(boundary_values, boundary_value, 0.0001):
                        boundary_values.append(boundary_value)

    boundary_values.sort()
    #print(boundary_values)
    #print("Number of linear boundaries: %d "% (len(boundary_values)))
    return len(boundary_values), boundary_values


def create_new_network(x_mean, y_mean, x_std, y_std):
    net = Net(np.array([x_mean, y_mean]), np.array([x_std, y_std]))
    return net
    #torch.save(net, model_path)


def generate_training_data(tractrix_data_sampler, data_path, num_samples=200, periodic_freq=1.0, noise_scale = 0.25,
                           sample_range=(-5, 5)):
    xs, ys, ts = tractrix_data_sampler.get_n_samples(num_samples, sample_range= sample_range)
    f_vals = []
    noise_vals = np.random.normal(scale=noise_scale, size=num_samples)
    for t_val, noise_val in zip(ts, noise_vals):
        f_vals.append(np.sin(t_val*np.pi/periodic_freq)*(1 + noise_val))

    ys = np.array([ys])
    xs = np.array([xs])
    f_vals = np.array([f_vals])

    data_array = np.concatenate((xs, ys, f_vals), axis=0)
    with open(data_path, 'wb') as data_f_out:
        pkl.dump(data_array, data_f_out)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

def get_avg_distance(tractrix_sampler, linear_boundaries, sample_range = (-5, 5), num_samples = 200):
    _, _, sample_ts = tractrix_sampler.get_n_samples(num_samples, sample_range=sample_range)
    min_dists = []
    for sample_t in sample_ts:
        min_dist = np.inf
        for linear_boundary in linear_boundaries:
            sample_dist = tractrix_sampler.get_distance(sample_t, linear_boundary)
            if sample_dist < min_dist:
                min_dist = sample_dist
        #print(min_dist, sample_t)
        min_dists.append(min_dist)
    return np.mean(np.array(min_dists))



def train_model(tractrix_sampler, data_path, model_path = '', net = None, num_epochs = 60, sample_range = (-5, 5), run_number=0):
    if not net:
        net = torch.load(model_path)

    with open(data_path, 'rb') as f_in:
        all_data = pkl.load(f_in)
    fun_vals = np.reshape(all_data[2], (all_data[2].shape[0], 1))
    stacked_input = np.column_stack((all_data[0], all_data[1]))
    frac_train = int(0.8*stacked_input.shape[0])
    frac_test = stacked_input.shape[0] - frac_train
    print(frac_train, frac_test)
    train_inputs = stacked_input[:frac_train]
    test_inputs = stacked_input[-frac_test:]
    print(train_inputs.shape, test_inputs.shape)
    train_outputs = fun_vals[:frac_train]
    test_outputs = fun_vals[-frac_test:]
    criterion = nn.MSELoss()
    optimizer = optim.Adam(net.parameters(), lr=0.001)
    trainset = torch.utils.data.TensorDataset(torch.Tensor(train_inputs), torch.Tensor(train_outputs))
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True)
    testset = torch.utils.data.TensorDataset(torch.Tensor(test_inputs), torch.Tensor(test_outputs))
    testloader = torch.utils.data.DataLoader(testset, batch_size=8,
                                             shuffle=False)
    num_linear_regions = []
    linear_region_epochs = []
    average_distances = []
    total_losses = []

    max_distance = tractrix_sampler.get_interval_length(sample_range[0], sample_range[1])
    #print(max_distance)
    for epoch in range(num_epochs):
        running_loss = 0.0
        epoch_loss = 0.0

        for i, data in enumerate(trainloader):
            inputs, real_fun_vals = data

            optimizer.zero_grad()

            outputs = net(inputs)
            #print(outputs[0], real_fun_vals[0])
            loss = criterion(outputs, real_fun_vals)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            epoch_loss += loss.item()
        print('[%d] train loss: %.3f' %
              (epoch + 1, epoch_loss / i))

        with torch.no_grad():
            total_loss = 0.0
            num_passes = 0
            for data in testloader:
                inputs, real_fun_vals = data
                outputs = net(inputs)
                loss = criterion(real_fun_vals, outputs)
                total_loss += loss.item()
                num_passes += 1
            print('[%d, %d] test loss: %.3f' % (run_number, epoch + 1, total_loss/num_passes))

        if epoch % 5 == 0:
            print("start count")
            print(datetime.datetime.now())
            num_regions, linear_boundaries = count_linear_regions('', net, sample_range=sample_range)
            average_distance = get_avg_distance(tractrix_sampler, linear_boundaries, sample_range=sample_range)
            average_distances.append(average_distance / max_distance)
            num_linear_regions.append(num_regions)
            total_losses.append(total_loss)
            linear_region_epochs.append(epoch)
            print(datetime.datetime.now())

    """plt.ylabel("Num Linear Regions Tractrix")
    plt.plot(linear_region_epochs, num_linear_regions)
    plt.show()

    plt.ylabel("Average distance Tractrix")
    plt.plot(linear_region_epochs, average_distances)
    plt.show()"""

    data_arr = [linear_region_epochs, num_linear_regions, average_distances, total_losses]
    with open("data/tractrix_" + str(run_number) + ".pkl", "wb") as f_out:
        pkl.dump(data_arr, f_out)


def get_mean_and_variance(sample_range):
    tractrix_sampler = Tractrix2DSampler()
    xs, ys, ts = tractrix_sampler.get_n_samples(3000, sample_range=sample_range)

    return np.mean(xs), np.mean(ys), np.std(xs), np.std(ys)

from scipy.interpolate import make_interp_spline, BSpline


if __name__ == '__main__':
    """
    sample_range = (-5, 5)
    x_mean, y_mean, x_std, y_std = get_mean_and_variance(sample_range)
    np.random.seed(12)
    tractrix_sampler = Tractrix2DSampler(x_mean, y_mean, x_std, y_std)


    num_runs = 20
    for i in range(num_runs):
        generate_training_data(tractrix_sampler, 'data/training_data.pkl', num_samples=1000, periodic_freq=1.5,
                               sample_range=sample_range)
        net = create_new_network(x_mean, y_mean, x_std, y_std)
        train_model(tractrix_sampler, 'data/training_data.pkl', model_path='data/2d_model.torch', net=net, num_epochs=300, sample_range=sample_range, run_number=i)
    """
    tractrix_sampler = Tractrix2DSampler()
    xs, ys, ts, all_samples = trac_samples = tractrix_sampler.get_n_samples(n=1000, sample_range=(-7, 7))
    all_samples = np.array(all_samples)
    xs = all_samples[:,1]
    ys = all_samples[:,2]
    print(xs.shape)
    xnew = np.linspace(xs.min(), xs.max(), 300)
    spl = make_interp_spline(xs, ys)
    tract_smooth = spl(xnew)
    plt.plot(xnew, tract_smooth)
    plt.grid()
    plt.title('Tractrix')
    plt.savefig("tractrix.png")
    """
    #print("Calculated distance %f" % (tractrix_sampler.get_dist_from_t0(1, 3)))
    print("Distance to 0 old: %f"%(tractrix_sampler.get_dist_from_t0_old(1)))
    print("From 0: %f %f, %f, %f" % (tractrix_sampler.get_distance(0, 1), tractrix_sampler.get_distance(0, 3), tractrix_sampler.get_dist_from_t0(1), tractrix_sampler.get_dist_from_t0(3)))
    """

