"""
Turns graphs into persistence intervals. Filtration has to be given in th graph
"""

from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV
import gudhi as gd
from os.path import expanduser, exists
import networkx as nx
import pickle
from joblib import Parallel, delayed, cpu_count
import numpy as np
from tqdm import tqdm
from warnings import warn
from random import choice
import mma
from sklearn.neighbors import KernelDensity
from typing import Callable, Iterable
from os import walk
from pandas import read_csv
from sklearn.preprocessing import LabelEncoder
import matplotlib.pyplot as plt
from gudhi.representations import Landscape
from gudhi.representations.kernel_methods import SlicedWassersteinKernel
# import shapely
# import matplotlib 
# import sympy

DATASET_PATH = expanduser("~/Datasets/")

################################################################################################# DATASET GET / SET
def get_graph_dataset(dataset, label= "lb", N=100)->tuple[list[nx.Graph], list[int]]:
	if dataset == "SBM":	return get_sbm_dataset(N)
	from os import walk
	from scipy.io import loadmat
	from warnings import warn
	path = DATASET_PATH + dataset  +"/mat/"
	labels:list[int] = []
	gs:list[nx.Graph] = []
	for root, dir, files in walk(path):
		for file in files:
			file_ppties  = file.split("_")
			i=0
			while i+1 < len(file_ppties) and file_ppties[i] != label :
				i+=1
			if i+1 >= len(file_ppties):
				warn(f"Cannot find label {label} on file {file}.")
			else:
				labels += [file_ppties[i+1]]
			adj_mat = np.array(loadmat(path + file)['A'], dtype=np.float32)
			gs.append(nx.Graph(adj_mat))
	return gs, labels


def get_modelnet(version=10):
	from torch_geometric.transforms import FaceToEdge
	from torch_geometric.datasets import ModelNet
	return

def read_off_file(path:str, dtype=np.ndarray, distance = lambda x,y : np.linalg.norm(x-y)):
	file = open(path, "r")
	is_empty = lambda line : len(line) == 0 or line[0] == '#'
	lines:list[str] = file.readlines()
	i:int = 0
	nlines = len(lines)
	def next_line()->str | None:  # TODO : turn that into a real generator
		nonlocal i 
		if i > nlines:	return None
		line = lines[i].strip()		
		while (not line is None) and i < nlines and is_empty(line):	line = next_line()
		i = i+1
		if line is None:	warn("File ended too quickly.")
		# print(line)
		return line
	if 'OFF' != next_line():
		warn('Not a valid OFF header')
		return
	line = next_line()
	if line is None:
		warn("Non valid file")
		return
	n_vertices, n_faces, n_edges = tuple([int(s) for s in line.split(' ')])
	vertices = np.array([np.array([next_line().split(" ")], dtype=float) for _ in range(n_vertices)])
	faces = np.array([np.array(next_line().split(" "), dtype=int)[1:] for _ in range(n_faces)])
	if dtype is np.ndarray:
		return vertices, faces
	if dtype is gd.SimplexTree:
		simplextree = gd.SimplexTree()
		simplextree.insert_batch(faces.T, np.zeros(len(faces)))
		for s,_ in simplextree.get_simplices():
			if len(s) != 2:	continue
			i,j = s
			x = vertices[i]
			y = vertices[j]
			simplextree.assign_filtration(s, distance(x,y))
		simplextree.make_filtration_non_decreasing()
		return simplextree
	if dtype is nx.Graph:
		return


### Graphs
def get_sbm(n1,n2,p,q)->nx.Graph:
	rand = np.random.uniform
	edges = [[u,v] for u in range(n1) for v in range(n1)  if (rand()<p) and u < v] # edges of block 1
	edges += [[u+n1,v+n1] for u in range(n2) for v in range(n2)  if (rand()<p) and (u < v)] # edges of block 2
	edges += [[u,v+n1] for u in range(n1) for v in range(n2) if (rand()<q)] # interblock edges
	edges = np.array(edges)
	g = nx.Graph()
	for i in range(n1+n2):
		g.add_node(i)
	for e in edges:
		g.add_edge(*e)
	return g

def get_sbm_dataset(N=100, progress=False)->tuple[list[nx.Graph], list[int]]:
	graphs:list[nx.Graph] = Parallel(n_jobs=5)(delayed(get_sbm)(n1=100,n2=50,p=0.5,q= 0.1) for _ in tqdm(range(N), disable = not progress))
	labels = [0] * len(graphs)
	graphs += Parallel(n_jobs=5)(delayed(get_sbm)(n1=75,n2=75, p=0.4, q=0.2) for _ in tqdm(range(N), disable = not progress))
	labels += [1]*(len(graphs) - len(labels))
	return graphs, labels

#### graphs tools
def get_graphs(dataset:str, N:int|str="")->tuple[list[nx.Graph], list[int]]:
	graphs_path = f"{DATASET_PATH}{dataset}/graphs{N}.pkl"
	labels_path = f"{DATASET_PATH}{dataset}/labels{N}.pkl"
	if not exists(graphs_path) or not exists(labels_path):
		graphs, labels = get_graph_dataset(dataset,)
		print("Saving graphs at :", graphs_path)
		set_graphs(graphs = graphs, labels = labels, dataset = dataset)
	else:
		graphs = pickle.load(open(graphs_path, "rb"))
		labels = pickle.load(open(labels_path, "rb"))
	return graphs, labels


def set_graphs(graphs:list[nx.Graph], labels:list, dataset:str, N:int|str=""): # saves graphs (and filtration values) into a file
	graphs_path = f"{DATASET_PATH}{dataset}/graphs{N}.pkl"
	labels_path = f"{DATASET_PATH}{dataset}/labels{N}.pkl"
	pickle.dump(graphs, open(graphs_path, "wb"))
	pickle.dump(labels, open(labels_path, "wb"))
	return

def reset_graphs(dataset:str, N=100): # Resets filtrations values on graphs
	graphs, labels = get_graph_dataset(dataset, N=N)
	set_graphs(graphs,labels, dataset)
	return


############################################## Immuno 1.5mm
def get_immuno_regions():
	X, labels = [],[]
	path = DATASET_PATH+"1.5mmRegions/"
	for label in ["FoxP3", "CD8", "CD68"]:
	#     for label in ["FoxP3", "CD8"]:
		for root, dirs, files in walk(path + label+"/"):
			for name in files:
				X.append(np.array(read_csv(path+label+"/"+name))/1500)
				labels.append(label)
	labels = np.array(LabelEncoder().fit_transform(labels))
	p = np.random.permutation(len(labels))
	return [X[i] for i in p], np.array(labels)[p]


