import torch
import math
import time, os
import numpy as np
import pennylane as qml
from math import pi
from model import QNN

# ======================= hyperparameter ==================
paras = {}
paras['data_dim']        = 2
paras['d']               = paras['data_dim']
paras['K']               = 10
paras['eps']             = 0.01
paras['s']               = 2
paras['depth_constant']  = 2
paras['batch_size']      = 100
paras['random_seed_1']   = 28_10_2000
paras['random_seed_2']   = 13_02_1999
paras['random_seed_3']   = 27_11_2000
paras['max_shots']       = 1000
paras['learning_rate']   = 1e-3
paras['ckpt_dir']        = 'outputs_new_target_4/pure_quantum/outputs_K10_s2_lr3_more_sample/'

if not os.path.exists(paras['ckpt_dir']): 
    os.makedirs(paras['ckpt_dir'])


print('===================================================', flush=True)
for key,value in paras.items():
    print('{key}:{value}'.format(key = key, value = value), flush=True)
print('===================================================', flush=True)

# ======================= target function ==================
def target_func(x):
    # target 1
    x0 = x[0]
    x1 = x[1]
    return ((x0**2 + x1 - 1.5*pi)**2 + (x0 + x1**2 + pi)**2 + (x0 + x1 - 0.5 * pi)**2) / (5*pi**2)
    

# ======================= training data ==================
input_data = np.mgrid[0:1:0.1, 0:1:0.1].reshape(2,-1).T
true_label = np.array([target_func(x) for x in input_data])
input_data = torch.from_numpy(input_data).float()
true_label = torch.from_numpy(true_label).float().view(-1)

# ======================== define model ====================
model = QNN(s=paras['s'], depth_constant=paras['depth_constant'], K=paras['K'], eps=paras['eps'], d=paras['data_dim'],  random_seed_1=paras['random_seed_1'], random_seed_2=paras['random_seed_2'])
model.load_discretization_model(r'.\model\localization\model_poly_K10_322.pth')
# model.load_polynomial_model()
criterion  = torch.nn.MSELoss(reduction="sum")
criterion1 = torch.nn.MSELoss(reduction="mean")
optimizer  = torch.optim.Adam(model.parameters(), lr = paras['learning_rate'])

t          = 0
loss_val   = 1.0
while (t <= paras['max_shots']) and (loss_val > 0.01):
    y_pred     = model(input_data)
    loss       = criterion(y_pred, true_label)
    loss_val   = loss.item()
    print('y_pred: ', y_pred, flush=True)
    print('loss sum: ', loss_val, flush=True)
    print('loss means: ', criterion1(y_pred, true_label).item(), flush=True)
    if (t % 1 == 0):
        print(f"---- iter: {t}, loss: {round(loss_val, 4)} -----", flush=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 10 == 0:
        torch.save(model.state_dict(), paras['ckpt_dir'] + 'model_K10_s2_{}.pth'.format(t))
    t += 1
torch.save(model.state_dict(), paras['ckpt_dir'] + 'model_K10_s2_{}.pth'.format(t))


