#from smartnoise_sdk.synth.snsynth import Synthesizer
#from smartnoise_sdk.synth.snsynth.pytorch.nn.dpctgan import DPCTGAN
from snsynth import Synthesizer
import snsynth
import pandas as pd

import numpy as np
from ucimlrepo import fetch_ucirepo 
from folktables import ACSDataSource, ACSEmployment, ACSIncome, ACSPublicCoverage, ACSTravelTime, ACSMobility

from sdmetrics.reports.single_table import DiagnosticReport
from sdmetrics.reports.single_table import QualityReport
from sdmetrics.visualization import get_column_plot
from sdmetrics.visualization import get_column_pair_plot
from sdmetrics.single_table import NewRowSynthesis
from sklearn.model_selection import train_test_split
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("-d","--dataset")
parser.add_argument("-e","--epsilon")
parser.add_argument("-n","--name")
args = parser.parse_args()
dataset = args.dataset
epsilon = args.epsilon
name = args.name
#dataset = 'acs_employment'



if dataset == 'adult':
    # fetch dataset 
    adult = fetch_ucirepo(id=2) 

    # data (as pandas dataframes) 
    X = adult.data.original #features 
    #X = X.drop(['capital-gain','capital-loss'],axis=1)
    #y = adult.data.targets 
    metadata={
        "columns":{
            "age":{
                "sdtype":"numerical"
            },
            "workclass":{
                "sdtype":"categorical"
            },
            "fnlwgt":{
                "sdtype":"numerical"
            },
            "education":{
                "sdtype":"categorical"
            },
            "education-num":{
                "sdtype":"numerical"
            },
            "marital-status":{
                "sdtype":"categorical"
            },
            "occupation":{
                "sdtype":"categorical"
            },
            "relationship":{
                "sdtype":"categorical"
            },
            "race":{
                "sdtype":"categorical"
            },
            "sex":{
                "sdtype":"boolean"
            },
            "capital-gain":{
                "sdtype":"numerical"
            },
            "capital-loss":{
                "sdtype":"numerical"
            },
            "hours-per-week":{
                "sdtype":"numerical"
            },
            "native-country":{
                "sdtype":"categorical"
            },
            "income":{
                "sdtype":"boolean"
            }
        }
    }
elif dataset == 'acs_employment':
    data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
    acs_data = data_source.get_data(states=["CA"], download=True)
    X, label, group = ACSEmployment.df_to_pandas(acs_data)
    metadata={
        "columns":{
            "AGEP":{
                "sdtype":"numerical"
            },
            "SCHL":{
                "sdtype":"categorical"
            },
            "MAR":{
                "sdtype":"categorical"
            },
            "RELP":{
                "sdtype":"categorical"
            },
            "DIS":{
                "sdtype":"categorical"
            },
            "ESP":{
                "sdtype":"categorical"
            },
            "CIT":{
                "sdtype":"categorical"
            },
            "MIG":{
                "sdtype":"categorical"
            },
            "MIL":{
                "sdtype":"categorical"
            },
            "ANC":{
                "sdtype":"categorical"
            },
            "NATIVITY":{
                "sdtype":"categorical"
            },
            "DEAR":{
                "sdtype":"categorical"
            },
            "DEYE":{
                "sdtype":"categorical"
            },
            "DREM":{
                "sdtype":"categorical"
            },
            "SEX":{
                "sdtype":"categorical"
            },
            "RAC1P":{
                "sdtype":"categorical"
            },
            'ESR':{#target
                "sdtype":"categorical"
            }
        }
    }
elif dataset == 'acs_income':
    data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
    acs_data = data_source.get_data(states=["CA"], download=True)
    X, label, group = ACSIncome.df_to_pandas(acs_data)
    metadata={
        "columns":{
            "AGEP":{
                "sdtype":"numerical"
            },
            "SCHL":{
                "sdtype":"categorical"
            },
            "MAR":{
                "sdtype":"categorical"
            },
            "RELP":{
                "sdtype":"categorical"
            },
            #"DIS":{
            #    "sdtype":"categorical"
            #},
            #"ESP":{
            #    "sdtype":"categorical"
            #},
            #"CIT":{
            #    "sdtype":"categorical"
            #},
            #"MIG":{
            #    "sdtype":"categorical"
            #},
            #"MIL":{
            #    "sdtype":"categorical"
            #},
            #"ANC":{
            #    "sdtype":"categorical"
            #},
            #"NATIVITY":{
            #    "sdtype":"categorical"
            #},
            #"DEAR":{
            #    "sdtype":"categorical"
            #},
            #"DEYE":{
            #    "sdtype":"categorical"
            #},
            #"DREM":{
            #    "sdtype":"categorical"
            #},
            "SEX":{
                "sdtype":"categorical"
            },
            "RAC1P":{
                "sdtype":"categorical"
            },
            "COW":{
                "sdtype":"categorical"
            },
            "OCCP":{
                "sdtype":"categorical"
            },
            "POBP":{
                "sdtype":"categorical"
            },
            "WKHP":{
                "sdtype":"numerical"
            },
            'PINCP':{#target
                "sdtype":"categorical"
            }
        }
    }
