import os
import sys

HERE = os.path.dirname(os.path.abspath(__file__))
sys.path.insert(0, HERE)

import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.nn.functional as F
import torch.backends.cudnn as cudnn

if 'models' in sys.modules :
	sys.modules.pop('models')
if 'models.submodule' in sys.modules :
	sys.modules.pop('models.submodule')

from models.psmnet import PSMNet
from models.gwcnet import GwcNet

import numpy as np

sys.path.pop(0)

class Bunch(object):
  def __init__(self, adict):
    self.__dict__.update(adict)

def getModel(pretrained, stage=2) :
	
	ModelClass = GwcNet
	
	if pretrained.find('gwcnet-c') >= 0 :
		ModelClass = PSMNet
	
	model = ModelClass(maxdisp=192,
						ndisps=[48,24],
						disp_interval_pixel=[4,1],
						cr_base_chs=[32,32,16],
						grad_method='detach',
						using_ns=True,
						ns_size=13)
	
	model.cuda()
	
	checkpoint = torch.load(pretrained)
	model.load_state_dict(checkpoint['model'])

	model = nn.DataParallel(model)
	model.eval()
	
	def testFunc(imgLeft, imgRight) :
		outputs = model(imgLeft[np.newaxis,...].cuda(), imgRight[np.newaxis,...].cuda())
		outputs_stage = outputs["stage{}".format(stage)]
		disp = [outputs_stage["pred"]]

		return disp[-1]
	
	return testFunc