########################################## LARGE IMMUNO
def get_immuno(i=1):
	immu_dataset = read_csv(DATASET_PATH+f"LargeHypoxicRegion{i}.csv")
	X = np.array(immu_dataset['x'])
	X /= np.max(X)
	Y = np.array(immu_dataset['y'])
	Y /= np.max(Y)
	labels = LabelEncoder().fit_transform(immu_dataset['Celltype'])
	return X,Y, labels


############################################## UCR
def get_UCR_dataset(dataset = "Coffee", test = False):
	dataset_path = DATASET_PATH +"UCR/"+ dataset + "/" + dataset
	dataset_path +=  "_TEST.tsv" if test else "_TRAIN.tsv"
	data = np.array(read_csv(dataset_path, delimiter='\t', header=None, index_col=None))
	return data[:,1:-1], LabelEncoder().fit_transform(data[:,0])

################################################ Synthetic
def circle_pole(n_circle:int=500,n_noise:int=100, low:float= 1, high:float=1.1, k:int=3, sigma:float=0.5, pinch:bool=False)->np.ndarray:
	def pt()->np.ndarray:
		n = np.random.normal(loc=0,scale=sigma)
		r = np.sqrt(np.random.uniform(low = low, high = high**2)) - pinch*0.1/sigma*(1-np.abs(n))
		θ = np.random.choice(range(k)) * 2*np.pi / k + n
		return np.array([r*np.cos(θ), r* np.sin(θ)])
	out = np.vstack([np.array([pt() for _ in range(n_circle)], dtype=float),np.random.uniform(low=-1.2, high=1.2, size=(n_noise,2))])
	np.random.shuffle(out)
	return out

################################################################################################# WRAPPERS
from sklearn.base import BaseEstimator, TransformerMixin



############## INTERVALS (for sliced wasserstein)

#################### To module 
class ToModule(BaseEstimator, TransformerMixin):
	def __init__(self,n_jobs=-1,dtype:mma.PyModule|None=mma.PyModule, **kwargs) -> None:
		super().__init__()
		self.persistence_args = kwargs
		self.n_jobs=n_jobs
		self.dtype:mma.PyModule|None=dtype
		return		
	def fit(self, X, y=None):
		return self
	def transform(self,X)->list[mma.PyModule]:
		input_type = type(X[0])
		if input_type is mma.PyModule:
			module_list = X
		elif input_type is mma.SimplexTreeMulti:
			module_list = [x.persistence_approximation(**self.persistence_args) for x in X]
		elif input_type.__name__ == "function":
			assert len(X) == 2
			get_st = X[0]
			dataset = X[1]
			todo = lambda x : get_st(x).persistence_approximation(**self.persistence_args).dump()
			mods = Parallel(n_jobs=self.n_jobs)(delayed(todo)(x) for x in dataset)
			if self.dtype is None:
				return mods
			mods = [mma.from_dump(mod) for mod in mods]
			return mods
		else:
			if self.dtype is mma.PyModule:
				return [mma.from_dump(x) for x in X]
		
		if self.dtype is None:
			return [x.dump() for x in module_list]
		return module_list


################ Rips + Density simplextree
class RipsDensity2SimplexTree(BaseEstimator, TransformerMixin):
	def __init__(self, bandwidth:float=1, threshold:float=np.inf, sparse:float|None=None, num_collapse:int=100, max_dimension:int=None, num_parameters:int=2, kernel:str="gaussian", dtype=mma.SimplexTreeMulti, rescale_density:bool=True) -> None:
		super().__init__()
		self.bandwidth=bandwidth
		self.threshold = threshold
		self.sparse=sparse
		self.num_collapse=num_collapse
		self.max_dimension=max_dimension
		self.num_parameters = num_parameters
		self.kernel = kernel
		self.dtype=dtype
		self.rescale_density = rescale_density
		return
	def fit(self, X:np.ndarray|list, y=None):
		if len(X) == 0:	return self
		if self.max_dimension is None:
			self.max_dimension = len(X[0])
		return self
	def transform(self,X):
		kde:KernelDensity=KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel)
		def get_st(x)->mma.SimplexTreeMulti:
			st= gd.RipsComplex(points = x, max_edge_length=self.threshold, sparse=self.sparse).create_simplex_tree(max_dimension=1)
			st:mma.SimplexTreeMulti=mma.SimplexTreeMulti(st, num_parameters = self.num_parameters)
			kde.fit(x)
			codensity = -kde.score_samples(x)
			if self.rescale_density:
				codensity -= codensity.min()
				codensity /= max(1,codensity.max())
			st.fill_lowerstar(codensity, parameter = 1)
			st.collapse_edges(num=self.num_collapse)
			st.collapse_edges(num=self.num_collapse, strong = False, max_dimension = self.max_dimension)
			return st
		if self.dtype is mma.SimplexTreeMulti:
			return [get_st(x) for x in X] # No parallel possible here unless Gudhi serialize simplextrees.
		if self.dtype is None:
			return [get_st, X] # delay the computation for the to_module pipe
		warn(f"Bad ouput type {self.dtype} !")
		return



class AlphaDensity2ST(BaseEstimator, TransformerMixin):
	def __init__(self, bandwidth:float=1, threshold:float=np.inf, max_dimension:int=1, num_parameters:int=2, kernel:str="gaussian", dtype=mma.SimplexTreeMulti, rescale_density:bool=True) -> None:
		super().__init__()
		self.bandwidth=bandwidth
		self.threshold = threshold
		self.max_dimension=max_dimension
		self.num_parameters = num_parameters
		self.kernel = kernel
		self.n_jobs=-1
		self.dtype=dtype
		self.rescale_density = rescale_density
		return
	def fit(self, X:np.ndarray|list, y=None):
		return self
	def transform(self,X):
		kde:KernelDensity=KernelDensity(bandwidth=self.bandwidth, kernel=self.kernel)
		def get_st(x)->mma.SimplexTreeMulti:
			x = np.unique(x, axis=0)
			ac = gd.AlphaComplex(points = x)
			st= ac.create_simplex_tree(max_alpha_square = self.threshold**2)
			st:mma.SimplexTreeMulti=mma.SimplexTreeMulti(st, num_parameters = self.num_parameters)
			kde.fit(x)
			codensity = -kde.score_samples([ac.get_point(i) for i, _ in enumerate(x)])
			if self.rescale_density:
				codensity -= codensity.min()
				codensity /= max(1,codensity.max())
			st.fill_lowerstar(codensity, parameter = 1)
			return st
		if self.dtype is mma.SimplexTreeMulti:
			return [get_st(x) for x in X] # No parallel possible here unless Gudhi serialize simplextrees.
		if self.dtype is None:
			return [get_st, X] # delay the computation for the to_module pipe
		warn(f"Bad ouput type {self.dtype} !")
		return

