import torch
import numpy as np
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))

from modules.conAR_mis_beta import conAR_mis
from modules.cigp_v10 import cigp
from tools.prepare_data import data_preparation
from tools.calculate_metrix import calculate_metrix

print(torch.__version__)
# I use torch (1.11.0) for this work. lower version may not work.

import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True' # Fixing strange error if run in MacOS
JITTER = 1e-6
EPS = 1e-10
PI = 3.1415

print('testing')
print(torch.__version__)

# train_sample_num means the first fidelity's train samples
# dec_rate means the way to make the data low
# fidelity_num means use how many fidelity to train&test
def data_model(data_name, 
               mask,
               train_begin_index = 0, 
               test_begin_index = 0,
               train_samples_num = 16,
               test_samples_num = 128,
               fidelity_num = 5,
               seed = 1,
               need_inerp = True):

    xtr, ytr, xte, yte = data_preparation(data_name, fidelity_num, seed, train_samples_num)


    '''initiate the numbers'''
    train_begin_index = train_begin_index
    test_begin_index = 0
    train_samples_num = train_samples_num
    test_samples_num = test_samples_num

    '''train for missing points'''
    missing_var = []
    # missing_var.append(torch.zeros(train_samples_num))
    ytr_f = []
    # ytr_f.append(ytr[0][train_begin_index:train_samples_num])
    
    for i in range(0, fidelity_num):
        # train the exist data 

        xtr_exist = []
        ytr_exist = []
        xtr_missing = []
        
        for k in range(0, train_samples_num):
            if mask[i][k] == 0:
                xtr_missing.append(xtr[0][k])
            else:
                xtr_exist.append(xtr[0][k])
                ytr_exist.append(ytr[i][k])
        
        xtr_exist = torch.stack(xtr_exist)
        ytr_exist = torch.stack(ytr_exist)
        if len(xtr_missing) != 0:
            xtr_missing = torch.stack(xtr_missing)
            
            fix = cigp(xtr_exist, ytr_exist)
            fix.train_adam(200, lr = 0.02)
            
            with torch.no_grad():
                missing_mean, missing_variance = fix.forward(xtr_missing)
            
            ytr_full = []
            m_var = []
            j_mis = 0
            j_exist = 0
            for k in range(train_samples_num):
                if mask[i][k] == 0:
                    ytr_full.append(missing_mean[j_mis])
                    m_var.append(missing_variance[j_mis][0])
                    j_mis += 1
                else:
                    ytr_full.append(ytr_exist[j_exist])
                    m_var.append(torch.tensor(0))
                    j_exist += 1

            ytr_full = torch.stack(ytr_full)
            m_var = torch.stack(m_var)
            ytr_f.append(ytr_full)
            missing_var.append(m_var)
        else:
            ytr_f.append(ytr_exist)
            missing_var.append(torch.zeros(train_samples_num))

    xte = xte[0][test_begin_index:test_samples_num]
    

    '''training process'''
    train_num = [train_samples_num for i in range(fidelity_num)]
    m_fid = conAR_mis(xtr, ytr_f, xte,
                    train_begin_index = 0, 
                    train_num = train_num,
                    fidelity_num = fidelity_num,
                    seed = seed,
                    niteration = 200,
                    learning_rate = 0.02,
                    missing_var = missing_var,
                    normal_y_mode = 0)
    yte_mean, yte_var = m_fid.train_mod()

    '''Test and return the evaluation results'''

    yte_test = yte[fidelity_num - 1][test_begin_index : test_samples_num]
    cp_metrics = calculate_metrix(y_test = yte_test, y_mean_pre = yte_mean, y_var_pre = yte_var)

    print("loss of our model:", cp_metrics)

    return cp_metrics
    