elif dataset == 'acs_public_coverage':
    data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
    acs_data = data_source.get_data(states=["CA"], download=True)
    X, label, group = ACSPublicCoverage.df_to_pandas(acs_data)
    metadata={
        "columns":{
            "AGEP":{
                "sdtype":"numerical"
            },
            "SCHL":{
                "sdtype":"categorical"
            },
            "MAR":{
                "sdtype":"categorical"
            },
            #"RELP":{
            #    "sdtype":"categorical"
            #},
            "DIS":{
                "sdtype":"categorical"
            },
            "ESP":{
                "sdtype":"categorical"
            },
            "CIT":{
                "sdtype":"categorical"
            },
            "MIG":{
                "sdtype":"categorical"
            },
            "MIL":{
                "sdtype":"categorical"
            },
            "ANC":{
                "sdtype":"categorical"
            },
            "NATIVITY":{
                "sdtype":"categorical"
            },
            "DEAR":{
                "sdtype":"categorical"
            },
            "DEYE":{
                "sdtype":"categorical"
            },
            "DREM":{
                "sdtype":"categorical"
            },
            "SEX":{
                "sdtype":"categorical"
            },
            "RAC1P":{
                "sdtype":"categorical"
            },
            #"COW":{
            #    "sdtype":"categorical"
            #},
            #"OCCP":{
            #    "sdtype":"categorical"
            #},
            #"POBP":{
            #    "sdtype":"categorical"
            #},
            #"WKHP":{
            #    "sdtype":"numerical"
            #},
            'PINCP':{
                "sdtype":"numerical"
            },
            'ESR':{
                "sdtype":"categorical"
            },
            'ST':{
                "sdtype":"categorical"
            },
            'FER':{
                "sdtype":"categorical"
            },
            'PUBCOV':{#target
                "sdtype":"categorical"
            }
        }
    }
elif dataset == 'acs_travel':
    data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
    acs_data = data_source.get_data(states=["CA"], download=True)
    X, label, group = ACSTravelTime.df_to_pandas(acs_data)
    metadata={
        "columns":{
            "AGEP":{
                "sdtype":"numerical"
            },
            "SCHL":{
                "sdtype":"categorical"
            },
            "MAR":{
                "sdtype":"categorical"
            },
            "RELP":{
                "sdtype":"categorical"
            },
            "DIS":{
                "sdtype":"categorical"
            },
            "ESP":{
                "sdtype":"categorical"
            },
            "CIT":{
                "sdtype":"categorical"
            },
            "MIG":{
                "sdtype":"categorical"
            },
            #"MIL":{
            #    "sdtype":"categorical"
            #},
            #"ANC":{
            #    "sdtype":"categorical"
            #},
            #"NATIVITY":{
            #    "sdtype":"categorical"
            #},
            #"DEAR":{
            #    "sdtype":"categorical"
            #},
            #"DEYE":{
            #    "sdtype":"categorical"
            #},
            #"DREM":{
            #    "sdtype":"categorical"
            #},
            "SEX":{
                "sdtype":"categorical"
            },
            "RAC1P":{
                "sdtype":"categorical"
            },
            #"COW":{
            #    "sdtype":"categorical"
            #},
            "OCCP":{
                "sdtype":"categorical"
            },
            #"POBP":{
            #    "sdtype":"categorical"
            #},
            #"WKHP":{
            #    "sdtype":"numerical"
            #},
            #'PINCP':{
            #    "sdtype":"numerical"
            #},
            #'ESR':{
            #    "sdtype":"categorical"
            #},
            'ST':{
                "sdtype":"categorical"
            },
            #'FER':{
            #    "sdtype":"categorical"
            #}
            'PUMA':{
                "sdtype":"categorical"
            },
            'JWTR':{
                "sdtype":"categorical"
            },
            'POWPUMA':{
                "sdtype":"categorical"
            },
            'POVPIP':{
                "sdtype":"numerical"
            },
            'JWMNP':{#target              
                "sdtype":"categorical"
            }
        }
    }