# Multipers SimplexTree to Gudhi SimplexTree
class SimplexTreeSlice(BaseEstimator, TransformerMixin):
	def __init__(self, basepoint:np.ndarray|list[float]|None=None, parameter:int=0, dtype=gd.simplex_tree.SimplexTree, num_collapse:int=100, max_dimension:int=None) -> None:
		super().__init__()
		self.basepoint = basepoint
		self.parameter = parameter
		self.dimension:int|None= None if basepoint is None else len(basepoint)
		self.dtype = dtype
		self.num_collapse = num_collapse
		self.max_dimension = max_dimension
		return
	def fit(self, X:list[mma.SimplexTreeMulti], y=None):
		if len(X) == 0:	return self
		self.max_dimension:int = len(X[0]) if self.max_dimension is None else self.max_dimension
		return self
	def transform(self,X:list[mma.SimplexTreeMulti]):
		def _todo(st:mma.SimplexTreeMulti, basepoint=self.basepoint, parameter=self.parameter):
			sliced_st = st.project_on_line(basepoint=basepoint, parameter=parameter)
			sliced_st.collapse_edges(nb_iterations=self.num_collapse)
			sliced_st.expansion(self.max_dimension)
			return sliced_st
		if type(X[0]) is mma.SimplexTreeMulti:
			dataset = X
			todo = _todo
		else:
			old_todo, dataset = X
			todo = lambda x : _todo(old_todo(x)) # Delays once again the computation
		if self.dtype is None:
			return [todo, dataset]
		return [todo(st) for st in dataset]



############################## MULTIPERS IMGS
class Module2Image(BaseEstimator, TransformerMixin):
	def __init__(self, bandwidth:float=0.1, resolution=[50,50], normalize:bool=False, degrees:list[int] = [0,1], p:float=0, box=None, plot:bool=False, n_jobs=1) -> None:
		super().__init__()
		self.bandwidth:float=bandwidth
		self.resolution:list[int]=resolution
		self.normalize:bool=normalize
		self.p=p
		self.degrees = degrees
		self.box = box
		self.plot=plot
		self.n_jobs = n_jobs
		return
	def fit(self, X, y=None):
		if self.box is None:
			_mods = X if type(X[0]) is mma.PyModule else [mma.from_dump(mod) for mod in X]
			m = np.min([mod.get_bottom() for mod in _mods], axis=0)
			M = np.max([mod.get_top() for mod in _mods], axis=0)
			self.box=[m,M]
		return self
	def transform(self,X)->list[np.ndarray]:
		todo = lambda mod : np.concatenate([
			mod.image(bandwidth=self.bandwidth,p=self.p, resolution = self.resolution, normalize = self.normalize, degree=degree, plot=self.plot, box=self.box).flatten()
			for degree in self.degrees
		]).flatten()
		if type(X[0]) is mma.PyModule:
			return [todo(mod) for mod in X]
		else:
			dump_todo = lambda x : todo(mma.from_dump(x))
			return Parallel(n_jobs=self.n_jobs)(delayed(dump_todo)(mod) for mod in X)

################################# MULTIPERS LANDSCAPES
class Module2Landscape(BaseEstimator, TransformerMixin):
	def __init__(self, resolution=[100,100], degrees:list[int]|None = [0,1], ks:Iterable[int]=range(5), phi:Callable = np.sum, box=None, plot:bool=False, n_jobs=1) -> None:
		super().__init__()
		self.resolution:list[int]=resolution
		self.degrees = degrees
		self.ks=ks
		self.phi=phi # Has to have a axis=0 !
		self.box = box
		self.plot = plot
		self.n_jobs=n_jobs
		return
	def fit(self, X, y=None):
		if self.box is None:
			_mods = X if type(X[0]) is mma.PyModule else [mma.from_dump(mod) for mod in X]
			m = np.min([mod.get_bottom() for mod in _mods], axis=0)
			M = np.max([mod.get_top() for mod in _mods], axis=0)
			self.box=[m,M]
		return self
	def transform(self,X)->list[np.ndarray]:
		if len(X) <= 0:	return
		todo = lambda mod : np.concatenate([
				self.phi(mod.landscapes(ks=self.ks, resolution = self.resolution, degree=degree, plot=self.plot), axis=0).flatten()
				for degree in self.degrees
			]).flatten()
		if type(X[0]) is mma.PyModule:
			return [todo(mod) for mod in X]
		else:
			dump_todo = lambda x : todo(mma.from_dump(x))
			return Parallel(n_jobs=self.n_jobs)(delayed(dump_todo)(mod) for mod in X)

###################################### Euler surface
# TODO





###################################### OLD MULTIPERS IMAGE
# def multipersistence_image(decomposition, bnds=None, resolution=[100,100], return_raw=False, bandwidth=1., power=1., line_weight=lambda x: 1):
class Module2OldImage(BaseEstimator, TransformerMixin):
	def __init__(self, resolution=[100,100], degree:int = 1, nlines:int=100, box=None, plot:bool=False, bandwidth:float=1., power=1., n_jobs=1) -> None:
		super().__init__()
		self.resolution:list[int]=resolution
		self.degree = degree
		self.box = box
		self.nlines = nlines
		self.plot = plot
		self.bandwidth = bandwidth
		self.power = power
		self.n_jobs=n_jobs
		return
	def fit(self, mods:list[mma.PyModule], y=None):
		if self.box is None:
			_mods = mods if type(mods[0]) is mma.PyModule else [mma.from_dump(mod) for mod in mods]
			m = np.min([mod.get_bottom() for mod in _mods], axis=0)
			M = np.max([mod.get_top() for mod in _mods], axis=0)
			self.box=[m,M]
		return self
	def transform(self,X)->list[np.ndarray]:
		if len(X) <= 0:	return
		def todo(mod):
			bcs = mod.barcodes(num=self.nlines, degree = self.degree, threshold=1, box=self.box).to_multipers()
			multipers_img = multipersistence_image(bcs, bnds=np.array(self.box).flatten("F"), bandwidth = self.bandwidth, power = self.power, resolution=self.resolution)
			if self.plot:
				box= self.box
				extent = [box[0][0], box[1][0], box[0][1], box[1][1]]
				aspect = (box[1][0]-box[0][0]) / (box[1][1]-box[0][1])
				plt.imshow(multipers_img, extent=extent, aspect=aspect, origin="lower")
				plt.colorbar()
			return multipers_img.flatten()
		if type(X[0]) is mma.PyModule:
			return [todo(mod) for mod in X]
		else:
			dump_todo = lambda x : todo(mma.from_dump(x))
			return Parallel(n_jobs=self.n_jobs)(delayed(dump_todo)(mod) for mod in X)
 



