import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Scanner;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.zip.GZIPInputStream;

import base.ColorGroup;
import base.Point;
import base.Utility;
import clust.Clustering;
import clust.Objective;
import clust.WeightedPoint;
import clust.ZObjective;
import coreset.CoresetKMeans;
import coreset.CoresetKMedian;

public class Main
{
	// gen, true, false
	String debug = "true";
	private ArrayList<WeightedPoint> downSample(ArrayList<WeightedPoint> l, double prob)
	{
		ArrayList<WeightedPoint> res = new ArrayList<WeightedPoint>();
		Random rand = new Random(0);
		for (WeightedPoint p : l)
		{
			if (rand.nextDouble() <= prob)
			{
				res.add(p);
			}
		}
		return res;
	}
	
	public static void normalize(ArrayList<WeightedPoint> instance)
	{
		int dim = instance.get(0).data.dim;
		for (int i = 0; i < dim; i++)
		{
			double max = 0;
			for (WeightedPoint wp : instance)
			{
				max = Math.max(max, Math.abs(wp.data.coor[i]));
			}
			if (max != 0)
			{
				for (WeightedPoint wp : instance)
				{
					wp.data.coor[i] /= max;
				}
			}
		}
	}
	
	private static Map<String, Integer> nameToInt(String[] col)
	{
		HashMap<String, Integer> res = new HashMap<String, Integer>();
		int cnt = 0;
		for (String s : col)
		{
			res.put(s.toLowerCase(), cnt++);
		}
		return res;
	}

	
	public ArrayList<WeightedPoint> adult() throws Exception
	{
		/*
		 * age: continuous.
workclass: Private, Self-emp-not-inc, Self-emp-inc, Federal-gov, Local-gov, State-gov, Without-pay, Never-worked.
fnlwgt: continuous.
education: Bachelors, Some-college, 11th, HS-grad, Prof-school, Assoc-acdm, Assoc-voc, 9th, 7th-8th, 12th, Masters, 1st-4th, 10th, Doctorate, 5th-6th, Preschool.
education-num: continuous.
marital-status: Married-civ-spouse, Divorced, Never-married, Separated, Widowed, Married-spouse-absent, Married-AF-spouse.
occupation: Tech-support, Craft-repair, Other-service, Sales, Exec-managerial, Prof-specialty, Handlers-cleaners, Machine-op-inspct, Adm-clerical, Farming-fishing, Transport-moving, Priv-house-serv, Protective-serv, Armed-Forces.
relationship: Wife, Own-child, Husband, Not-in-family, Other-relative, Unmarried.
race: White, Asian-Pac-Islander, Amer-Indian-Eskimo, Other, Black.
sex: Female, Male.
capital-gain: continuous.
capital-loss: continuous.
hours-per-week: continuous.
native-country: United-States, Cambodia, England, Puerto-Rico, Canada, Germany, Outlying-US(Guam-USVI-etc), India, Japan, Greece, South, China, Cuba, Iran, Honduras, Philippines, Italy, Poland, Jamaica, Vietnam, Mexico, Portugal, Ireland, France, Dominican-Republic, Laos, Ecuador, Taiwan, Haiti, Columbia, Hungary, Guatemala, Nicaragua, Scotland, Thailand, Yugoslavia, El-Salvador, Trinadad&Tobago, Peru, Hong, Holand-Netherlands.

		 */
		String[] col = new String[] {"male", "female", 
			"Married-civ-spouse", "Divorced",
			"Never-married", "Separated",
			"Widowed", "Married-spouse-absent", "Married-AF-spouse"};
		Map<String, Integer> colName = Main.nameToInt(col);
		GZIPInputStream instream = new GZIPInputStream(new FileInputStream(new File("data/adult/adult.data.csv.gz")));
		Scanner scan = new Scanner(instream);
		ArrayList<WeightedPoint> instance = new ArrayList<WeightedPoint>();
		while (scan.hasNext())
		{
			String line = scan.nextLine();
            String[] l = line.split(",");
            double[] coor = new double[]{ Double.parseDouble(l[0]),
                Double.parseDouble(l[2]), Double.parseDouble(l[4]),
                Double.parseDouble(l[10]),
                Double.parseDouble(l[11]),
                Double.parseDouble(l[12])};
            int sex = colName.get(l[9].trim().toLowerCase());
            int marital = colName.get(l[5].trim().toLowerCase());
            // color = 0;
			instance.add(new WeightedPoint(new Point(coor), 1, new ColorGroup(new int[] {sex, marital})));
		}
		scan.close();
		return instance;
	}
	
