import torch
import math
import time, os
import numpy as np
import pennylane as qml
from math import pi
from model import Discretization, step_function

# ======================= hyperparameter ==================
paras = {}
paras['data_dim']        = 2
paras['d']               = paras['data_dim']
paras['K']               = 2
paras['eps']             = 0.01
paras['s']               = 2
paras['batch_size']      = 100
paras['random_seed_1']   = 28_10_2000
paras['random_seed_2']   = 13_02_1967
paras['random_seed_3']   = 27_11_2000
paras['max_shots']       = 500
paras['learning_rate']   = 1e-2
paras['ckpt_dir']        = 'outputs/models_K2_new/'

if not os.path.exists(paras['ckpt_dir']): 
    os.makedirs(paras['ckpt_dir'])

print('===================================================', flush=True)
for key,value in paras.items():
    print('{key}:{value}'.format(key = key, value = value), flush=True)
print('===================================================', flush=True)

# ======================== define data ====================
input_data = torch.linspace(0, 1, 100).float()
true_label = step_function(input_data, paras['K']).float().view(-1)

# ======================== define model ====================
model      = Discretization(K=paras['K'], eps=paras['eps'],  random_seed=paras['random_seed_1'])
criterion  = torch.nn.MSELoss(reduction="sum")
criterion1 = torch.nn.MSELoss(reduction="mean")
optimizer  = torch.optim.Adam(model.parameters(), lr = paras['learning_rate'])

t          = 0
loss_val   = 1.0
while (t <= paras['max_shots']) and (loss_val > 0.01):
    y_pred     = model.train(input_data)
    loss       = criterion(y_pred, true_label)
    loss_val   = loss.item()
    print('y_pred: ', y_pred, flush=True)
    print('loss sum: ', loss_val, flush=True)
    print('loss means: ', criterion1(y_pred, true_label).item(), flush=True)
    if (t % 1 == 0):
        print(f"---- iter: {t}, loss: {round(loss_val, 4)} -----", flush=True)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if t % 10 == 0:
        torch.save(model.state_dict(), paras['ckpt_dir'] + 'model_K2_{}.pth'.format(t))
    t += 1
torch.save(model.state_dict(), paras['ckpt_dir'] + 'model_K2_{}.pth'.format(t))




