######################################
#Relevant experimental parameters can be found with the NeurIPS tag
########################################
import scipy as sp
import numpy as np
import numpy.linalg as la
from sklearn import datasets
import sklearn
from itertools import permutations
import histools as ht
from sklearn.model_selection import KFold
from sklearn.model_selection import ShuffleSplit as ss
import multiprocessing as mp

###Experimental Parameters###
dim = 2     #NeurIPS dimension
minBin = 1
maxBin = 12  #NeurIPS max number of bins 
minRank = 1
maxRank = 8 #NeurIPS max value for r in Tucker Decomposition
fitMethod = ht.nnegTucker   #nnegParfac doesnt work so well
printBin = False     #print each iteration of bin increase

###Validation###

swapTest = True     #Swap the training and testing datasets for the evaluation folds
testKFold = False        #Use KFold for testing 
trainKFold = False      #use KFold for training

# KFold parameters
testFoldNum = 4      #Number of folds to do for final testing
hyperFoldNum = 8    #Number of folds for hyperparameter fitting

#Shuffle Split parameters
testSplits = 30
testSize = 200

trainSplits =  80
trainSize = 0.2
############################

X = ht.loadMnistPca(dim)    #NeurIPS choose dataset
X = ht.unifNormalize(X)
X = ht.dataPerm(X)
numSamp = X.shape[1]

trainKf = KFold(n_splits = hyperFoldNum, shuffle = True)

if testKFold:
    testKf = KFold(n_splits = testFoldNum, shuffle = True)
    testSplit = testKf.split(np.arange(numSamp))
else:
    ssSplit = ss(n_splits = testSplits,test_size = testSize)
    testSplit  = ssSplit.split(np.arange(numSamp))


if trainKFold:
    trainKf = KFold(n_splits = trainFoldNum, shuffle = True)
    trainSplit = trainKf.split(np.arange(numSamp))
else:
    trainKf = ss(n_splits = trainSplits,test_size = trainSize)

def test(indexes):
    #print(indexes)
    if not(swapTest):
        trainIndex = indexes[0]
        testIndex = indexes[1]
    else:
        trainIndex = indexes[1]
        testIndex = indexes[0]
    testRes = []
    Xtrain = X[:,trainIndex]
    Xtest = X[:,testIndex]
    trainNumSamp = Xtrain.shape[1]

    histHResults = np.zeros([maxBin-minBin+1,trainSplits])
    histlrHResults = np.zeros([maxBin-minBin+1, maxRank-minRank+1,trainSplits])
    currFold = 0

    for fitIndex, valIndex in trainKf.split(np.arange(trainNumSamp)):
        Xhtrain = Xtrain[:,fitIndex]
        Xhtest = Xtrain[:,valIndex]
        for binNum in range(minBin,maxBin+1):
            Dtest = ht.histTransform(Xhtest,[binNum]*dim)
            Dtrain = ht.histTransform(Xhtrain,[binNum]*dim)
            histHResults[binNum-minBin,currFold] =ht.genHistInner(Dtrain,Dtrain) - 2*ht.genHistInner(Dtrain,Dtest)
            for rank in range(minRank,maxRank+1):
                try:
                    Dhat = fitMethod(Dtrain,rank)
                    histlrHResults[binNum-minBin,rank-minRank,currFold] =ht.genHistInner(Dhat,Dhat) - 2*ht.genHistInner(Dhat,Dtest)
                except ValueError:
                    histlrHResults[binNum-minBin,rank-minRank,currFold] = np.inf
            if printBin:
                print(binNum)
                #print(rank)
        currFold+=1



    bestBin = np.argmin(np.mean(histHResults,1))+minBin
    Dtest = ht.histTransform(Xtest,[bestBin]*dim)
    print(bestBin)
    Dtrain = ht.histTransform(Xtrain,[bestBin]*dim)
    histRes = ht.genHistInner(Dtrain,Dtrain) - 2*ht.genHistInner(Dtrain,Dtest)

    bestBR = np.array(np.unravel_index(np.argmin(np.mean(histlrHResults,2)), histlrHResults.shape[:2]))
    bestBR[0]+= minBin
    bestBR[1]+= minRank
    Dtest = ht.histTransform(Xtest,[bestBR[0]]*dim)
    print(bestBR)
    Dtrain = ht.histTransform(Xtrain,[bestBR[0]]*dim)
    Dhat = fitMethod(Dtrain,bestBR[1])
    lrHistRes= ht.genHistInner(Dhat,Dhat) - 2*ht.genHistInner(Dhat,Dtest)
    
    bestl2hat = ht.genHistInner(Dhat,Dhat)
    return [histRes,lrHistRes,bestBin,bestBR[0],bestBR[1],bestl2hat]

with mp.pool.Pool() as pool:
    res = np.array(pool.map(test,testSplit))
print('Tucker Hist mean loss: ' + str(np.mean(res[:,1])) + ' +/- ' + str(np.std(res[:,1])))
print('Standard Hist mean loss: ' + str(np.mean(res[:,0])) + ' +/- ' + str(np.std(res[:,0])))
print('Standard Hist avg. bins: ' + str(np.mean(res[:,2])) + ' +/- ' + str(np.std(res[:,2])))
print('Tucker Hist avg. bins: ' + str(np.mean(res[:,3])) + ' +/- ' + str(np.std(res[:,3])))
print('Tucker Hist avg. components: ' + str(np.mean(res[:,4])) + ' +/- ' + str(np.std(res[:,4])))
print('Wilcoxon Signed rank test p-value: '+str(sp.stats.wilcoxon(res[:,0],res[:,1])[1]))