def get_multiper_imgs(mod:mma.PyModule | list,degree=1, box:None|np.ndarray | list=None, nlines=100, plot:bool=False, **kwargs)->np.ndarray:
    if mod is list:
        mod:mma.PyModule = mma.from_dump(mod)
    if box is None:
        box = mod.get_box()
    bcs = mod.barcodes(num=nlines, degree = degree, threshold=1, box=box).to_multipers()
    multipers_img = multipersistence_image(bcs, bnds=np.array(box).flatten("F"), **kwargs)
    if plot:
        extent = [box[0][0], box[1][0], box[0][1], box[1][1]]
        aspect = (box[1][0]-box[0][0]) / (box[1][1]-box[0][1])
        plt.imshow(multipers_img, extent=extent, aspect=aspect, origin="lower")
        plt.colorbar()
    return multipers_img

############################################################################################### ACCURACIES HELPERS
def kfold_acc(cls,x,y, k:int=10, clsn=None):
	if clsn is None:
		clsn = range(len(cls))
	from sklearn.model_selection import StratifiedKFold as sKFold
	accuracies = np.zeros((len(cls), k))
	for i,(train_idx, test_idx) in enumerate(tqdm(sKFold(k, shuffle=True).split(x,y), total=k, desc="Computing kfold")):
		for j, cl in enumerate(cls):
			xtrain = [x[i] for i in train_idx]
			ytrain = [y[i] for i in train_idx]
			cl.fit(xtrain, ytrain)
			xtest = [x[i] for i in test_idx]
			ytest = [y[i] for i in test_idx] 
			accuracies[j][i] = cl.score(xtest, ytest)
	return [f"Classifier {cl_name} : {np.mean(acc*100).round(decimals=3)}% ±{np.std(acc*100).round(decimals=3)}" for cl_name,acc in zip(clsn, accuracies)]
	


def kfold_to_csv(X,Y,cl, cln:str, k:int=10, dataset:str = "", filtration:str = "", verbose:bool=True):
	import pandas as pd
	from sklearn.model_selection import KFold
	kfold = KFold(k, shuffle=True).split(X,Y)
	accuracies = np.zeros(k)
	for i,(train_idx, test_idx) in enumerate(tqdm(kfold, total=k, desc="Computing kfold")):
		xtrain = [X[i] for i in train_idx]
		ytrain = [Y[i] for i in train_idx]
		cl.fit(xtrain, ytrain)
		xtest = [X[i] for i in test_idx]
		ytest = [Y[i] for i in test_idx] 
		accuracies[i] = cl.score(xtest, ytest)
		if verbose:	print(f"step {i} : {accuracies[i]}", flush=True)
	file_path:str = f"result_{dataset}.csv"
	columns:list[str] = ["dataset", "filtration", "pipeline", "cv", "mean", "std"]
	if exists(file_path):
		df:pd.DataFrame = pd.read_csv(file_path)
	else:
		df:pd.DataFrame = pd.DataFrame(columns= columns)
	new_line:pd.DataFrame = pd.DataFrame([[dataset, filtration, cln, k, np.mean(accuracies), np.std(accuracies)]], columns = columns)
	df = pd.concat([df, new_line])
	df.to_csv(file_path, index=False)





def to_csv(X,Y,cl, cln:str, k:int=10, dataset:str = "", filtration:str = "", verbose:bool=True):
	import pandas as pd
	from sklearn.model_selection import KFold
	kfold = KFold(k, shuffle=True).split(X,Y)
	accuracies = np.zeros(k)
	for i,(train_idx, test_idx) in enumerate(tqdm(kfold, total=k, desc="Computing kfold")):
		xtrain = [X[i] for i in train_idx]
		ytrain = [Y[i] for i in train_idx]
		cl.fit(xtrain, ytrain)
		xtest = [X[i] for i in test_idx]
		ytest = [Y[i] for i in test_idx]
		accuracies[i] = cl.score(xtest, ytest)
		if verbose:	print(f"step {i} : {accuracies[i]}", flush=True)
	file_path:str = f"result_{dataset}.csv"
	columns:list[str] = ["dataset", "filtration", "pipeline", "cv", "mean", "std"]
	if exists(file_path):
		df:pd.DataFrame = pd.read_csv(file_path)
	else:
		df:pd.DataFrame = pd.DataFrame(columns= columns)
	new_line:pd.DataFrame = pd.DataFrame([[dataset, filtration, cln, k, np.mean(accuracies), np.std(accuracies)]], columns = columns)
	df = pd.concat([df, new_line])
	df.to_csv(file_path, index=False)











######################################### MULTIPERS

import sys
import os
import numpy as np
import matplotlib.pyplot as plt
import matplotlib
import time
import gudhi as gd
from sklearn.preprocessing import MinMaxScaler
from sklearn.base import BaseEstimator, TransformerMixin
import gudhi.representations as sktda
from sklearn.neighbors import KDTree
import math
import random
from concurrent import futures
from joblib import Parallel, delayed

#from dionysus_vineyards import ls_vineyards as lsvine
#from custom_vineyards import ls_vineyards as lsvine

def DTM(X,query_pts,m):
	"""
	Code for computing distance to measure. Taken from https://github.com/GUDHI/TDA-tutorial/blob/master/Tuto-GUDHI-DTM-filtrations.ipynb
	"""    
	N_tot = X.shape[0]
	k = math.floor(m*N_tot)+1
	kdt = KDTree(X, leaf_size=30, metric='euclidean')
	NN_Dist, NN = kdt.query(query_pts, k, return_distance=True)
	DTM_result = np.sqrt(np.sum(NN_Dist*NN_Dist,axis=1) / k)
	return(DTM_result)


def recursive_insert(st, base_splx, splx, name_dict, filt):
	if len(splx) == 1:	st.insert(base_splx + [name_dict[tuple(splx)]], filt)
	else:
		for idx in range(len(splx)):
			coface = splx[:idx] + splx[idx+1:]
			recursive_insert(st, base_splx + [name_dict[tuple(splx)]], coface, name_dict, max(filt, st.filtration([name_dict[tuple(coface)]])))

def barycentric_subdivision(st, list_splx=None, use_sqrt=False, use_mean=False):
	"""
	Code for computing the barycentric subdivision of a Gudhi simplex tree.
	
	Inputs:
		st: input simplex tree
		list_splx: a list of simplices of st (useful if you want to give specific names for the barycentric subdivision simplices)
		use_sqrt: whether to take the square root for the barycentric subdivision filtration values (useful if st was computed from Gudhi AlphaComplex for instance)
		use_mean: whether to take the mean of the vertices for the new vertex value (useful for refining lower star for instance)
	Outputs:
		bary: barycentric subdivision of st
	"""
	bary = gd.SimplexTree()
	bary_splx = {}

	splxs = st.get_filtration() if list_splx is None else list_splx 
	count = 0
	for splx, f in splxs:
		if use_sqrt:	bary.insert([count], np.sqrt(f))
		elif use_mean:	bary.insert([count], np.mean([st.filtration([v]) for v in splx]))
		else:	bary.insert([count], f)
		bary_splx[tuple(splx)] = count
		count += 1

	for splx, f in st.get_filtration():
		if len(splx) == 1:	continue
		else:
			recursive_insert(bary, [], splx, bary_splx, bary.filtration([bary_splx[tuple(splx)]]))

	return bary

