#!/usr/bin/env python
# coding: utf-8

# ## NeurIPS Submission: Fictitious Play

# ### Code

# In[1]:


import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import pandas as pd
from cycler import cycler


# In[2]:


import pylab as plb
plb.rcParams['font.size'] = 8


# #### Hard-coded instances

# In[3]:


mat2 = np.array(
    [
        [2,3],
        [1,0]
    ]
)

mat4 = np.array(
    [
        [2,0,0,3],
        [0,6,7,0],
        [0,5,0,4],
        [1,0,0,0]
    ]
)


mat6 = np.array(
    [
        [2,0,0,0,0,3],
        [0,6,0,0,7,0],
        [0,0,10,11,0,0],
        [0,0,9,0,8,0],
        [0,5,0,0,0,4],
        [1,0,0,0,0,0]
    ]
)

mat8 = np.array(
    [
        [2,0,0,0,0,0,0,3],
        [0,6,0,0,0,0,7,0],
        [0,0,10,0,0,11,0,0],
        [0,0,0,14,15,0,0,0],
        [0,0,0,13,0,12,0,0],
        [0,0,9,0,0,0,8,0],
        [0,5,0,0,0,0,0,4],
        [1,0,0,0,0,0,0,0]
    ]
)


# #### Fictitious Play

# In[14]:


def FP(mat, T, state = None, eps=0, resample=False):
    (n, m) = mat.shape
    
    Vmax = np.max(mat)
    
    x = np.zeros(n)
    y = np.zeros(m)
    
    X = []
    Y = []
    
    V = []
    
    best_response_x = []
    best_response_y = []
    
    br_vector_x = []
    br_vector_y = []
    
    cntX = []
    cntY = []

    transitions = []
    
    if state is None:
        ei = np.random.randint(0,m)
        ej = np.random.randint(0,n)
    else:
        ei, ej = state
        
    per_y = np.random.uniform(0,eps, n)
    per_x = np.random.uniform(0,eps, m)
    
    ''' 
    Fictitious play
    '''
    t = 0
    while not (t == T):
        t += 1
        
        if (T == -1) and (Vmax == mat[ei,ej]):
            break
        
        '''
        Rule: choose the best response;
        ties are broken lexicographically due to 
        argmax operator, a different rule might 
        be used
        '''
        if not (t == 1):
            util_y = mat.dot(y)
            util_x = x.dot(mat)
            
            if resample:
                per_y = np.random.uniform(0,eps, n)
                per_x = np.random.uniform(0,eps, m)
            
            ei = np.argmax(util_y + per_y)
            ej = np.argmax(util_x + per_x)
            
        x[ei] += 1
        y[ej] += 1
        
        v_ei = np.zeros(m); v_ei[ei] = 1
        v_ej = np.zeros(n); v_ej[ej] = 1
        
        '''
        Update statistics
        '''
        avgx = x/t; X.append(avgx)
        avgy = y/t; Y.append(avgy)
        
        cntX.append(x)
        cntY.append(y)
        
        V.append(
            np.dot(avgx,mat).dot(avgy)
        )   

        if(best_response_x and ei != best_response_x[-1]):
            prev_ei = best_response_x[-1]
            transitions.append((t, 'i', prev_ei, ei))
            
            print(transitions[-1], mat[ei,ej])
            
        if(best_response_y and ej != best_response_y[-1]):
            prev_ej = best_response_y[-1]
            transitions.append((t, 'j', prev_ej, ej))
            
            print(transitions[-1], mat[ei,ej])
            
        best_response_x.append(ei)
        best_response_y.append(ej)   
        
        br_vector_x.append(v_ei)
        br_vector_y.append(v_ej)
        '''
        Print any results
        
        '''
        # if(not t==1): print(util_y + per_y)
        
        # print(avgx, avgy)
        # print(best_response_x[-1], best_response_y[-1])
            
    X = np.asarray(X)
    Y = np.array(Y)
    V = np.array(V)

    br_vector_x = np.array(br_vector_x)
    br_vector_y = np.array(br_vector_y)
    
    cntX = np.array(cntX)
    cntY = np.array(cntY)
            
    return (X, Y, V, transitions, br_vector_x, br_vector_y, cntX, cntY)


# #### Plot utility

# In[15]:


def plotX(X, log_timescale=False, legend=True):
    fig, ax1 = plt.subplots(1, 1, figsize=(4,3))

    T, n = X.shape

    time = np.arange(1,T,1)
    
    ax1.plot(X)
    if legend: ax1.legend(np.arange(n)+1)
    ax1.set_xlabel('iterations')
    
    if log_timescale:
        ax1.set_xscale("log")

    fig.tight_layout()
    fig.savefig("row_strategy", dpi=1200)


# #### Plot Nash gap

# In[16]:


def plot_nash_gap(mat, X, Y, V, log_timescale=False, time_lim=-1):
    T, _ = X.shape
    time = np.arange(1,T+1,1)
    
    xtA = np.einsum('ki,ij->kj', X, mat)
    br_v_y = np.max(xtA, axis=1)

    Ayt = np.einsum('ij,kj->ki', mat, Y)
    br_v_x = np.max(Ayt, axis=1)

    fig, ax1 = plt.subplots(1, 1, figsize=(4,3))

    ax1.plot(time, (br_v_x - V) + (br_v_y - V))
    ax1.set_title(r'$Gap_X$')
    ax1.set_xlabel('iterations')
    ax1.set_ylim([0,1])
    
    if log_timescale:
        ax1.set_xscale("log")
        
    fig.tight_layout(pad=0)
    fig.savefig("nash_gap", dpi=1200)


# ### Experiments

# #### Case (4x4)

# In[7]:


T4 = 1000
X4, Y4, V4, transitions4, _, _, _, _ = FP(mat4, 100 * T4, (3,0))


# In[8]:


df4 = pd.DataFrame(transitions4, columns=['t','player','from','to']); df4


# In[9]:


plotX(X4, log_timescale=True)


# In[10]:


plot_nash_gap(mat4, X4, Y4, V4, log_timescale=True)


# #### Case (6x6)

# In[11]:


T6=4577014
X6, Y6, V6, transitions6, _, _, _, _ = FP(mat6, 2 * T6, (5,0))


# In[20]:


df6 = pd.DataFrame(transitions6, columns=['t','player','from','to']); df6


# In[18]:


plotX(X6[::2], log_timescale=True, legend=False)


# In[19]:


plot_nash_gap(mat6, X6[::2], Y6[::2], V6[::2], log_timescale=True)


# In[ ]:




