
import torch
from torch.optim import Adam
from torch.utils.data import DataLoader

from models.StereoNet_single import StereoNet

from dataloader.exrDatasetLoader import exrImagePairDataset

import argparse as args
import collections

if __name__ == "__main__" : 
	
	parser = args.ArgumentParser(description='Finetune ActiveStereoNet on our dataset')
	
	parser.add_argument("--traindata", help="Path to the training images")
	#parser.add_argument("--validationdata", help="Path to the validation images")
	parser.add_argument("--numepochs", default=10, type=int, help="Number of epochs to run")
	parser.add_argument("--batchsize", default=4, type=int, help="Batch size")
	parser.add_argument("--numworkers", default=4, type=int, help="Number of workers threads used for loading dataset")
	parser.add_argument('--learningrate', default = 1e-2, type=float, help="Learning rate for the optimizer")
	parser.add_argument('--ramcache', action="store_true", help="cache the whole dataset into ram. Do this only if you are certain it can fit.")
	
	
	parser.add_argument('-p', '--pretrained', default='./pretrained/ps_sceneflow_checkpoint.pth', help="Pretrained weights")
	parser.add_argument('-o', '--output', default='./pretrained/sn_finetuned_sim_stereo.pth', help="Pretrained weights")
	
	args = parser.parse_args()
	
	model = StereoNet(k=3, r=4)
	model.cuda()
	
	checkpoint = torch.load(args.pretrained)
	state_dict = checkpoint['state_dict']
	
	model = torch.nn.DataParallel(model).cuda()
	
	model.load_state_dict(state_dict)
	
	cache = False
	
	dats = exrImagePairDataset(imagedir = args.traindata,
							left_nir_channel = 'Left.SimulatedNir.A', 
							right_nir_channel = 'Right.SimulatedNir.A',
							cache = cache,
							ramcache = args.ramcache,
							direction = 'l2r')
	
	datl = DataLoader(dats, 
					   batch_size= args.batchsize, 
					   shuffle=True, 
					   num_workers=args.numworkers)
	
	def buildOptimizer(parameters) :
		return Adam(parameters, lr=args.learningrate)
	
	def getLoss(c = 2) :
		return lambda d, gt_d : torch.mean(torch.sqrt(torch.square((d - gt_d)/c) + 1) - 1)
	
	optimizer = buildOptimizer(model.parameters())
	loss = getLoss()
	
	for ep in range(args.numepochs) :
		
		count = 0
		aggr = 0
		
		for batch_id, sampl in enumerate(datl) :
			
			imgLeft = sampl['frameLeft'].cuda()
			imgRight = sampl['frameRight'].cuda()
			imgGtDisp = sampl['trueDisparity'].cuda()
			
			imgLeft = torch.cat((imgLeft, imgLeft, imgLeft), dim=1)
			imgRight = torch.cat((imgRight, imgRight, imgRight), dim=1)
			
			r = model(imgLeft, imgRight)
			disp = r[0]
			
			l = loss(disp, imgGtDisp)
				
			optimizer.zero_grad()
					
			l.backward()
			optimizer.step()
			
			lval = l.item()
			
			#print(f"\tEpoch {ep}, batch {batch_id}: loss = {lval}")
			
			aggr += lval
			count += 1.
			
		print(f"Epoch {ep}: avg loss = {aggr/count}")
			
	torch.save({"state_dict" : model.state_dict()}, args.output)