def gudhi_line_diagram(simplextree, F, homology=0, extended=False, essential=False, mode="Gudhi"):
	"""
	Wrapper code for computing the lower-star filtration of a simplex tree.

	Inputs:
		simplextree: input simplex tree
		F: array containing the filtration values of the simplex tree vertices
		homology: homological dimension
		extended: whether to compute extended persistence
		essential: whether to keep features with infinite ordinates
		mode: method for giving the simplex tree: either in the native Gudhi format ("Gudhi"), or as a list of numpy arrays containing the simplices in each dimension ("Numpy")
	Outputs:
		dgm: a Numpy array containing the persistence diagram points
	"""
	if mode == "Gudhi":
		st = gd.SimplexTree()
		for (s,_) in simplextree.get_filtration():	st.insert(s)
	elif mode == "Numpy":
		st = gd.SimplexTree()
		for ls in simplextree:
			for s in range(len(ls)):	st.insert([v for v in ls[s,:]])
	
	for (s,_) in st.get_filtration():	st.assign_filtration(s, -1e10)
	for (v,_) in st.get_skeleton(0):	st.assign_filtration(v, F[v[0]])

	st.make_filtration_non_decreasing()
	if extended:	
		st.extend_filtration()
		dgms = st.extended_persistence(min_persistence=1e-10)
		dgms = [dgm for dgm in dgms if len(dgm) > 0]
		ldgm = [[np.array([[    min(pt[1][0], pt[1][1]), max(pt[1][0], pt[1][1])    ]]) for pt in dgm if pt[0] == homology] for dgm in dgms]
		ldgm = [ np.vstack(dgm) for dgm in ldgm if len(dgm) > 0]
		dgm = np.vstack(ldgm) if len(ldgm) > 0 else np.empty([0,2])
	else:		
		st.persistence()
		dgm = st.persistence_intervals_in_dimension(homology)

	if not essential:	dgm = dgm[np.ravel(np.argwhere(dgm[:,1] != np.inf)),:]
	return dgm


def gudhi_matching(dgm1, dgm2):
	"""
	Code for computing the matching associated to the 2-Wasserstein distance.

	Inputs:
		dgm1: first persistence diagram
		dgm2: second persistence diagram
	Outputs:
		mtc: a Numpy array containing the partial matching between the inputs
	"""
	import gudhi.wasserstein
	f1, i1, f2, i2 = np.ravel(np.argwhere(dgm1[:,1] != np.inf)), np.ravel(np.argwhere(dgm1[:,1] == np.inf)), np.ravel(np.argwhere(dgm2[:,1] != np.inf)), np.ravel(np.argwhere(dgm2[:,1] == np.inf))
	dgmf1, dgmi1 = dgm1[f1], dgm1[i1]
	dgmf2, dgmi2 = dgm2[f2], dgm2[i2]
	mtci = np.hstack([i1[np.argsort(dgmi1[:,0])][:,np.newaxis], i2[np.argsort(dgmi2[:,0])][:,np.newaxis]])
	_, mtcff = gd.wasserstein.wasserstein_distance(dgmf1, dgmf2, matching=True, order=1, internal_p=2)
	mtcf = []
	for i in range(len(mtcff)):
		if mtcff[i,0] == -1:	mtcf.append(np.array([[ -1, f2[mtcff[i,1]] ]]))
		elif mtcff[i,1] == -1:	mtcf.append(np.array([[ f1[mtcff[i,0]], -1 ]]))
		else:	mtcf.append(np.array([[ f1[mtcff[i,0]], f2[mtcff[i,1]] ]]))
	mtc = np.vstack([mtci, np.vstack(mtcf)]) if len(mtcf) > 0 else mtci
	return mtc



def intersect_boundaries(summand, bnds, visu=False):

	xm, xM, ym, yM = bnds[0], bnds[1], bnds[2], bnds[3]

	# Select good bars
	good_idxs = np.argwhere(np.abs(summand[:,1]-summand[:,0]) > 0.)[:,0]
	summand = summand[good_idxs]
	good_idxs = np.argwhere(np.abs(summand[:,3]-summand[:,2]) > 0.)[:,0]
	summand = summand[good_idxs]
		
	# Compute intersection with boundaries
	Ts = np.hstack([ np.multiply( xm-summand[:,0:1], 1./(summand[:,1:2]-summand[:,0:1]) ),
			 np.multiply( xM-summand[:,0:1], 1./(summand[:,1:2]-summand[:,0:1]) ),
			 np.multiply( ym-summand[:,2:3], 1./(summand[:,3:4]-summand[:,2:3]) ),
			 np.multiply( yM-summand[:,2:3], 1./(summand[:,3:4]-summand[:,2:3]) ) ])
	Ts = np.hstack([ np.minimum(Ts[:,0:1], Ts[:,1:2]), np.maximum(Ts[:,0:1], Ts[:,1:2]), np.minimum(Ts[:,2:3], Ts[:,3:4]), np.maximum(Ts[:,2:3], Ts[:,3:4]) ])
	Ts = np.hstack([ np.maximum(Ts[:,0:1], Ts[:,2:3]), np.minimum(Ts[:,1:2], Ts[:,3:4]) ])
	good_idxs = np.argwhere(Ts[:,1]-Ts[:,0] > 0.)[:,0]
	summand, Ts = summand[good_idxs], Ts[good_idxs]
	good_idxs = np.argwhere(Ts[:,0] < 1.)[:,0]
	summand, Ts = summand[good_idxs], Ts[good_idxs]
	good_idxs = np.argwhere(Ts[:,1] > 0.)[:,0]
	summand, Ts = summand[good_idxs], Ts[good_idxs]
	Ts = np.hstack([  np.maximum(Ts[:,0:1], np.zeros(Ts[:,0:1].shape)), np.minimum(Ts[:,1:2], np.ones(Ts[:,0:1].shape))  ])
	P1x, P2x, P1y, P2y = summand[:,0:1], summand[:,1:2], summand[:,2:3], summand[:,3:4]
	Ta, Tb = Ts[:,0:1], Ts[:,1:2]
	summand = np.hstack([  np.multiply(1.-Ta, P1x) + np.multiply(Ta, P2x), 
                               np.multiply(1.-Tb, P1x) + np.multiply(Tb, P2x), 
                               np.multiply(1.-Ta, P1y) + np.multiply(Ta, P2y), 
                               np.multiply(1.-Tb, P1y) + np.multiply(Tb, P2y)  ])

	if visu:
		plt.figure()
		for i in range(len(summand)):	plt.plot([summand[i,0], summand[i,1]], [summand[i,2], summand[i,3]], c="red")
		plt.xlim(xm-1., xM+1.)
		plt.ylim(ym-1., yM+1.)
		plt.show()

	return summand










