""" Data generation for autism SMART Trial example

Off-policy Policy Evaluation Under Unobserved Confounding
Ramtin keramati, Steve Yadlowsky, Hongseok Namkoong, Emma Brunskill

Data generation model is adopted from Appendix B
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC4876020/
This paper is refered to as LuIn

Variable names are same as the paper Appendix B

This simulation is designed to imitate Autism SMAERT trial:
https://www.sciencedirect.com/science/article/pii/S0890856714001634
"""

import numpy as np
import pandas as pd

class DataGen(object):
	""" Class used to generate simulation data for autism smart trials

		Attributes
		----------
		Gamma : float
			amount of confounding injected in the simulation
		sigma : float
			noise in the simulation, $\sigma$ in LuIn Appendix B
		theta : float
			effect size, higher larger effect of the adaptive policy
		
		Methods
    	-------
		SMART(sample_size)
			Generated :param sample_size: number of simulation from
			SMART trial (no confounding injected)
		Confounded(sample_size)
			Generated :param sample_size: number of simulation from
			Confounded simulation, amount of confunding is self.Gamma
	"""
	def __init__(self, address='core/data.csv', config=None):
		""" __init__, reads data and define parameters

			Parameters
        	----------
			address : str
				address to .csv file containin the data (covariates)
				extracted covariates are:
					age, gender, indicator of African American,
					indicator of Caucasian, indicator of Hispanic, indicator of Asian
					R, indicator of a fast or slow responder
			config : dictionary
				dictionary containing Class atreibutes values. 
				
		"""
		study_data = pd.read_csv(address)
		self.data = study_data.to_numpy()[:, [0, 1, 2, 3, 4, 5, 9]]
		self.num_data = self.data.shape[0]
		
		self.sigma = config['sigma'] # noise value
		self.theta = config['theta'] # effect size
		self.Gamma = config['Gamma'] # confounding

		# (const, age, gender, aamerican, caucasian, hispanic, asian)
		self.eta_0 = np.array([29.5, -5.1, -16.3, 0, 14.3, -11.8, 0.5])
		self.eta_11 = np.array([23.46, 1.4, -3.0, 16.6, 11.1, 6.5, 22.5])
		self.eta_21 = np.array([22.758, 1.20, 4.33, 12.33, 4.00, 7.53, 7.47])

		self.eta_12 = 0.3
		self.beta_11 = -1.0
		self.eta_22 = 0.2
		self.eta_24 = 0.2
		self.eta_23 = -1.8
		self.beta_21 = -self.theta
		self.eta_31 = 2*self.eta_21
		self.eta_32 = 2*self.eta_22
		self.eta_33 = 2*self.eta_23
		self.eta_34 = 2*self.eta_24 - 1
		self.beta_31 = -2*self.theta

		# variation in effect size	
		self.conf_sigma = 5
	
	def Gen(self, sample_size, RCT=False):
		""" Generates data from a confounded/ RCT experiment. At the first
		step A1 is randomized between {-1, 1}. If R=1 and A1=1 (slow respondent)
		second action A2 is {1} with probability sqrt(Gamma)/(1+sqrt(Gamma)) 
		if theta[i]  > mean(theta) and {-1} with probability sqrt(Gamma)/(1+sqrt(Gamma))
		otherwise. In the case of RCT Gamma is simply set to 1.0

		effect size theta[i] is randomly assigned between self.theta -+ self.conf_sigma

		Parameters
		----------
		sample_size : int
			number of simulations
		RCT : bool
			if data is randomized (not confounded)
		Returns
		-------
		out : float np.array [sample_size, 14] ,with following values along axis 1:
			out[:, 0] = 1, out[:, 1] = age, out[:, 2] = gender, 
			out[:, 3] = Iaamerican, out[:, 4] = Icaucasian,
			out[:, 5] = Ihispanic, out[:, 6] = Iasian
			out[:, 7] = R, out[8] = A1, out[9] = A2, 
			out[:, 10: 13]= Y0, Y12, Y24, Y36

		NOTE: RCT data can be generated by either setting Gamma in the config
			dict to 1.0, or set the bool RCT to True
		"""
		# Confounding
		u = np.random.choice([self.conf_sigma, -self.conf_sigma], size=(sample_size,1))
		self.theta = self.theta + u
		# Recompute the parameters given the noises. 
		beta_21 = -self.theta
		beta_31 = -2*self.theta
		
		# sample sample_size with replacement from data
		X = self.data[np.random.choice(self.num_data, size=sample_size, replace=True), :]
		X = np.concatenate([np.ones((sample_size, 1)), X], axis=1)
		R = np.expand_dims(X[:, -1], axis=-1); X = X[:,0:-1]
		
		Y0 = np.dot(X, self.eta_0) + np.random.normal(loc=0, scale=self.sigma, size=sample_size)
		Y0 = np.expand_dims(Y0, axis=-1) 
		
		# A1 random assignment
		A1 = np.expand_dims(np.random.choice([-1, 1], size=sample_size), axis=-1)            
		
		Y12 = np.expand_dims(np.dot(X, self.eta_11), -1) + self.eta_12 * Y0 +\
			  self.beta_11 * A1 + np.random.normal(loc=0, scale=self.sigma, size=(sample_size,1))
		
		# A2 conditional on A1 and R:
		A2 = np.zeros((sample_size, 1))
		
		if not RCT:
			prob_A2 = np.sqrt(self.Gamma)/(1+np.sqrt(self.Gamma))
		else:
			prob_A2 = 0.5
		for i in range(sample_size):
			if (R[i, 0] == 0 and A1[i, 0] == 1):
				if u[i] < 0: 
					if np.random.rand() <= prob_A2:
						A2[i] = 1
					else:
						A2[i] = -1
				else:
					if np.random.rand() <= (1-prob_A2):
						A2[i] = 1
					else:
						A2[i] = -1
		
		# Y24:
		Y24 =  np.expand_dims(np.dot(X, self.eta_21), -1) + self.eta_22 * Y0 +\
			   self.eta_23 * A1 + self.eta_24 * Y12 +\
			   beta_21 * (1-R)*(A1+1)*A2 +\
			   np.random.normal(loc=0, scale=self.sigma, size=(sample_size,1))
		
		# Y36
		Y36 = np.expand_dims(np.dot(X, self.eta_31), -1) + self.eta_32 * Y0 +\
			   self.eta_33 * A1 + self.eta_34 * Y12 +\
			   beta_31 * (1-R)*(A1+1)*A2 +\
			   np.random.normal(loc=0, scale=self.sigma, size=(sample_size,1))
		out = np.concatenate([X, R, A1, A2, Y0, Y12, Y24, Y36], axis=-1)
		return out