from generate_data import *
from training import *
from projection_module import *
from visualization import *
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

import sys
sys.stdout = open("outputs.txt", "w")

def training_main(iter_num):
    train_opts = {
        "num_epochs": 600,
        "lr": 1e-3,
        'lr_step': 20,
        'lr_gamma': 0.8,
        "batch_size": 256,
        "loss": 'l1', #'my_regu',
        "weight_decay": 0 #.000001
    }
    train_ds, val_ds = get_data_loader()
    model = Projection(num_particles = 4, dimension = 2, iter = iter_num).cuda()
    exp_dir = str(iter_num) + "models\\"
    print(model)
    print(model.iter)
    # for name, param in model.named_parameters():
    #     print(name)
    #     if param.requires_grad:
    #         print("require_grad" + name)
    train(model, train_ds, val_ds, train_opts=train_opts, exp_dir=exp_dir)
    state = {
        'state': model.state_dict()
    }

    save(state, "model.pt")

def test_main():
    model_path = 'models/checkpoint _1000.pt'
    model = torch.load(model_path).cpu()
    # model = Projection(num_particles = 1, dimension = 2, iter = 3).cuda()
    # model.load_state_dict(model_state)
    start = 1024 - 100
    d, l = read_data(start, 1024)
    # with torch.no_grad():
    #     pred = model(d)
    data = d[:,:,:]
    label = l[:,:,:]
    pred = model(data)
    pred2 = model(pred)
    print(d.size(), l.size(), pred.size())
    pred = pred.detach().numpy()
    pred2 = pred2.detach().numpy()
    data = data.detach().numpy()
    label = label.numpy()

    print(label[0,:,:])

    for i in range(pred.shape[0]):
        plt.scatter(data[i,:,0], data[i,:,1], c = 'y')
        plt.scatter(label[i,:,0], label[i,:,1], c = 'g')
        plt.scatter(pred[i,:,0], pred[i,:,1], c = 'b')
        plt.scatter(pred2[i,:,0], pred2[i,:,1], c = 'r', alpha = 0.5)
        error = np.sum(np.sum((label[i,:,0] - pred[i,:,0])**2))
        plt.title("frame: " + str(start) +  ' + ' + str(i) + '; error: ' + str(error))
        # plt.show()
        if (error > 1e-3) :
            plt.savefig("bad_results/" + str(start) +  '-' + str(i) + '-' + str(error) + '.jpg')
        plt.clf()
    
if __name__ == '__main__':
    # print("iter = 1")
    # training_main(1)
    print("iter = 2")
    training_main(2)
    # print("iter = 3")
    # training_main(3)
    # print("iter = 5")
    # training_main(5)
    # print("iter = 7")
    # training_main(7)
    # print("iter = 9")
    # training_main(9)
    # test_main()