def persistence_image(dgm, bnds, resolution=[100,100], return_raw=False, bandwidth=1., power=1.):
	"""
	Code for computing 1D persistence images.
	"""
	xm, xM, ym, yM = bnds[0], bnds[1], bnds[2], bnds[3]
	x = np.linspace(xm, xM, resolution[0])
	y = np.linspace(ym, yM, resolution[1])
	X, Y = np.meshgrid(x, y)
	Zfinal = np.zeros(X.shape)
	X, Y = X[:,:,np.newaxis], Y[:,:,np.newaxis]

	# Compute image
	P0, P1 = np.reshape(dgm[:,0], [1,1,-1]), np.reshape(dgm[:,1], [1,1,-1])
	weight = np.abs(P1-P0)
	distpts = np.sqrt((X-P0)**2+(Y-P1)**2)

	if return_raw:
		lw = [weight[0,0,pt] for pt in range(weight.shape[2])]
		lsum = [distpts[:,:,pt] for pt in range(distpts.shape[2])]
	else:
		weight = weight**power
		Zfinal = (np.multiply(weight, np.exp(-distpts**2/bandwidth))).sum(axis=2)

	output = [lw, lsum] if return_raw else Zfinal
	return output

class PersistenceImageWrapper(BaseEstimator, TransformerMixin):
	"""
	Scikit-Learn wrapper for cross-validating 1D persistence images.
	"""
	def __init__(self, bdw=1., power=0, step=1):
		self.bdw, self.power, self.step = bdw, power, step

	def fit(self, X, y=None):
		return self

	def transform(self, X):
		final = []
		bdw = self.bdw
		for nf in range(len(X[0])):
			XX = [X[idx][nf] for idx in range(len(X))]
			lpi = []
			for idx, _ in enumerate(XX):
				im = sum([(XX[idx][0][i]**self.power)*np.exp(-XX[idx][1][i]**2/bdw) for i in range(len(XX[idx][0]))]) if len(XX[idx][0]) > 0 else np.zeros([50,50])
				im = np.reshape(im, [1,-1])
				lpi.append(im)
			Y = np.vstack(lpi)
			res = int(np.sqrt(Y.shape[1]))
			nr = int(res/self.step)
			Y = np.reshape(np.transpose(np.reshape(np.transpose(np.reshape(Y,[-1,res,nr,self.step]).sum(axis=3),(0,2,1)),[-1,nr,nr,self.step]).sum(axis=3),(0,2,1)),[-1,nr**2])
			final.append(Y)

		return np.hstack(final)










def multipersistence_image(decomposition, bnds=None, resolution=[100,100], return_raw=False, bandwidth=1., power=1., line_weight=lambda x: 1):
	"""
	Code for computing Multiparameter Persistence Images.

	Inputs:
		decomposition: vineyard decomposition as provided with interlevelset_multipersistence or sublevelset_multipersistence
		bnds: bounding rectangle
		resolution: number of pixels for each image axis
		return_raw: whether to return the raw images and weights for each summand (useful to save time when cross validating the image parameters) or the usual image
		bandwidth: image bandwidth
		power: exponent for summand weight
		line_weight: weight function for the lines in the vineyard decomposition

	Outputs:
		image (as a numpy array) if return_raw is False otherwise list of images and weights for each summand 
	"""
	if np.all(~np.isnan(np.array(bnds))):
		bnds = bnds
	else:
		full = np.vstack(decomposition)
		maxs, mins = full.max(axis=0), full.min(axis=0)
		bnds = list(np.where(np.isnan(np.array(bnds)), np.array([min(mins[0],mins[1]), max(maxs[0],maxs[1]), min(mins[2],mins[3]), max(maxs[2],maxs[3])]), np.array(bnds)))

	xm, xM, ym, yM = bnds[0], bnds[1], bnds[2], bnds[3]
	x = np.linspace(xm, xM, resolution[0])
	y = np.linspace(ym, yM, resolution[1])
	X, Y = np.meshgrid(x, y)
	Zfinal = np.zeros(X.shape)
	X, Y = X[:,:,np.newaxis], Y[:,:,np.newaxis]

	if return_raw:	lw, lsum = [], []

	for summand in decomposition:

		summand = intersect_boundaries(summand, bnds)

		# Compute weight
		if return_raw or power > 0.:
			bars   = np.linalg.norm(summand[:,[0,2]]  -summand[:,[1,3]],  axis=1)		
			consm  = np.linalg.norm(summand[:-1,[0,2]]-summand[1:,[0,2]], axis=1)
			consM  = np.linalg.norm(summand[:-1,[1,3]]-summand[1:,[1,3]], axis=1)
			diags  = np.linalg.norm(summand[:-1,[0,2]]-summand[1:,[1,3]], axis=1)
			s1, s2 = .5 * (bars[:-1] + diags + consM), .5 * (bars[1:] + diags + consm)
			weight = np.sum(np.sqrt(np.abs(np.multiply(np.multiply(np.multiply(s1,s1-bars[:-1]),s1-diags),s1-consM)))+np.sqrt(np.abs(np.multiply(np.multiply(np.multiply(s2,s2-bars[1:]),s2-diags),s2-consm))))
			weight /= ((xM-xm)*(yM-ym))	
		else:	weight = 1.

		# Compute image
		P00, P01, P10, P11 = np.reshape(summand[:,0], [1,1,-1]), np.reshape(summand[:,2], [1,1,-1]), np.reshape(summand[:,1], [1,1,-1]), np.reshape(summand[:,3], [1,1,-1])
		good_xidxs, good_yidxs = np.argwhere(P00 != P10), np.argwhere(P01 != P11)
		good_idxs = np.unique(np.reshape(np.vstack([good_xidxs[:,2:3], good_yidxs[:,2:3]]), [-1]))
		if len(good_idxs) > 0:
			P00, P01, P10, P11 = P00[:,:,good_idxs], P01[:,:,good_idxs], P10[:,:,good_idxs], P11[:,:,good_idxs]
			vectors = [ P10[0,0,:]-P00[0,0,:], P11[0,0,:]-P01[0,0,:] ]
			vectors = np.hstack([v[:,np.newaxis] for v in vectors])
			norm_vectors = np.linalg.norm(vectors, axis=1)
			unit_vectors = np.multiply( vectors, 1./norm_vectors[:,np.newaxis] )
			W = np.array([line_weight(unit_vectors[i,:]) for i in range(len(unit_vectors))])
			T = np.maximum(np.minimum(np.multiply(np.multiply(P00-X,P00-P10)+np.multiply(P01-Y,P01-P11),1./np.square(np.reshape(norm_vectors,[1,1,-1])) ),1),0)
			distlines = np.sqrt((X-P00+np.multiply(T,P00-P10))**2+(Y-P01+np.multiply(T,P01-P11))**2 )
			Zsummand = distlines.min(axis=2)
			arglines = np.argmin(distlines, axis=2)
			weightlines = W[arglines]
			#Zsummand = np.multiply(Zsummand, weightlines)
			if return_raw:
				lw.append(weight)
				lsum.append(Zsummand)
			else:
				weight = weight**power
				Zfinal += weight * np.multiply(np.exp(-Zsummand**2/bandwidth), weightlines)

	output = [lw, lsum] if return_raw else Zfinal
	return output