	public ArrayList<WeightedPoint> bank() throws Exception
	{
		/*
		 *    Input variables:
   # bank client data:
   1 - age (numeric)
   2 - job : type of job (categorical: "admin.","blue-collar","entrepreneur","housemaid","management","retired","self-employed","services","student","technician","unemployed","unknown")
   3 - marital : marital status (categorical: "divorced","married","single","unknown"; note: "divorced" means divorced or widowed)
   4 - education (categorical: "basic.4y","basic.6y","basic.9y","high.school","illiterate","professional.course","university.degree","unknown")
   5 - default: has credit in default? (categorical: "no","yes","unknown")
   6 - housing: has housing loan? (categorical: "no","yes","unknown")
   7 - loan: has personal loan? (categorical: "no","yes","unknown")
   # related with the last contact of the current campaign:
   8 - contact: contact communication type (categorical: "cellular","telephone") 
   9 - month: last contact month of year (categorical: "jan", "feb", "mar", ..., "nov", "dec")
  10 - day_of_week: last contact day of the week (categorical: "mon","tue","wed","thu","fri")
  11 - duration: last contact duration, in seconds (numeric). Important note:  this attribute highly affects the output target (e.g., if duration=0 then y="no"). Yet, the duration is not known before a call is performed. Also, after the end of the call y is obviously known. Thus, this input should only be included for benchmark purposes and should be discarded if the intention is to have a realistic predictive model.
   # other attributes:
  12 - campaign: number of contacts performed during this campaign and for this client (numeric, includes last contact)
  13 - pdays: number of days that passed by after the client was last contacted from a previous campaign (numeric; 999 means client was not previously contacted)
  14 - previous: number of contacts performed before this campaign and for this client (numeric)
  15 - poutcome: outcome of the previous marketing campaign (categorical: "failure","nonexistent","success")
   # social and economic context attributes
  16 - emp.var.rate: employment variation rate - quarterly indicator (numeric)
  17 - cons.price.idx: consumer price index - monthly indicator (numeric)     
  18 - cons.conf.idx: consumer confidence index - monthly indicator (numeric)     
  19 - euribor3m: euribor 3 month rate - daily indicator (numeric)
  20 - nr.employed: number of employees - quarterly indicator (numeric)

  Output variable (desired target):
  21 - y - has the client subscribed a term deposit? (binary: "yes","no")
		 */
		GZIPInputStream instream = new GZIPInputStream(new FileInputStream(new File("data/bank/bank-additional-full.csv.gz")));
		Scanner scan = new Scanner(instream);
		ArrayList<WeightedPoint> instance = new ArrayList<WeightedPoint>();
		int[] numId = new int[] {
				1, 11, 12, 13, 14, 16, 17, 18, 19, 20
		};
		while (scan.hasNext())
		{
			String line = scan.nextLine();
			line = line.substring(1, line.length() - 1);
			String[] l = line.split(";");
			double[] coor = new double[numId.length];
			for (int i = 0; i < numId.length; i++)
			{
				coor[i] = Double.parseDouble(l[numId[i] - 1]);
			}
			HashMap<String, Integer> cMap = new HashMap<String, Integer>();
			HashMap<String, Integer> cMap1 = new HashMap<String, Integer>();
			cMap.put("\"\"divorced\"\"", 0);
			cMap.put("\"\"married\"\"", 1);
			cMap.put("\"\"single\"\"", 2);
			cMap.put("\"\"unknown\"\"", 3);
			cMap1.put("\"\"no\"\"", 4);
			cMap1.put("\"\"yes\"\"", 5);
			cMap1.put("\"\"unknown\"\"", 6);
			int mar = cMap.get(l[2]);
			int def = cMap1.get(l[4]);
			// color = 0;
			instance.add(new WeightedPoint(new Point(coor), 1, new ColorGroup(new int[] {mar, def})));
		}
		
		scan.close();
		return instance;
	}
	
