import pandas as pd
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.animation as animation
from mpl_toolkits import mplot3d
from scipy.integrate import odeint, solve_ivp
from scipy.stats import entropy
from scipy.linalg import sinm, cosm, logm
from scipy.linalg import expm
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from pathlib import Path
from sympy.physics.quantum import TensorProduct
from tqdm import tqdm
import cmath
from helpers import *

###### Multiplicative Weights Update ######

def updateMP(current, adv, eps, payoff1, payoff2):
    '''
    Matching Pennies update rule for classical multiplicative weights.
    Takes in current x,y strategy vectors and returns new x,y vectors.
    '''
    zx = current[0].item()*(1+eps)**((payoff1@adv)[0].item()) + current[1].item()*(1+eps)**((payoff1@adv)[1].item())
    zy = adv[0].item()*(1+eps)**((payoff2@current)[0].item()) + adv[1].item()*(1+eps)**((payoff2@current)[1].item())
    new_xh = current[0].item()*((1+eps)**((payoff1@adv)[0].item())/zx)
    new_xt = current[1].item()*((1+eps)**((payoff1@adv)[1].item())/zx)
    new_yh = adv[0].item()*((1+eps)**((payoff2@current)[0].item())/zy)
    new_yt = adv[1].item()*((1+eps)**((payoff2@current)[1].item())/zy)
    x_new = np.array([[new_xh], [new_xt]])
    y_new = np.array([[new_yh], [new_yt]])
    check = (1+eps)*((-1-eps)*x_new.T@payoff2@adv - (1+eps)*current.T@payoff1@y_new)
    return (x_new, y_new, check.item())

def MWUAlgorithmMP(T, payoff1, payoff2, px=np.array([[0.7], [0.3]]), py=np.array([[0.4], [0.6]]), eps=np.exp(0.1/8)-1):
    '''
    Classical MWU algorithm for Matching Pennies. Returns trajectories for T time-steps given initial conditions and payoff matrices.
    '''
    if (px.shape or py.shape != (2,1)):
        px.reshape(2,1)
        py.reshape(2,1)
    check = []
    current_x = px
    current_y = py
    print('Running MWU for', T, 'iterations')
    for i in tqdm(range(2,T+2)):
        update = updateMP(current_x, current_y, eps, payoff1, payoff2)
        new_x = update[0]
        new_y = update[1]
        check.append(update[2])
        px = np.append(px, new_x, axis=1)
        py = np.append(py, new_y, axis=1)

        current_x = new_x
        current_y = new_y
        
    data={}
    data["x_heads"]=px[0]
    data["x_tails"]=px[1]
    data["y_heads"]=py[0]
    data["y_tails"]=py[1]
    data['check'] = check
    print('Complete!')
    print('Time average x: ', np.average(px, axis=1))
    print('Time average y: ', np.average(py, axis=1))
    return data

def updateRPS(current, adv, eps, payoff1, payoff2):
    '''
    Rock-Paper-Scissors update rule for classical multiplicative weights.
    Takes in current x,y strategy vectors and returns new x,y vectors.
    '''
    zx = current[0].item()*(1+eps)**((payoff1@adv)[0].item()) + current[1].item()*(1+eps)**((payoff1@adv)[1].item()) + current[2].item()*(1+eps)**((payoff1@adv)[2].item())
    zy = adv[0].item()*(1+eps)**((payoff2@current)[0].item()) + adv[1].item()*(1+eps)**((payoff2@current)[1].item()) + adv[2].item()*(1+eps)**((payoff2@current)[2].item())
    new_x_rock = current[0].item()*((1+eps)**((payoff1@adv)[0].item())/zx)
    new_x_paper = current[1].item()*((1+eps)**((payoff1@adv)[1].item())/zx)
    new_x_scissors = current[2].item()*((1+eps)**((payoff1@adv)[2].item())/zx)
    new_y_rock = adv[0].item()*((1+eps)**((payoff2@current)[0].item())/zy)
    new_y_paper = adv[1].item()*((1+eps)**((payoff2@current)[1].item())/zy)
    new_y_scissors = adv[2].item()*((1+eps)**((payoff2@current)[2].item())/zy)
    x_new = np.array([[new_x_rock], [new_x_paper], [new_x_scissors]])
    y_new = np.array([[new_y_rock], [new_y_paper], [new_y_scissors]])
    check = (1+eps)*((-1-eps)*x_new.T@payoff2@adv - (1+eps)*current.T@payoff1@y_new)
    return (x_new, y_new, check.item())