elif dataset == 'acs_mobility':
    data_source = ACSDataSource(survey_year='2018', horizon='1-Year', survey='person')
    acs_data = data_source.get_data(states=["CA"], download=True)
    X, label, group = ACSMobility.df_to_pandas(acs_data)
    metadata={
        "columns":{
            "AGEP":{
                "sdtype":"numerical"
            },
            "SCHL":{
                "sdtype":"categorical"
            },
            "MAR":{
                "sdtype":"categorical"
            },
            "RELP":{
                "sdtype":"categorical"
            },
            "DIS":{
                "sdtype":"categorical"
            },
            "ESP":{
                "sdtype":"categorical"
            },
            "CIT":{
                "sdtype":"categorical"
            },
            "MIG":{#target
                "sdtype":"categorical"
            },
            "MIL":{
                "sdtype":"categorical"
            },
            "ANC":{
                "sdtype":"categorical"
            },
            "NATIVITY":{
                "sdtype":"categorical"
            },
            "DEAR":{
                "sdtype":"categorical"
            },
            "DEYE":{
                "sdtype":"categorical"
            },
            "DREM":{
                "sdtype":"categorical"
            },
            "SEX":{
                "sdtype":"categorical"
            },
            "RAC1P":{
                "sdtype":"categorical"
            },
            "COW":{
                "sdtype":"categorical"
            },
            #"OCCP":{
            #    "sdtype":"categorical"
            #},
            #"POBP":{
            #    "sdtype":"categorical"
            #},
            "WKHP":{
                "sdtype":"numerical"
            },
            'PINCP':{
                "sdtype":"numerical"
            },
            'ESR':{
                "sdtype":"categorical"
            },
            #'ST':{
            #    "sdtype":"categorical"
            #},
            #'FER':{
            #    "sdtype":"categorical"
            #}
            'GCL':{
                "sdtype":"categorical"
            },
            'JWMNP':{
                "sdtype":"numerical"
            }
        }
    }
  
  
  
dp_input = None
sfd_input = None

print(X.shape)
X, X_test, y, y_test = train_test_split(X, label, train_size=0.25, random_state=55)
X = pd.concat([X,y],axis=1)

print('X shape = ' + str(X.shape))

##################### synth metadata fix
columns=list(X.columns)
categorical_columns=[f for f in columns if metadata['columns'][f]['sdtype']=='categorical']
continuous_columns=[f for f in columns if metadata['columns'][f]['sdtype']=='numerical']
####################



eps = float(epsilon)
loss = 'kmm' 
epochs = 150
gdim = (512,512,512)
dsteps = 1
lr = 2e-5#
n_KL_slices = 150 #m
n_KL_slice_dim = 3 #k

n_KL_slices_wass = 150
n_KL_slice_dim_wass = 1
batch_size = 128 

do_baseline = True
do_sliced = True
do_pate = True
do_wass = True
if do_pate:
    pate = Synthesizer.create('pategan', epsilon=eps)
if do_baseline:
    dp = Synthesizer.create('dpgan', epsilon=eps, verbose=True,epochs=epochs*2)

sfd = Synthesizer.create('sfdpgan', epsilon=eps, verbose=True,loss=loss, epochs=epochs,batch_size=batch_size, generator_dim=gdim,discriminator_steps=dsteps,discriminator_lr=lr,generator_lr=lr,n_KL_slices=n_KL_slices,n_KL_slice_dim=n_KL_slice_dim)

if do_wass:
    wass = Synthesizer.create('sfdpgan', epsilon=eps, verbose=True,loss='SWD-DP', epochs=epochs,batch_size=batch_size,generator_dim=gdim,discriminator_steps=dsteps,discriminator_lr=lr,generator_lr=lr,n_KL_slices=n_KL_slices_wass,n_KL_slice_dim=n_KL_slice_dim_wass)
if do_baseline:
    try:
        
        dp_sample= pd.read_csv(f'./Results/{dataset}/dpGAN_eps{eps}.csv')
    except:
        dp_sample = dp.fit_sample(X, preprocessor_eps=.5,nullable=True,categorical_columns=categorical_columns,continuous_columns=continuous_columns)
        
        dp_sample.to_csv(f'./Results/{dataset}/dpGAN_eps{eps}.csv', index=False)
if do_pate:
    try:
        
        dp_sample= pd.read_csv(f'./Results/{dataset}/pateGANnew_eps{eps}.csv')
    except:
        dp_sample = pate.fit_sample(X, preprocessor_eps=.5,nullable=True,categorical_columns=categorical_columns,continuous_columns=continuous_columns)
        
        dp_sample.to_csv(f'./Results/{dataset}/pateGANnew_eps{eps}.csv', index=False)
if do_sliced:
    flname = f'./Results/{dataset}/sfdGAN_eps{eps}_loss{loss}_epochs{epochs}_gdim{gdim}_dsteps{dsteps}_lr{lr}_nslice{n_KL_slices}_slicedim{n_KL_slice_dim}.csv'
   
        
    try:
        
        sfd_sample = pd.read_csv(flname)
    except:
        sfd_sample = sfd.fit_sample(X, preprocessor_eps=0.5,nullable=True,categorical_columns=categorical_columns,continuous_columns=continuous_columns)
        
        sfd_sample.to_csv(flname, index=False)

if do_wass:
    flname = f'./Results/{dataset}/WassGAN_eps{eps}_loss{loss}_epochs{epochs}_gdim{gdim}_dsteps{dsteps}_lr{lr}_nslice{n_KL_slices}_slicedim{n_KL_slice_dim}.csv'
    try:
        
        wass_sample = pd.read_csv(flname)
    except:
        wass_sample = wass.fit_sample(X, preprocessor_eps=0.5,nullable=True,categorical_columns=categorical_columns,continuous_columns=continuous_columns)
       
        wass_sample.to_csv(flname, index=False)