	double evaluateError(ArrayList<WeightedPoint> instance, ArrayList<WeightedPoint> coreset, int k, int cases, Objective O)
	{
		double err = 0;
		for (int t = 0; t < cases; t++)
		{
			ArrayList<Point> C = this.getCenter(instance, k);
			/*for (int i = 0; i < k; i++)
			{
				C.add(instance.get(rand.nextInt(n)).data);
			}*/
			double obj = Clustering.evaluate(instance, C, O);
			double cor = Clustering.evaluate(coreset, C, O);
			err = Math.max(err, Math.abs(cor - obj) / obj);
		}
		return err;
	}
	
	private int[] randPart(int n, int k)
	{
		int[] res = new int[k];
		if (k == 1)
		{
			res[0] = n;
		}
		else
		{
			int remain = n;
			for (int j = 0; j < k - 1; j++)
			{
				res[j] = Utility.rand.nextInt(remain - (k - j - 1));
				remain -= res[j];
			}
			res[k - 1] = remain;
			assert(res[k - 1] >= 0);
		}
		return res;
	}
	
	int[][] getF(HashMap<ColorGroup, List<WeightedPoint>> colorMap, int k)
	{
		HashSet<Integer> col = new HashSet<Integer>();
		for (ColorGroup c : colorMap.keySet())
		{
			col.addAll(c.c);
		}
		int numCol = col.size();
		int[][] F = new int[numCol][k];
		for (ColorGroup c : colorMap.keySet())
		{
			int size = colorMap.get(c).size();
			int[] tmp = randPart(size, k);
			for (Integer cc : c.c)
			{
				for (int i = 0; i < k; i++)
				{
					F[cc][i] += tmp[i];
				}
			}
		}
		return F;
	}
	
	ArrayList<Point> getCenter(ArrayList<WeightedPoint> instance, int k)
	{
		ArrayList<Point> C = new ArrayList<Point>();
		for (int j = 0; j < k; j++)
		{
			C.add(instance.get(Utility.rand.nextInt(instance.size())).data);
		}
		return C;
	}
	
	ArrayList<Object[]> getTests(ArrayList<WeightedPoint> instance, int cases, int k)
	{
		HashMap<ColorGroup, List<WeightedPoint>> colorMap = WeightedPoint.colorMap(instance);
		ArrayList<Object[]> res = new ArrayList<Object[]>();
		for (int t = 0; t < cases; t++)
		{
			int[][] F = getF(colorMap, k);
			ArrayList<Point> C = getCenter(instance, k);
			res.add(new Object[] {F, C});
		}
		return res;
	}
	