def MWUAlgorithmRPS(T, payoff1, payoff2, px=np.array([[0.3], [0.3], [0.4]]), py=np.array([[0.7], [0.2], [0.1]]), eps=np.exp(0.1/8)-1):
    '''
    Classical MWU algorithm for Rock-Paper-Scissors. Returns trajectories for T time-steps given initial conditions and payoff matrices.
    '''
    if (px.shape or py.shape != (3,1)):
        px.reshape(3,1)
        py.reshape(3,1)
    check = []
    current_x = px
    current_y = py
    print('Running MWU for', T, 'iterations')
    for i in tqdm(range(2,T+2)):
        update = updateRPS(current_x, current_y, eps, payoff1, payoff2)
        new_x = update[0]
        new_y = update[1]
        check.append(update[2])
        px = np.append(px, new_x, axis=1)
        py = np.append(py, new_y, axis=1)

        current_x = new_x
        current_y = new_y
    
    data={}
    data["x_rock"]=px[0]
    data["x_paper"]=px[1]
    data["x_scissors"]=px[2]
    data["y_rock"]=py[0]
    data["y_paper"]=py[1]
    data["y_scissors"]=py[2]
    data['x'] = px.T
    data['y'] = py.T
    data['check'] = check
    print('Complete!')
    
    print('Time average x: ', np.average(px, axis=1))
    print('Time average y: ', np.average(py, axis=1))
    return data

def RunParallelAlgo(payoff, eps, n, m, s_a=np.array([0.45,0.55]), s_b=np.array([0.55, 0.45]), extra=1):
    '''
    Run MMWU algorithm as defined in Algorithm 1 of the main paper. 
    Returns trajectories of MMWU given initial conditions and step-size.
    '''
    mu = eps/8
    N = int(np.ceil(64.0*np.log(n*m)/eps**2))
    if len(s_a) != n or len(s_b) != m:
        print('Size mismatch for initial conditions.')
        return

    payoff_b = -payoff.T
    R = GetRMatrix(payoff)
    R_b = GetRMatrix(payoff_b)
    
    if s_a.ndim == 1 and s_b.ndim == 1:
        A=[np.diag(s_a)]
        B=[np.diag(s_b)]
        rho = [np.divide(A,np.trace(A[0]))[0]]
        sig = [np.divide(B,np.trace(B[0]))[0]]
    elif s_a.ndim == 2 and s_b.ndim == 2:
        A=[s_a]
        B=[s_b]
        rho = [np.divide(A,np.trace(A[0]))[0]]
        sig = [np.divide(B,np.trace(B[0]))[0]]
        
    print('Initial rho: ', rho[0])
    print('Initial sig: ', sig[0])
    
    phi_A = []
    phi_B = []
    cumsum_A = np.zeros((n,n), dtype=complex)
    cumsum_B = np.zeros((m,m), dtype=complex)

    for j in tqdm(range(extra*N)):
        phi_A_j = np.trace(np.array(R@np.kron(np.eye(n), sig[j].T)).reshape(n, m, n, m), axis1=1, axis2=3)
        cumsum_A += phi_A_j
        A_j = expm(mu*cumsum_A)@A[0]
        rho_j = (A_j/np.trace(A_j))
        phi_A.append(phi_A_j)
        A.append(A_j)
        rho.append(rho_j)
        
        phi_B_j = np.trace(np.array(R_b@np.kron(np.eye(m), rho[j].T)).reshape(n, m, n, m), axis1=1, axis2=3)
        cumsum_B += phi_B_j
        B_j = expm(mu*cumsum_B)@B[0]
        sig_j = (B_j/np.trace(B_j))
        phi_B.append(phi_B_j)
        B.append(B_j)
        sig.append(sig_j)
    
    rho_avg = np.average(rho, axis=0)
    sig_avg = np.average(sig, axis=0)  
    
    print('Number of iterations: ', N)
    print('Equilibrium rho: ', rho_avg)
    print('Equilibrium sig: ', sig_avg)
    
    data = {}
    data['rho'] = rho
    data['sig'] = sig
    data['A'] = A
    data['B'] = B
    return data