def convert_summand(summand):
	num_lines = len(summand)
	dimension = len(summand[0][0])
	#assert dimension == 2
	new_summand = np.zeros((num_lines,4))
	for i in range(num_lines):
		if summand[i][0] == []:	continue
		new_summand[i,0] = summand[i][0][0]
		new_summand[i,1] = summand[i][1][0]
		new_summand[i,2] = summand[i][0][1]
		new_summand[i,3] = summand[i][1][1]
	return new_summand

def convert_barcodes(barcodes):
	new_barcodes = []
	for matching in barcodes:
		new_barcodes.append(convert_summand(matching))
	return new_barcodes





class MultiPersistenceImageWrapper(BaseEstimator, TransformerMixin):
	"""
	Scikit-Learn wrapper for cross-validating Multiparameter Persistence Images.
	"""
	def __init__(self, bdw=1., power=0, step=1):
		self.bdw, self.power, self.step = bdw, power, step

	def fit(self, X, y=None):
		return self

	def transform(self, X):
		final = []
		bdw = self.bdw
		for nf in range(len(X[0])):
			XX = [X[idx][nf] for idx in range(len(X))]
			lmpi = []
			for idx, _ in enumerate(XX):
				im = sum([(XX[idx][0][i]**self.power)*np.exp(-XX[idx][1][i]**2/bdw) for i in range(len(XX[idx][0]))]) if len(XX[idx][0]) > 0 else np.zeros([50,50])
				im = np.reshape(im, [1,-1])
				lmpi.append(im)
			Y = np.vstack(lmpi)
			res = int(np.sqrt(Y.shape[1]))
			nr = int(res/self.step)	
			Y = np.reshape(np.transpose(np.reshape(np.transpose(np.reshape(Y,[-1,res,nr,self.step]).sum(axis=3),(0,2,1)),[-1,nr,nr,self.step]).sum(axis=3),(0,2,1)),[-1,nr**2])
			final.append(Y)

		return np.hstack(final)










def multipersistence_landscape(decomposition, bnds, delta, resolution=[100,100], k=None, return_raw=False, power=1., **kwargs):
	"""
	Code for computing Multiparameter Persistence Landscapes.

	Inputs:
		decomposition: decomposition as provided with sublevelset_multipersistence with corner=="dg"
		bnds: bounding rectangle
		delta: distance between consecutive lines in the vineyard decomposition. It can be computed from the second output of sublevelset_multipersistence
		resolution: number of pixels for each landscape axis
		k: number of landscapes
		return_raw: whether to return the raw landscapes and weights for each summand (useful to save time when cross validating the landscape parameters) or the usual landscape
		power: exponent for summand weight (useful if silhouettes are computed)

	Outputs:
		landscape (as a numpy array) if return_raw is False otherwise list of landscapes and weights for each summand 
	"""
	if np.all(~np.isnan(np.array(bnds))):
		bnds = bnds
	else:
		full = np.vstack(decomposition)
		maxs, mins = full.max(axis=0), full.min(axis=0)
		bnds = list(np.where(np.isnan(np.array(bnds)), np.array([min(mins[0],mins[1]), max(maxs[0],maxs[1]), min(mins[2],mins[3]), max(maxs[2],maxs[3])]), np.array(bnds)))

	xm, xM, ym, yM = bnds[0], bnds[1], bnds[2], bnds[3]
	x = np.linspace(xm, xM, resolution[0])
	y = np.linspace(ym, yM, resolution[1])
	X, Y = np.meshgrid(x, y)
	X, Y = X[:,:,np.newaxis], Y[:,:,np.newaxis]
	mesh = np.reshape(np.concatenate([X,Y], axis=2), [-1,2])

	final = []

	if len(decomposition) > 0:

		agl = np.sort(np.unique(np.concatenate([summand[:,4] for summand in decomposition])))

		for al in agl:

			tris = np.vstack( [ summand[np.argwhere(summand[:,4] == int(al))[:,0]] for summand in decomposition ] )

			if len(tris) > 0:

				tris = intersect_boundaries(tris, bnds)
				P1x, P2x, P1y, P2y = tris[:,0:1], tris[:,1:2], tris[:,2:3], tris[:,3:4]
				bars = np.linalg.norm(np.hstack([P2x-P1x, P2y-P1y]), axis=1)
				good_idxs = np.argwhere(bars > 0)[:,0]
				P1x, P2x, P1y, P2y, bars = P1x[good_idxs], P2x[good_idxs], P1y[good_idxs], P2y[good_idxs], np.reshape(bars[good_idxs], [1,-1])
				e1s, e2s = np.array([[ -delta/2, delta/2 ]]), np.hstack([ P2x-P1x, P2y-P1y ])
				e1s = np.reshape(np.multiply(e1s, 1./np.linalg.norm(e1s, axis=1)[:,np.newaxis]).T, [1,2,-1])
				e2s = np.reshape(np.multiply(e2s, 1./np.linalg.norm(e2s, axis=1)[:,np.newaxis]).T, [1,2,-1])
				pts = mesh[:,:,np.newaxis] - np.reshape(np.hstack([P1x+delta/4, P1y-delta/4]).T, [1,2,-1])
				scal1, scal2 = np.multiply(pts, e1s).sum(axis=1), np.multiply(pts, e2s).sum(axis=1)
				output = np.where( (scal1 >= 0) & (scal1 < np.sqrt(2)*delta/2) & (scal2 >= 0) & (scal2 <= bars), np.minimum(scal2, bars-scal2), np.zeros(scal2.shape))
				LS = np.reshape(output, [X.shape[0], X.shape[1], len(P1x)])

				if k is None:
					if return_raw:	final.append([LS, bars])
					else:	final.append( np.multiply(LS, np.reshape(bars**power, [1,1,-1])).sum(axis=2) )
				else:
					pLS = np.concatenate([np.zeros([LS.shape[0], LS.shape[1], 1]), LS], axis=2)
					num = LS.shape[2]
					final.append(np.concatenate([np.partition(pLS, kth=max(num-(kk-1),0), axis=2)[:,:,max(num-(kk-1),0):max(num-(kk-1),0)+1] for kk in range(1,k+1) ], axis=2))

		if k is None:
			if return_raw:	return final
			else:	return np.maximum.reduce(final)
		else:	return  np.maximum.reduce(final)

	else:
		if k is None:
			if return_raw:	return [  [np.zeros([X.shape[0], X.shape[1], 1]), np.zeros([1])]  ]
			else:	np.zeros([X.shape[0], X.shape[1]])
		else:	return np.zeros([X.shape[0], X.shape[1], 1])