	Object[] evaluateCoreset(List<WeightedPoint> X, int k, Objective O, ArrayList<Object[]> tests, long[] objRunTime, double[] objValue)
	{
		int cases = tests.size();

		double[] perTime = new double[cases];
		double[] perErr = new double[cases];
		class Run implements Runnable
		{
			final int i;
			public Run(int i)
			{
				this.i = i;
			}
			@Override
			public void run() {
				int[][] F = (int[][])tests.get(i)[0];
				ArrayList<Point> C = (ArrayList<Point>)(tests.get(i)[1]);
				
				long t0 = System.currentTimeMillis();
				double cor = Clustering.evaluate(X, F, C, O);
				perTime[i] = System.currentTimeMillis() - t0;
				
				perErr[i] = Math.abs(cor - objValue[i]) / objValue[i];
				//err = Math.max(err, Math.abs(cor - objValue[i]) / objValue[i]);
				
				System.out.printf("evaluate coreset progress: %d/%d\n", i, cases);
			}
		}

		ExecutorService ser = Executors.newFixedThreadPool(1);
		for (int i = 0; i < cases; i++)
		{
			ser.execute(new Run(i));
		}
		ser.shutdown();
		while (!ser.isTerminated()) {
			try {
				Thread.sleep(1000);
				System.out.println("waiting");
			} catch (InterruptedException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}

		double avgTime = 0;
		double err = 0;
		for (int i = 0; i < cases; i++)
		{
			avgTime += perTime[i];
			err = Math.max(err, perErr[i]);
		}
		return new Object[] {avgTime / cases, err};
	}
	
	Object[] evaluateOriginal(ArrayList<WeightedPoint> X, int k, Objective O, ArrayList<Object[]> tests, String name)
	{
		int cases = tests.size();
		
		long[] runTime = new long[cases];
		double[] objValue = new double[cases];
		
		class tmpRun implements Runnable
		{
			
			final int i;
			public tmpRun(int i)
			{
				this.i = i;
			}

			@Override
			public void run() {
				int[][] F = (int[][])tests.get(i)[0];
				ArrayList<Point> C = (ArrayList<Point>)(tests.get(i)[1]);
				
				long t0 = System.currentTimeMillis();
				objValue[i] = Clustering.evaluate(X, F, C, O);
				runTime[i] = System.currentTimeMillis() - t0;
				System.out.printf("evaluate original progress: %d/%d\n", i, cases);	
			}
			
		}
		ExecutorService ser = Executors.newFixedThreadPool(1);
		for (int i = 0; i < cases; i++)
		{
			ser.execute(new tmpRun(i));
			/*int[][] F = (int[][])tests.get(i)[0];
			ArrayList<Point> C = (ArrayList<Point>)(tests.get(i)[1]);
			
			long t0 = System.currentTimeMillis();
			objValue.add(Clustering.evaluate(X, F, C, O));
			runTime.add(System.currentTimeMillis() - t0);
			System.out.printf("evaluate original progress: %d/%d\n", i, cases);*/
		}
		ser.shutdown();
		while (!ser.isTerminated()) {
			try {
				Thread.sleep(1000);
				System.out.println("waiting");
			} catch (InterruptedException e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
		return new Object[] {runTime, objValue};
	}
	
	Object[] prepareTests(ArrayList<WeightedPoint> instance, int k, int cases, Objective O, String name)
	{
		ArrayList<Object[]> tests = null;
		long[] objRunTime = null;
		double[] objValue = null;
		if (debug.equals("true"))
		{
			try {
				ObjectInputStream in = new ObjectInputStream(new FileInputStream(new File("data/debug_" + name +".data")));
				tests = (ArrayList<Object[]>)in.readObject();
				objRunTime = (long[])in.readObject();
				objValue = (double[])in.readObject();
				in.close();
			} catch (Exception e) {
				// TODO Auto-generated catch block
				e.printStackTrace();
			}
		}
		else
		{
			tests = getTests(instance, cases, k);
			Object[] tmp = evaluateOriginal(instance, k , O, tests, name);
			objRunTime = (long[])(tmp[0]);
			objValue = (double[])(tmp[1]);
			if (debug.equals("gen"))
			{
				try {
					ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(new File("data/debug_" + name + ".data")));
					out.writeObject(tests);
					out.writeObject(objRunTime);
					out.writeObject(objValue);
					out.close();
				} catch (Exception e)
				{
					e.printStackTrace();
				}
			}
		}
		return new Object[] {tests, objRunTime, objValue};
	}
	
	void evaluateFairMeans(ArrayList<WeightedPoint> instance, int k, int cases, double[] eps, String name) throws Exception
	{
		System.out.println("evaluate fair means");
		Objective O = ZObjective.getObjective(2.0);
		
		PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream(new File("output/" + name +"_means.csv"))));
		out.println("eps,emperr_our,emperr_bico,our_size,bico_size,obj_time,our_time,bico_time");
		
		Object[] tmp = this.prepareTests(instance, k, cases, O, name + "_means");
		ArrayList<Object[]> tests = (ArrayList<Object[]>)(tmp[0]);
		long[] objRunTime = (long[])(tmp[1]);
		double[] objValue = (double[])(tmp[2]);
		
		double objTime = 0;
		for (Long t : objRunTime)
		{
			objTime += t;
		}
		objTime /= objRunTime.length;
		
		for (double e : eps)
		{
			System.out.printf("progress: %f\n", e);
			System.out.println("constructing coreset");
			List<WeightedPoint> ourCoreset = new CoresetKMeans(instance).getCoreset(e, k);
			System.out.println("constructing bico");
			ArrayList<WeightedPoint> bicoCoreset = null;
			int start = 1, end = ourCoreset.size() * 2;
			while (start != end)
			{
				int mid = (start + end) / 2;
				bicoCoreset = new BICO(instance).getCoreset(k, e, mid);
				if (bicoCoreset.size() < ourCoreset.size() * 0.95)
				{
					start = mid + 1;
				}
				else if (bicoCoreset.size() > ourCoreset.size() * 1.05)
				{
					end = mid - 1;
				}
				else
				{
					break;
				}
			}
			System.out.println("bico finished constructing");
			Object[] our = evaluateCoreset(ourCoreset, k, O, tests, objRunTime, objValue);
			Object[] bico = evaluateCoreset(bicoCoreset, k , O, tests, objRunTime, objValue);
			
			double ourTime = (double)our[0];
			double ourErr = (double)our[1];
			double bicoTime = (double)bico[0];
			double bicoErr = (double)bico[1];
			out.printf("%.6f,%.6f,%.6f,%d,%d,%.6f,%.6f,%.6f", e, ourErr, bicoErr, ourCoreset.size(), bicoCoreset.size(), objTime, ourTime, bicoTime);
			out.println();
			// System.out.printf("%.6f, %.6f\n", ourErr, bicoErr);
		}
		
		out.close();
	}
	
