from pipeline_utils import *
from baseline_utils import get_multi_metadata
from sdv.evaluation.multi_table import evaluate_quality
import pandas as pd
import numpy as np
import argparse
import os
import json

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='Matching pipeline')
    parser.add_argument('--data_dir', type=str, default='complex_data/instacart/preprocessed')
    parser.add_argument('--NUM_MATCHING_CLUSTERS', type=int, default=1)
    parser.add_argument('--exp_name', type=str, default='instacart_exp')
    parser.add_argument('--batch_size', type=int, default=1000)
    parser.add_argument('--unique_matching', action='store_true')
    parser.add_argument('--working_dir', type=str)

    args = parser.parse_args()

    dataset_meta = json.load(open(os.path.join(args.data_dir, 'dataset_meta.json')))
    relation_order = dataset_meta['relation_order']
    save_dir = os.path.join(args.working_dir, args.exp_name)

    tables = {}

    for table, meta in dataset_meta['tables'].items():
        tables[table] = {
            'df': pd.read_csv(os.path.join(args.data_dir, f'{table}.csv')),
            'domain': json.load(open(os.path.join(args.data_dir, f'{table}_domain.json'))),
            'children': meta['children'],
            'parents': meta['parents'],
        }
        tables[table]['original_cols'] = list(tables[table]['df'].columns)
        tables[table]['original_df'] = tables[table]['df'].copy()

    synthetic_tables = {}
    for parent, child in relation_order:
        table = pd.read_csv(os.path.join(save_dir, 'before_matching', f'{(parent, child)}_synthetic.csv'))
        synthetic_tables[(parent, child)] = {
            'df': table,
        }

    final_tables = {}
    for parent, child in relation_order:
        if child not in final_tables:
            if len(tables[child]['parents']) > 1:
                print(f'matching tables {child} with multiple parents {tables[child]["parents"]}')
                final_tables[child] = handle_multi_parent(
                    child, 
                    tables[child]['parents'], 
                    synthetic_tables, 
                    args.NUM_MATCHING_CLUSTERS,
                    args.unique_matching,
                    args.batch_size,
                )
            else:
                final_tables[child] = synthetic_tables[(parent, child)]['df']

    cleaned_tables = {}
    for key, val in final_tables.items():
        cleaned_tables[key] = val[tables[key]['original_cols']]

    for key, val in cleaned_tables.items():
        val.to_csv(os.path.join(save_dir, f'{key}_synthetic.csv'), index=False)

    real_data = {}
    for table_name, df in tables.items():
        real_data[table_name] = df['original_df']

    multi_meta = get_multi_metadata(tables, relation_order)
    quality = evaluate_quality(
        real_data,
        cleaned_tables,
        multi_meta
    )
