import argparse
import os
parser = argparse.ArgumentParser('Training FNO')
parser.add_argument('--lr',type=float, default=1e-3)
parser.add_argument('--epochs',type=int, default=500)
parser.add_argument('--weight_decay',type=float,default=1e-4)
parser.add_argument("--n1", type=int, default=32)
parser.add_argument("--n2", type=int, default=32)
parser.add_argument("--width", type=int, default=32, help="Width")
parser.add_argument('--batch-size',type=int, default=8)
parser.add_argument("--use_tb", type=int, default=0, help="Use TensorBoard: 1 for True, 0 for False")
parser.add_argument("--gpu", type=str, default='1', help="GPU index to use")
parser.add_argument('--max_grad_norm',type=float, default=None)
parser.add_argument('--downsample',type=int,default=5)
parser.add_argument('--ntrain',type=int, default=1000)
parser.add_argument('--ntrain2',type=int, default=100)
parser.add_argument('--dropout',type=float, default=0.)
parser.add_argument("--model", type=str, default='ofno', help="Model")
parser.add_argument('--dropout_type',type=str, default="GD", help="Dropout Type: MC for typical dropout, GD for Gaussian dropout")
args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
import torch
import numpy as np
from timeit import default_timer
from utilities3 import *
from AM_FNO import FNO2d, FNO2dMLP
from FNOs import UFNO2d , FNOFactorizedMesh2D, vannilaFNO2d
torch.manual_seed(42)
np.random.seed(42)



TRAIN_PATH = ''
TEST_PATH = ''

ntrain = args.ntrain
ntest = 200

batch_size = args.batch_size
learning_rate = args.lr
epochs = args.epochs
epochs2 = 100

r = args.downsample
h = int(((421 - 1)/r) + 1)
s = h
dx = 1.0/h

reader = MatReader(TRAIN_PATH)
x_train = reader.read_field('coeff')[:ntrain,::r,::r][:,:s,:s]
y_train = reader.read_field('sol')[:ntrain,::r,::r][:,:s,:s]


reader.load_file(TEST_PATH)
x_test = reader.read_field('coeff')[:ntest,::r,::r][:,:s,:s]
y_test = reader.read_field('sol')[:ntest,::r,::r][:,:s,:s]

x_normalizer = UnitGaussianNormalizer(x_train)
x_train = x_normalizer.encode(x_train)
x_test = x_normalizer.encode(x_test)

y_normalizer = UnitGaussianNormalizer(y_train)
y_train = y_normalizer.encode(y_train)

x_train = x_train.reshape(ntrain,s,s,1)
x_test = x_test.reshape(ntest,s,s,1)

train_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_train, y_train), batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(torch.utils.data.TensorDataset(x_test, y_test), batch_size=batch_size, shuffle=False)

################################################################
# training and evaluation
################################################################
if args.model == "ofno":
    model = FNO2d( n1 = args.n1, n2 = args.n2, width = args.width, input_dim=1, output_dim = 1, mlp_dropout = args.dropout).cuda()
elif args.model == "ufno":
    model = UFNO2d(12,12,32,input_dim=1, output_dim = 1).cuda()
elif args.model == "ffno":
    model = FNOFactorizedMesh2D(modes_x=12, modes_y=12, width=32,input_dim=1, output_dim = 1).cuda()
elif args.model == "fno":
    model = vannilaFNO2d(12,12,32,input_dim=1, output_dim = 1).cuda()
elif args.model == "fnoall":
    model = vannilaFNO2d(137,69,32,input_dim=1, output_dim = 1).cuda()  
elif args.model == "fnomlp":
    model = FNO2dMLP(n1 = args.n1, n2 = args.n2, width = args.width, input_dim=1, output_dim = 1, H = 85, W = 85).cuda()  

#model = FNOFactorizedMesh2D(modes_x=20, modes_y=20, width=32, input_dim=3).cuda()
print(count_params(model))
print(args)
optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
#optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*len(train_loader))
#scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=args.lr, epochs=epochs,steps_per_epoch=len(train_loader))
print(model)
myloss = LpLoss(size_average=False)
y_normalizer.cuda()
for ep in range(args.epochs):
    model.train()
    t1 = default_timer()
    train_l2 = 0
    for x, y in train_loader:
        x, y = x.cuda(), y.cuda()
        bsz = x.shape[0]
        optimizer.zero_grad()

        out = model(x).reshape(bsz, s, s)
        out = y_normalizer.decode(out)
        y = y_normalizer.decode(y)
        
        loss = myloss(out.view(bsz,-1), y.view(bsz,-1))
        if args.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) 
                    
        loss.backward()
        optimizer.step()
        scheduler.step()
        train_l2 += loss.item()

    model.eval()
    test_l2 = 0.0
    with torch.no_grad():
        for x, y in test_loader:
            x, y = x.cuda(), y.cuda()
            bsz = x.shape[0]
            out = model(x).reshape(bsz, s, s)
            out = y_normalizer.decode(out)

            test_l2 += myloss(out.view(bsz,-1), y.view(bsz,-1)).item()

    train_l2/= ntrain
    test_l2 /= ntest

    t2 = default_timer()
    print("Epoch : {}, time : {:.5f}, train_err : {:.5f}, test_err : {:.5f}.".format(ep, t2-t1, train_l2, test_l2))
    

    
