from generate_data import *
from training import *
from projection_module import *
from visualization import *
import numpy as np
from mpl_toolkits.mplot3d import Axes3D

def training_main():
    train_opts = {
        "num_epochs": 1000,
        "lr": 1e-3,
        'lr_step': 20,
        'lr_gamma': 0.8,
        "batch_size": 512,
        "loss": 'l1',
        "weight_decay": 0 #.000001
    }
    train_ds, val_ds = get_data_loader()
    model = Projection(num_particles = 16, dimension = 2, iter = 8).cuda()
    exp_dir = "models\\"
    print(model)
    train(model, train_ds, val_ds, train_opts=train_opts, exp_dir=exp_dir)
    # state = { 'state': model.state_dict() }
    save(model.state_dict(), "model.pt")

def test_main():
    model_path = 'models/checkpoint _10.pt'
    model = torch.load(model_path).cpu()
    # model = Projection(num_particles = 1, dimension = 2, iter = 3).cuda()
    # model.load_state_dict(model_state)
    start = 8192 - 100
    d, l = read_data(start, start+100)
    data = d[:,:,:]
    label = l[:,:,:]
    pred = model(data)
    pred2 = model(pred)
    print(d.size(), l.size(), pred.size())
    pred = pred.detach().numpy()
    pred2 = pred2.detach().numpy()
    data = data.detach().numpy()
    label = label.numpy()

    print(label[0,:,:])

    for i in range(pred.shape[0]):
        plt.scatter(data[i,:,0], data[i,:,1], c = 'y')
        plt.scatter(label[i,:,0], label[i,:,1], c = 'g')
        plt.scatter(pred[i,:,0], pred[i,:,1], c = 'b')
        plt.scatter(pred2[i,:,0], pred2[i,:,1], c = 'r', alpha = 0.5)
        error = np.sum(np.sum((label[i,:,0] - pred[i,:,0])**2))
        plt.title("frame: " + str(start) +  ' + ' + str(i) + '; error: ' + str(error))
        plt.show()
        # if (error > 1e-3) :
        #     plt.savefig("bad_results/" + str(start) +  '-' + str(i) + '-' + str(error) + '.jpg')
        plt.clf()
    
if __name__ == '__main__':
    training_main()
    # test_main()