# %% [markdown]
# Ripley -- immuno

# %%

import gudhi as gd
import numpy as np
import mma
from sklearn.neighbors import KernelDensity
from sklearn.svm import SVC
from os.path import expanduser
from os import walk
from pandas import read_csv
import matplotlib.pyplot as plt
from xgboost import XGBClassifier
from os.path import exists
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from sklearn.ensemble import RandomForestClassifier
from sklearn.svm import SVC
from random import choice
from mma_wrappers.mmawrappers import *

from joblib import Parallel, delayed, cpu_count


# %%
X, labels_name = get_immuno_regions()
labels = LabelEncoder().fit_transform(labels_name)
n_jobs = cpu_count()

from argparse import ArgumentParser
p = ArgumentParser()
p.add_argument("-p","--pipeline", default="image", type=str, help="pipeline to consider : either image (our representation) or landscape, or landscape1d (compare with 1d parameter pipelines)")
p.add_argument("-k","--k", default=5, type=int, help="Number of folds for train and for test")
p.add_argument("-c","--classifier", default="rf", type=str, help="Classifier to put after the vectorization : rf, svm")
p.add_argument("-C","--complex", default="rips", type=str, help="The triangulation to consider : rips or alpha")
args = p.parse_args()
print("Arguments", args, flush=1)



# For test purposes
#p = np.random.choice(range(len(X)),50, replace =False)
#X = [x for i, x in enumerate(X) if i in p]
#labels = [l for i, l in enumerate(labels) if i in p]
# %%



# %%


# %%
#def get_rips(x, **kwargs)->mma.SimplexTree:
#	st = gd.RipsComplex(points = x, max_edge_length = kwargs.get("threshold", np.inf), sparse = kwargs.get("sparse",None)).create_simplex_tree()
#	st = mma.from_gudhi(st, parameters=2)
#	kde = KernelDensity(kernel='gaussian', bandwidth=kwargs.get("kde_bandwidth", 0.05)).fit(x)
#	codensity = -kde.score_samples(x)
#	st.fill_lowerstar(codensity, 1)
#	if kwargs.get("collapse", True):
#		st.collapse_edges(num=100, max_dimension=1, progress=False); st.collapse_edges(num=100, strong=0, progress=0)
#	st.expansion(kwargs.get("max_dimension", 2))
#	return st
#def get_landscape(x, basepoint:list[int] = [0,0], degree:int=1, **kwargs):
#	st = get_rips(x, collapse=0, max_dimension=1).to_gudhi(basepoint=basepoint)
#	st.collapse_edges(10)
#	st.expansion(2)
#	st.compute_persistence()
#	return st.persistence_intervals_in_dimension(degree)
#def get_landscapes(X, n_jobs=5, progress=True, **kwargs):
#	from gudhi.representations import Landscape
#	print("Computing Diagrams...", flush=1)
#	dgms = Parallel(n_jobs=n_jobs)(delayed(get_landscape)(x, **kwargs) for x in tqdm(X, disable= not progress, desc="Computing diagrams"))
#	print("Computing Landscapes...", flush=1)
#	landscapes = Landscape().fit_transform(dgms)
#	return landscapes

class AlphaComplex(BaseEstimator, TransformerMixin):
    def __init__(self, dtype=gd.SimplexTree) -> None:
        super().__init__()
        self.max_dimension=None
        self.dtype = dtype
        return
    def fit(self, X:np.ndarray|list, y=None):
        if len(X) == 0:	return self
        return self
    def transform(self,X):
        def get_st(x)->mma.SimplexTree:
            st= gd.AlphaComplex(points = x).create_simplex_tree()
            return st
        if self.dtype is None:
            return [get_st, X]
        return [get_st(x) for x in X]

# %%
#landscapes = get_landscapes(X, basepoint=[0.1,1.], degree=1, n_jobs=cpu_count(), sparse=0.2, threshold=0.2)

params = {}


if args.complex == "rips":
	cplxn = "rips"
	cplx_pipe = RipsDensity2SimplexTree()
	params.update({
		"rips__bandwidth" : [0.05,0.02,0.07], # [0.01,0.02,0.05,0.07,0.1],
		"rips__sparse": [None],
		"rips__threshold":[0.1],
		"rips__max_dimension" : [2],
		"rips__dtype": [None], # None returns a delayed computation, for parallelization
	})
elif args.complex == "alpha":
	cplxn = "alpha"
	cplx_pipe = AlphaDensity2ST()
	params.update({
		"alpha__bandwidth" : [0.05,0.02,0.07], # [0.01,0.02,0.05,0.07,0.1],
		"alpha__threshold":[1],
		"alpha__dtype": [None], # None returns a delayed computation, for parallelization
	})

