import torch
import torch.nn as nn
from data_utils import *
from torch.utils.data import DataLoader
from torch.utils.data.sampler import SubsetRandomSampler

from tqdm import tqdm
from IPython import embed
import numpy as np
from numpy import linalg as LA
import argparse, os, sys, subprocess
import copy

from deepset import *
from cdm import *

def model_eval(model, loader, batch_size, criterion, cutoff):
    with torch.no_grad():
        total_loss = 0.0
        processed_no = 0

        for j, (team1, score) in enumerate(loader):
            if team1.shape[0] != batch_size: continue

            model.zero_grad()
            out = model(team1)#outcome wrt team 1
        
            loss = criterion(out, torch.tensor(score)).float()
            total_loss += loss.data
            processed_no += 1

    return total_loss / processed_no

def columnize(res, loader_type):
    rows = []
    for team, score, pred in res:
        team = team.numpy()
        score = score.numpy()
        pred = pred.numpy()
        for i, t in enumerate(team):
            t = np.where(t)[0]
            t = " ".join([str(x) for x in t])
            s = str(score[i][0])
            p = str(pred[i][0])
            rows.append(",".join([loader_type, t,s,p])) 
    return rows

def eval_all(model, loader, batch_size, criterion):
    with torch.no_grad():
        total_loss = 0.0
        processed_no = 0

        result = []
        for j, (team1, score) in enumerate(loader):
            if team1.shape[0] != batch_size: continue

            model.zero_grad()
            out = model(team1)#outcome wrt team 1

            result.append((team1, score, out))
        
            loss = criterion(out, torch.tensor(score)).float()
            total_loss += loss.data
            processed_no += 1

    return result, total_loss / processed_no

def save_predictions(best_model, loaders, batch_size, loss_function, filename="cdm_nba_prediction.csv"):

    train_loader, val_loader, test_loader = loaders

    train_res, l = eval_all(best_model, train_loader, batch_size, loss_function)
    print(l)
    train_perf = columnize(train_res, "train")
    val_res, l = eval_all(best_model, val_loader, batch_size, loss_function)
    print(l)
    val_perf = columnize(val_res, "val")
    test_res, l = eval_all(best_model, test_loader, batch_size, loss_function)
    print(l)
    test_perf = columnize(test_res, "test")
    all_perf = train_perf + val_perf + test_perf

    with open(filename, "w") as f:
        f.write("\n".join(all_perf))


def compute_l1(l1_loss, model, factor):
    reg_loss = 0
    for param in model.parameters():
        reg_loss += l1_loss(param, target=torch.zeros_like(param))
    return factor * reg_loss