	void evaluateFairMedian(ArrayList<WeightedPoint> instance, int k, int cases, double[] eps, String name) throws Exception
	{
		System.out.println("evaluate fair median");
		Objective O = ZObjective.getObjective(1.0);

		PrintWriter out = new PrintWriter(new BufferedOutputStream(new FileOutputStream(new File("output/" + name + "_median.csv"))));
		out.println("eps,emperr_our,our_size,obj_time,our_time");
		
		Object[] tmp = this.prepareTests(instance, k, cases, O, name + "_median");
		ArrayList<Object[]> tests = (ArrayList<Object[]>)(tmp[0]);
		long[] objRunTime = (long[])(tmp[1]);
		double[] objValue = (double[])(tmp[2]);
		
		double objTime = 0;
		for (Long t : objRunTime)
		{
			objTime += t;
		}
		objTime /= objRunTime.length;

		for (double e : eps)
		{
			System.out.printf("progress: %f\n", e);
			List<WeightedPoint> ourCoreset = new CoresetKMedian(instance).getCoreset(e, k);
			// System.out.println("coreset size: " + ourCoreset.size());
			Object[] our = evaluateCoreset(ourCoreset, k,O, tests, objRunTime, objValue);
			
			double ourTime = (double)our[0];
			double ourErr = (double)our[1];
			out.printf("%.6f,%.6f,%d,%.6f,%.6f", e, ourErr, ourCoreset.size(), objTime, ourTime);
			out.println();
		}
		
		out.close();
	}
	
	public void run() throws Exception
	{
        final int k = 3;
        final int cases = 500;
		double[] eps = new double[] {0.1, 0.15, 0.20, 0.25, 0.30, 0.35, 0.40, 0.45, 0.50};
		//eps = new double[] {0.25, 0.30, 0.35, 0.40, 0.45, 0.50};
		//eps = new double[] {0.4};
		// eps = new double[] {0.5};
		//eps = new double[] {};

		// adult
        ArrayList<WeightedPoint> instance = adult();
        normalize(instance);
        System.out.println("adult size: " + instance.size());
        
		evaluateFairMeans(instance, k, cases, eps, "adult");
		evaluateFairMedian(instance, k, cases, eps, "adult");
		
		
		// bank
		instance = bank();
		normalize(instance);
		
		evaluateFairMeans(instance, k , cases, eps, "bank");
		evaluateFairMedian(instance, k, cases, eps, "bank");
	}
	
	public static void main(String[] args) throws Exception
	{
		/*ArrayList<Point> list = new ArrayList<Point>();
		list.add(new Point(new double[] {101, 1}));
		// list.add(new Point(new double[] {1, -1}));
		Point l = PCA.pca(list, 0);
		for (int i = 0; i < l.dim; i++)
		{
			System.out.print(l.coor[i]);
			if (i != l.dim)
			{
				System.out.print(" ");
			}
		}
		System.out.println();*/
		new Main().run();
	}
	// x_{m, white} + x_{fm, white} = white
	// x_{m, black} + x_{fm, black} = black
	// x_{m, black} + x_{m, white} = m
	// x_{fm, black} + x_{fm, white} = fm
	// 
	// x_{fm,black} - x_{m, white} = white - fm
	// 
}
