import mma
import numpy
import pickle
from sklearn.base import BaseEstimator, TransformerMixin
from joblib import Parallel, delayed

def _to_PyModule(x):
	if type(x) == mma.PyModule:
		return x
	elif type(x) == list:
		return mma.from_dump(x)
	elif type(x) == str:
		return mma.from_dump(pickle.load(open(x, "rb")))
	else:
		print("Cannot convert to PyModule")

class MMAImageWrapper(BaseEstimator, TransformerMixin):
	"""
	Scikit-Learn wrapper for cross-validating MMA Images.
	"""
	def __init__(self, bdw:float=0.1, power:float=1, normalize:bool=True, resolution:list=[50,50], plot:bool=False, dimensions:list=[0,1], qx:float=0, qy:float=0, box = [[0,0], [1,1]]):
		self.bdw, self.power, self.normalize, self.resolution, self.plot, self.dimensions = bdw, power, normalize, resolution, plot, dimensions
		self.qx = qx
		self.qy = qy
		self.box=box
	def fit(self, X, y=None):
		return self

	def transform(self, X):
		[mxf,myf], [Mxf,Myf] = self.box
		lx = Mxf - mxf
		ly = Myf - myf
		box_ = [[mxf+self.qx*lx, myf],[Mxf-self.qx*lx,Myf-self.qy*ly]]
		X = [_to_PyModule(x) for x in X]
		todo = lambda x : numpy.concatenate([x.image(
				bandwidth = self.bdw,
				p=self.power,
				normalize = self.normalize,
				plot=self.plot,
				cb=1,
				resolution = self.resolution,
				degree = d,
				box = box_
			) for d in self.dimensions]).flatten()
		
		# return Parallel(n_jobs=-1, prefer="threads")(delayed(todo)(x) for x in X)
		return [todo(x) for x in X]

class MMALandscapeWrapper(BaseEstimator, TransformerMixin):
	"""
	Scikit-Learn wrapper for cross-validating MMA landscapes.
	"""
	def __init__(self, ks=[0],  resolution:list=[50,50], plot:bool=False, dimensions:list=[0,1], box = [[0,0], [1,1]],qx=0, qy=0):
		self.resolution, self.plot, self.dimensions = resolution, plot, dimensions
		self.qx = qx
		self.qy = qy
		self.ks = ks
		self.box=box
	def fit(self, X, y=None):
		return self

	def transform(self, X):
		[mxf,myf], [Mxf,Myf] = self.box
		lx = Mxf - mxf
		ly = Myf - myf
		box_ = [[mxf+self.qx*lx, myf],[Mxf-self.qx*lx,Myf-self.qy*ly]]
		out = [numpy.concatenate([numpy.sum(_to_PyModule(x).landscapes(
				ks=self.ks,
				plot=self.plot,
				resolution = self.resolution,
				degree = d,
				box = box_
				),
		axis=0) for d in self.dimensions]).flatten() for x in X]
		return out