def MWUAlgorithmMPDecrease(T, payoff1, payoff2, px=np.array([[0.7], [0.3]]), py=np.array([[0.4], [0.6]]), exponent = 1/2):
    '''
    Classical MWU for Matching Pennies with decreasing step-size. 
    '''
    if (px.shape or py.shape != (2,1)):
        px.reshape(2,1)
        py.reshape(2,1)
    check=[]
    current_x = px
    current_y = py
    print('Running MWU for', T, 'iterations')
    for i in tqdm(range(T)):
        if i<1:
            eps=1
        else:
            eps = 1/(i+1)**(exponent)
        update = updateMP(current_x, current_y, eps, payoff1, payoff2)
        new_x = update[0]
        new_y = update[1]
        check.append(update[2])

        px = np.append(px, new_x, axis=1)
        py = np.append(py, new_y, axis=1)

        current_x = new_x
        current_y = new_y
        
    data={}
    data["x_heads"]=px[0]
    data["x_tails"]=px[1]
    data["y_heads"]=py[0]
    data["y_tails"]=py[1]
    data['check']=check
    print('Complete!')
    print('Time average x: ', np.average(px, axis=1))
    print('Time average y: ', np.average(py, axis=1))
    return data

def MWUAlgorithmRPSDecrease(T, payoff1, payoff2, px=np.array([[0.3], [0.3], [0.4]]), py=np.array([[0.3], [0.3], [0.4]]), exponent=1/2):
    '''
    Classical MWU for Rock-Paper-Scissors with decreasing step-size. 
    '''
    if (px.shape or py.shape != (3,1)):
        px.reshape(3,1)
        py.reshape(3,1)
    check=[]
    current_x = px
    current_y = py
    print('Running MWU for', T, 'iterations')
    for i in tqdm(range(T)):
        if i<1:
            eps=1
        else:
            eps = 1/(i+1)**(exponent)

        update = updateRPS(current_x, current_y, eps, payoff1, payoff2)
        new_x = update[0]
        new_y = update[1]
        check.append(update[2])

        px = np.append(px, new_x, axis=1)
        py = np.append(py, new_y, axis=1)

        current_x = new_x
        current_y = new_y
    
    data={}
    data["x_rock"]=px[0]
    data["x_paper"]=px[1]
    data["x_scissors"]=px[2]
    data["y_rock"]=py[0]
    data["y_paper"]=py[1]
    data["y_scissors"]=py[2]
    data['x'] = px.T
    data['y'] = py.T
    data['check']=check
    print('Complete!')
    
    print('Time average x: ', np.average(px, axis=1))
    print('Time average y: ', np.average(py, axis=1))
    return data

def RunParallelAlgoEpsDecrease(payoff, n, m, s, extra=1, N=8873, exponent=1/2, alt=False):
    '''
    MMWU algorithm with decreasing step-size. Returns trajectories given initial conditions and payoff matrix.
    '''
    # Automatically get diagonal R matrix if complex R not specified 
    if payoff.shape == (n,m):
        R = GetRMatrix(payoff)
    else: R = payoff
    
    # Generate initial matrices A and B from list of initial conditions
    s1 = split_list(s)[0]
    s2 = split_list(s)[1]
    s_converted = np.concatenate([s1, s2])      
    A = [s_converted[0:n**2].reshape(n,n)]
    B = [s_converted[n**2:2*n**2].reshape(n,n)]
    rho = [np.divide(A,np.trace(A[0]))[0]]
    sig = [np.divide(B,np.trace(B[0]))[0]]
    
    phi_A = []
    phi_B = []
    cumsum_A = np.zeros((n,n), dtype=complex)
    cumsum_B = np.zeros((m,m), dtype=complex)
    
    # Update steps for parallel algo
    for j in tqdm(range(extra*N)):
        if j<1:
            mu = np.log(2)
        else:
            mu = np.log(1+(1/(j+1)**(exponent)))

        phi_A_j = mu*np.trace(np.array(R@np.kron(np.eye(n), sig[j].T)).reshape(n, m, n, m), axis1=1, axis2=3)

        cumsum_A += phi_A_j
        A_j = expm(cumsum_A)@A[0]
        rho_j = (A_j/np.trace(A_j))
        phi_A.append(phi_A_j)
        A.append(A_j)
        rho.append(rho_j)
        
        # Option for alternating updates instead of parallel
        if alt:
            phi_B_j = -mu*np.trace(np.array(R@np.kron(rho[j+1].T, np.eye(m))).reshape(n,m,n,m), axis1=0, axis2=2)
        else:
            phi_B_j = -mu*np.trace(np.array(R@np.kron(rho[j].T, np.eye(m))).reshape(n,m,n,m), axis1=0, axis2=2)
        cumsum_B += phi_B_j
        B_j = expm(cumsum_B)@B[0]
        sig_j = (B_j/np.trace(B_j))
        phi_B.append(phi_B_j)
        B.append(B_j)
        sig.append(sig_j)
    rho_avg = np.average(rho, axis=0)
    sig_avg = np.average(sig, axis=0)  
    
    print('Number of iterations: ', N)
    print('Equilibrium rho: ', np.linalg.eigvals(rho_avg))
    print('Equilibrium sig: ', np.linalg.eigvals(sig_avg))
    
    data = {}
    data['R'] = R
    data['rho'] = rho
    data['sig'] = sig
    data['A'] = A
    data['B'] = B
    data['nash 1'] = rho_avg
    data['nash 2'] = sig_avg
    return data