def train(args):

    cutoff = 0.0 if args.regress else 0.5
    np.random.seed(0)

    dataset = Single_DataSet(args.train_path, args.num_players, split=[args.train_split, 0.1, 0.9 - args.train_split])

    train_indices, val_indices, test_indices = dataset.get_split_indices()
    print("training set size {}, test size {}".format(len(train_indices), len(test_indices)))
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)
    test_sampler = SubsetRandomSampler(test_indices)

    train_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=train_sampler)
    val_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=valid_sampler)
    test_loader = DataLoader(dataset, batch_size=args.batch_size, sampler=test_sampler)

    if args.model == "fhoi": model = FHoi_single(args.num_players) 
    elif args.model == "cdm": model = CDM_single(args.num_players, args.embed_size)
    elif args.model == "linear": model = LR_single(args.num_players)
    else: model = DeepSet_single(args.num_players, args.embed_size, linear_dim=args.linear_dim)

    if args.load_path: model.linearload_state_dict(torch.load(args.load_path))

    print(model)

    loss_function = nn.BCELoss()
    if args.regress: loss_function = nn.MSELoss() 

    l1_loss = nn.L1Loss()

    optimizer = torch.optim.SGD(model.parameters(), lr=args.learn_rate, weight_decay=args.l2_regularization)
  
    #evaluate training set every 20
    train_eval = args.num_epochs / 20

    avg_loss = model_eval(model, train_loader, args.batch_size, loss_function, cutoff) 
    print("start train loss {}".format(avg_loss)) 
    avg_loss= model_eval(model, val_loader, args.batch_size, loss_function, cutoff) 
    print("start validation loss {}".format(avg_loss)) 

    best_loss = float("inf")
    for epoch in range(args.num_epochs):
        total_loss = 0.0
        processed_no = 0
        for i, (team1, outcome) in enumerate(train_loader):

            if team1.shape[0] != args.batch_size: continue
            model.zero_grad()
            out = model(team1)#outcome wrt team 1

            if epoch % train_eval == 0:
                processed_no += 1

            loss = loss_function(out, torch.tensor(outcome)).float()
            loss += compute_l1(l1_loss, model, factor=args.l1_regularization)
            loss.backward()
            optimizer.step()
    
            total_loss += loss.data 

        if epoch % train_eval  == 0: 
            print("train epoch {} loss {}".format(epoch, total_loss / processed_no))

        if epoch % args.eval_iter == 0:
            avg_loss = model_eval(model, val_loader, args.batch_size, loss_function, cutoff) 
            print("validation epoch {} loss {}".format(epoch, avg_loss)) 

            if avg_loss < best_loss:
                best_model_weights = copy.deepcopy(model.state_dict())
                best_loss =avg_loss 

    if args.model == "fhoi": best_model = FHoi_single(args.num_players) 
    elif args.model == "cdm": best_model = CDM_single(args.num_players, args.embed_size)
    elif args.model == "linear": best_model = LR_single(args.num_players)
    else: best_model = DeepSet_single(args.num_players, args.embed_size, linear_dim=args.linear_dim)

    best_model.load_state_dict(best_model_weights)
    test_loss= model_eval(best_model, test_loader, args.batch_size, loss_function, cutoff)

    '''
    save_predictions(best_model, [train_loader, val_loader, test_loader], args.batch_size, loss_function, filename="cdm_nba_prediction.csv")
    '''

    if len(args.save_path) > 0: torch.save(best_model.state_dict(), args.save_path) 

    print(args.train_path)
    print("test epoch {} loss {}".format(epoch, test_loss))

if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('--num_epochs', '-epochs', type=int, default=100)
    parser.add_argument('--eval_iter', '-eval', type=int, default=20, help="Evaluate accuracy every eval_iter")
    parser.add_argument('--batch_size', '-batch_size', type=int, default=50, help="Batch size")
    parser.add_argument('--learn_rate', '-lr', type=float, default=1e-3, help="Learning rate")
    parser.add_argument('--l2_regularization', '-l2_lambda', type=float, default=1e-3, help="L2 regularization")
    parser.add_argument('--l1_regularization', '-l1_lambda', type=float, default=0.0, help="L1 regularization")
    parser.add_argument('--embed_size', '-embed_size', type=int, default=2, help="Embedding size")
    parser.add_argument('--linear_dim', '-linear_dim', type=int, default=10, help="Linear dim")
    parser.add_argument('--num_players', '-num_players', type=int, default=36, help="Number of players")
    parser.add_argument('--train_split', '-train_split', type=float, default=0.7, help="Train split")
    parser.add_argument('--train_path', '-train_path', type=str, default="data/spread_data12.txt", help="File path")
    parser.add_argument('--test_path', '-test_path', type=str, default="data/spread_data12.txt", help="File path")
    parser.add_argument('--save_path', '-save_path', type=str, default="", help="Save path")
    parser.add_argument('--load_path', '-load_path', type=str, default="", help="Load path")
    parser.add_argument('--regress', '-regress', type=str, default="True", help="regress score difference")
    parser.add_argument('--data_format', '-data_format', type=str, default="Matchup", help="Dataset format")
    parser.add_argument('--model', '-model', type=str, default="fhoi", help="model")

    args = parser.parse_args()
    args.regress = args.regress == "True"
    print(args)

    train(args)