################################ CLASSIFIERS
if args.classifier == "rf":
	cl = RandomForestClassifier()
	cln = "rf"
	params.update({
		"rf__n_estimators":[200],
	})
elif args.classifier == "svm":
	cl = SVC()
	cln = "svm"
	params.update({
		"svm__C":[0.1,1.,10.],
		"svm__kernel":["rbf", "linear"]
	})
elif args.classifier == "xgboost":
	cl = XGBClassifier()
	cln = "xgb"


########################## representations
if args.pipeline == "image":
	vectn = "img"
	vect = Module2Image()
	params.update({
		"img__bandwidth": [0.1, 0.01,0.001], # image \delta parameter
		"img__resolution": [[50,50]], # you increase the image resolution
		"img__degrees": [[0,1]], # homological degrees to consider
		"img__p" : [0,1], # p parameter of the image. Here 0 works better
		"img__normalize":[0,1], # normalize parameter of the image (i.e. sum vs mean)
	})
elif args.pipeline == "landscape":
	vectn = "landscape"
	vect = Module2Landscape()
	params.update({
		"landscape__degrees":[[0,1]], # homological degrees
		"landscape__resolution":[[50,50]], # resolution of each landscape
		"landscape__ks":[range(10), range(5), range(1)], #number of landscapes to consider
	})


## BELOW HERE ARE 1-PARAMETER PIPELINES
if args.pipeline == "landscape1d":
	if args.complex == "rips":
		params.update({"rips__max_dimension":[1]})
		params.update({"rips__num_collapse":[0]})
		num_collapse = 10
	else:
		num_collapse = 0
	params.update({
		"landscapes__num":[1,5,10],
		"landscapes__resolution" : [50,100],
		"to_gudhi__basepoint":[[-0.2,0.], [-0.1,0.], [0.,0.], [0.05,0.]]
	})
	pipeline = Pipeline([
		(cplxn, cplx_pipe),
		("dgm", SimplexTree2Diagram(degrees=[0,1], extended=False, n_jobs=n_jobs//4, progress=0)),
		("landscapes", Dgms2Landscapes(num=5, resolution=100, n_jobs=n_jobs//4)),
		(cln, cl)
	])
elif args.pipeline == "landscape1d2":
	if args.complex == "rips":
		params.update({"rips__max_dimension":[1]})
		params.update({"rips__num_collapse":[100]})
		num_collapse = 10
	else:
		num_collapse = 0
	params.update({
		"landscapes__num":[1,5,10],
		"landscapes__resolution" : [50,100],
		"to_gudhi__basepoint":[None],
		"to_gudhi__parameter":[0],
	})
	pipeline = Pipeline([
		(cplxn, cplx_pipe),
		("to_gudhi", SimplexTreeSlice(dtype=None, num_collapse=num_collapse, max_dimension=2)),
		("dgm", SimplexTree2Diagram(degrees=[0,1], extended=False, n_jobs=n_jobs//4, progress=0)),
		("landscapes", Dgms2Landscapes(num=5, resolution=100, n_jobs=n_jobs//4)),
		(cln, cl)
	])
	
elif args.pipeline == "landscape1dalpha":
	params = {
		"landscapes__num":[1,5,10,50],
		"landscapes__resolution":[50,100,500],
	}
	pipeline = Pipeline([
		("alpha", AlphaComplex(dtype=None)),
		("dgm", SimplexTree2Diagram(degrees=[0,1], extended=False, n_jobs=n_jobs//4, progress=0)),
		("landscapes", Dgms2Landscapes(num=5, resolution=100, n_jobs=n_jobs//4)),
		(cln, cl)
	])
else:
	pipeline = Pipeline([
		(cplxn, cplx_pipe),
		("mod", ToModule(dtype=None, n_jobs=n_jobs//4, nlines=2000)),
		(vectn, vect),
		(cln, cl)
	])



gridsearch_cl = GridSearchCV(estimator=pipeline, param_grid=params, n_jobs=n_jobs//4, cv=args.k, verbose=10)
print(f"Computing {pipeline} perfs...", flush=1)
to_csv(X=X, Y=labels, cl=gridsearch_cl, cln=f"{args.complex}_{args.pipeline}_{args.classifier}", dataset="immuno", k=args.k, verbose=True)




#rfc = RandomForestClassifier(n_estimators=200, n_jobs=8, random_state=42)
#svmpc = SVC(kernel="poly")
#svmlc = SVC(kernel="linear")
#svmexpc = SVC(kernel="rbf")
#xgbc = XGBClassifier()


#print("Computing CV", flush=1)
#for cl, cln, in zip([rfc,svmexpc, xgbc], ["RandomForest", "SVM", "XGBoost"]):
#	to_csv(X=landscapes, Y=labels, cl=cl, cln=cln, k=10)

## %%




