import numpy

import matplotlib.pyplot as plt
from scipy.integrate import odeint
from torchdiffeq import odeint as odeint_torch
import torch

def plot_2d_vector_field(func, x_min, x_mx, y_min, y_max, normalize =True, scale = None, model_d = 2, points = 20):
    """
    plot the vector field of a function f within the given range.
    Always draws the first two dimensions if model_d > 2
    :params
        f: function
        x_min: float, the minimum value of x
        x_max: float, the maximum value of x
        y_min: float, the minimum value of y
        y_max: float, the maximum value of y
        model_d: number of dimensions of the model data
    :return:
        none
    """
    fig_size = (8, 8)
    f, ax = plt.subplots(figsize=fig_size)
    x = numpy.linspace(x_min, x_mx, points)
    y = numpy.linspace(y_min, y_max, points)
    X, Y = numpy.meshgrid(x, y)
    U = numpy.zeros((points, points))
    V = numpy.zeros((points, points))
    for i in range(points):
        for j in range(points):
            data_point = numpy.array([X[i, j], Y[i, j]] + [0] * (model_d - 2))
            element = func(data_point, 0)
            U[i, j] = element[0]
            V[i, j] = element[1]
    # plt.quiver(X, Y, U, V, units='width')
    # # draw the x-axis and y-axis, thin
    # plt.axhline(y=0, color='k', linestyle='--', linewidth=0.5)
    # plt.axvline(x=0, color='k', linestyle='--', linewidth=0.5)
    if normalize:
        U_ = U / numpy.sqrt(U ** 2 + V ** 2)
        V = V / numpy.sqrt(U ** 2 + V ** 2)
        U = U_
    if scale is not None:
        ax.quiver(X, Y, U, V, units='width', scale = scale)
    else:
        ax.quiver(X, Y, U, V)
    # draw the x-axis and y-axis, thin
    ax.axhline(y=0, color='k', linestyle='--', linewidth=0.5)
    ax.axvline(x=0, color='k', linestyle='--', linewidth=0.5)
    ax.set_xlim([x_min, x_mx])
    ax.set_ylim([y_min, y_max])
    return ax

    

def solve_ivp(f, x_0, t_start, t_end, t_step = 1):
    """
    generate the flow data of a function f from x_0 for t=t_range
    params:
        f: function
        x_0: numpy array, the initial point
        t_range: numpy array, the time range
    return:
        X: numpy array, the flow data
        x_dot, numpy array, the derivative of the flow data
    """
    t = numpy.arange(t_start, t_end, t_step)
    X = odeint(f, x_0, t)
    x_dot = numpy.zeros(X.shape)
    for i in range(X.shape[0]):
        x_dot[i] = f(X[i], 0)
    return X, x_dot

def model_to_f(model):
    # this function converts a model to a numpy function
    return lambda x, t: model(torch.Tensor(x))




def forward_step(f, x0, t_steps, dt):
    """
    solve the ode using forward euler method
    :params
        f: dscrete time function f(x(t), dt) = x(t+dt) where x is (b, d)
        x0: numpy array, the initial point
        t_steps: int, the number of steps
        dt: float, the time step
    :return:
        X: numpy array, the flow data
    """
    X = torch.zeros((t_steps, x0.shape[0], x0.shape[1]))
    X[0] = x0
    for i in range(1, t_steps):
        X[i] = f(X[i - 1], dt)
    return X
    
def forward_step_until_converge(f, x0, dt, max_steps = 1000, tol = 1e-5):
    """
    solve the ode using forward euler method
    :params
        f: dscrete time function f(x(t), dt) = x(t+dt) where x is (b, d)
        x0: numpy array, the initial point
        t_steps: int, the number of steps
        dt: float, the time step
    :return:
        X: numpy array, the flow data
    """
    X = torch.zeros((max_steps, x0.shape[0], x0.shape[1]))
    X[0] = x0
    for i in range(1, max_steps):
        X[i] = f(X[i - 1], dt)
        if torch.norm(X[i] - X[i-1]) < tol:
            break
    return X[:i+1]


def DTWD(trajectory, reference):
    """
    calculate the DTW distance between trajectory and reference
    :params
        trajectory: torch array [n,d], the trajectory
        reference: list of N torch [d] arrays
    :return:
        DTWD distance
    """

    def dist(x, y):
        return torch.norm(x - y)

    n = trajectory.shape[0]
    N = len(reference)
    dists = torch.zeros((n, N))
    for i in range(n):
        for j in range(N):
            dists[i, j] = dist(trajectory[i], reference[j])
    
    min_x = torch.min(dists, dim=0).values / N
    min_y = torch.min(dists, dim=1).values / n
    assert min_x.shape[0] == N
    assert min_y.shape[0] == n
    return (min_x.sum() + min_y.sum()).item()

    