class MultiPersistenceLandscapeWrapper(BaseEstimator, TransformerMixin):
	"""
	Scikit-Learn wrapper for cross-validating Multiparameter Persistence Landscapes.
	"""
	def __init__(self, power=0, step=1, k=None):
		self.power, self.step, self.k = power, step, k

	def fit(self, X, y=None):
		return self

	def transform(self, X):
		final = []
		for nf in range(len(X[0])):
			XX = [X[idx][nf] for idx in range(len(X))]
			if self.k is None:
				Y = np.vstack([  np.maximum.reduce([np.multiply(im, np.reshape(w**self.power, [1,1,-1])).sum(axis=2).flatten()[np.newaxis,:] for [im,w] in L])  for L in XX  ])
			else:
				Y = np.vstack([  LS[:,:,:self.k].sum(axis=2).flatten()[np.newaxis,:] for LS in XX  ])
			res = int(np.sqrt(Y.shape[1]))
			nr = int(res/self.step)	
			Y = np.reshape(np.transpose(np.reshape(np.transpose(np.reshape(Y,[-1,res,nr,self.step]).sum(axis=3),(0,2,1)),[-1,nr,nr,self.step]).sum(axis=3),(0,2,1)),[-1,nr**2])
			final.append(Y)
		return np.hstack(final)










def extract_diagrams(decomposition, bnds, lines):
	"""
	Code for extracting all persistence diagrams from a decomposition.

	Inputs:
		decomposition: decomposition as provided with interlevelset_multipersistence or sublevelset_multipersistence
		bnds: bounding rectangle
		lines: lines used for computing decompositions

	Outputs:
		ldgms: list of persistence diagrams		
	"""
	if len(decomposition) > 0:

		mdgm = np.vstack(decomposition)
		agl = np.arange(len(lines))
		ldgms, limits = [], []

		for al in agl:

			dg = []
			idxs = np.argwhere(mdgm[:,4] == al)[:,0]
			if len(idxs) > 0:	dg.append(mdgm[idxs][:,:4])
			if len(dg) > 0:
				dg = np.vstack(dg)
				dg = intersect_boundaries(dg, bnds)
				if len(dg) > 0:
					xalpha, yalpha, xAlpha, yAlpha = lines[al,0], lines[al,1], lines[al,2], lines[al,3]
					pt = np.array([[xalpha, yalpha]])
					st, ed = dg[:,[0,2]], dg[:,[1,3]]
					dgm = np.hstack([ np.linalg.norm(st-pt, axis=1)[:,np.newaxis], np.linalg.norm(ed-pt, axis=1)[:,np.newaxis] ])
				else:	dgm = np.array([[.5*(bnds[0]+bnds[1]), .5*(bnds[2]+bnds[3])]])
			else:	dgm = np.array([[.5*(bnds[0]+bnds[1]), .5*(bnds[2]+bnds[3])]])
	
			ldgms.append(dgm)
	else:	ldgms = [np.array([[.5*(bnds[0]+bnds[1]), .5*(bnds[2]+bnds[3])]]) for _ in range(len(lines))]

	return ldgms

def multipersistence_kernel(X, Y, lines, kernel, line_weight=lambda x: 1, same=False, metric=True, return_raw=False, power=1.):
	"""
	Code for computing Multiparameter Persistence Kernel.

	Inputs:
		X: first list of persistence diagrams extracted from decompositions as provided with extract_diagrams
		Y: second list of persistence diagrams extracted from decompositions as provided with extract_diagrams
		lines: lines used for computing decompositions, as provided with interlevelset_multipersistence or sublevelset_multipersistence
		kernel: kernel function between persistence diagrams if metric == True otherwise CNSD distance function between persistence diagrams 
		line_weight: weight function for the lines in the decomposition
		same: are X and Y the same list?
		metric: do you want to use CNSD distances or direct kernels?
		return_raw: whether to return the raw kernel matrices and weights for each line (useful to save time when cross validating the multiparam. kernel) or the usual kernel matrix
		power: exponent for line weight

	Outputs:
		kernel matrix (as a numpy array) if return_raw is False otherwise list of matrices and weights for each line 
	"""
	M = np.zeros([len(X), len(Y), len(lines)])
	vectors = np.hstack([ lines[:,2:3]-lines[:,0:1], lines[:,3:4]-lines[:,1:2]])
	unit_vectors = np.multiply(vectors, 1./np.linalg.norm(vectors, axis=1)[:,np.newaxis])
	W = np.zeros([len(lines)])

	for l in range(len(lines)):
		W[l] = line_weight(unit_vectors[l,:])
		if same:
			for i in range(len(X)):
				ldgmsi = X[i]
				for j in range(i, len(X)):
					#print(i,j)
					ldgmsj = X[j]
					M[i,j,l] = kernel(ldgmsi[l], ldgmsj[l])
					M[j,i,l] = M[i,j,l]
		else:
			for i in range(len(X)):
				ldgmsi = X[i]
				for j in range(len(Y)):
					ldgmsj = Y[j]
					M[i,j,l] = kernel(ldgmsi[l], ldgmsj[l])
	if metric:
		med = 1 if np.median(M) == 0 else np.median(M)
	if not return_raw:
		if metric:	return np.multiply( W[np.newaxis, np.newaxis, :]**power, np.exp(-M/med) ).sum(axis=2)
		else:	return np.multiply( W[np.newaxis, np.newaxis, :]**power, M).sum(axis=2)
	else:	
		if metric:	return np.exp(-M/med), W
		else:	return M, W


class SubsampleWrapper(BaseEstimator, TransformerMixin):
	"""
	Scikit-Learn wrapper for cross-validating resolutions.
	"""
	def __init__(self, step=1):
		self.step = step

	def fit(self, X, y=None):
		return self

	def transform(self, X):
		return X[:,::self.step]