###### Replicator Dynamics ######

def odeintz(func, z0, t, **kwargs):
    '''
    An odeint-like function to solve complex valued differential equations.
    '''
    _unsupported_odeint_args = ['Dfun', 'col_deriv', 'ml', 'mu']
    bad_args = [arg for arg in kwargs if arg in _unsupported_odeint_args]
    if len(bad_args) > 0:
        raise ValueError("The odeint argument %r is not supported by "
                         "odeintz." % (bad_args[0],))

    z0 = np.array(z0, dtype=np.complex128, ndmin=1)

    def realfunc(x, t, *args):
        z = x.view(np.complex128)
        dzdt = func(z, t, *args)
        return np.asarray(dzdt, dtype=np.complex128).view(np.float64)

    result = odeint(realfunc, z0.view(np.float64), t, **kwargs)

    if kwargs.get('full_output', False):
        z = result[0].view(np.complex128)
        infodict = result[1]
        return z, infodict
    else:
        z = result.view(np.complex128)
        return z
    
def f_deriv(s,t,R,n,m):
    '''
    Quantum replicator equations. Returns time derivative given state of the system.
    '''
    rho = s[0:n**2].reshape(n,n)
    sig = s[n**2:2*n**2].reshape(n,n)
  
    phi = np.trace(np.array(R@np.kron(np.eye(n), sig.T)).reshape(n,m,n,m), axis1=1, axis2=3)
    phi_star = np.trace(np.array(R@np.kron(rho.T, np.eye(m))).reshape(n,m,n,m), axis1=0, axis2=2)
    
    dxdt = (rho@(phi-np.trace(rho@phi)*np.eye(n))).flatten()
    dydt = (sig@(-phi_star+np.trace(sig@phi_star)*np.eye(m))).flatten()
    
    return np.concatenate([dxdt, dydt])

def rep_trajectory(payoff, s, f=f_deriv, timestep=0.1, numstep=2000, plot=True, n=2, m=2):
    '''
    Computes quantum replicator trajectories given initial conditions and payoff matrix.
    '''
    # generate an expanded list that represents the initial conditions (first half of list is 1st player)
    s1 = split_list(s)[0]#np.diag(split_list(s)[0]) 
    s2 = split_list(s)[1]#np.diag(split_list(s)[1])
    s_converted = np.concatenate([s1, s2])

    if payoff.shape == (n,m):
        R = GetRMatrix(payoff)
    else: R = payoff
    
    # runs complex odeint
    tvals = np.arange(numstep)*timestep
    traj = odeintz(f, s, tvals, args=(R, n, m))
    
    x_1d = traj[:,0:n**2]
    y_1d = traj[:,n**2:2*n**2]
    
    x = [i.reshape(n,n) for i in x_1d]
    y = [i.reshape(m,m) for i in y_1d]
    # get eigenvalues of the data
    x_eigs = np.array([np.linalg.eigvals(i) for i in x])
    y_eigs = np.array([np.linalg.eigvals(i) for i in y])
    # convert data to a simpler format
    x_flat = np.array([np.diag(i.reshape(n,n)) for i in x_1d])
    y_flat = np.array([np.diag(i.reshape(m,m)) for i in y_1d])
    
    data={}
    data['traj']=traj
    data["times"]=tvals
    data["x"]=x
    data["y"]=y
    data['x eigvals']=x_eigs
    data['y eigvals']=y_eigs
    if plot:
        fig1 = plt.figure(figsize=(16,8)) 
        plt.subplot(1, 2, 1)
        plt.plot(x_eigs[:,0])
        plt.plot(y_eigs[:,0])
        plt.title('Eigenvalue trajectories')
        plt.xlabel("Time")
        plt.ylabel("Eigenvalues")
        plt.legend(['x1','y1'], loc='best')
        plt.grid();
        
        plt.subplot(1, 2, 2)
        plt.plot(x_flat[:,0])
        plt.plot(y_flat[:,0])
        plt.title('Replicator trajectories')
        plt.xlabel("Time")
        plt.ylabel("System State")
        plt.legend(['x1','y1'], loc='best')
        plt.grid();
    return data