#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Jul  8 14:20:43 2020

@author: zw
"""


import os
from skimage import io, transform
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms#, utils
# import torch.optim as optim

import numpy as np
from PIL import Image
import glob

from data_loader import RescaleT
from data_loader import ToTensorLab
from data_loader import CObjDataset

from model import BaseFG, BaseFGM
import os
import cv2
import argparse

import pytorch_iou

parser = argparse.ArgumentParser(description='PyTorch Training')

parser.add_argument('--modelname', default='11', type=str) 
parser.add_argument('--gpuname', default='0', type=str) 
parser.add_argument('--batchsize', default=2, type=int) 

parser.add_argument('--level', default='E', type=str)

parser.add_argument('--tag', default='0', type=str)

args = parser.parse_args()

os.environ["CUDA_VISIBLE_DEVICES"] = args.gpuname

def normPRED(d):
	ma = torch.max(d)
	mi = torch.min(d)

	dn = (d-mi)/(ma-mi)

	return dn

def save_output(image_name, pred, d_dir):

    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()
    
    predict_np = predict_np * 255
    
	#im = Image.fromarray(predict_np*255).convert('RGB')
    img_name = image_name.split("/")[-1]
    image = io.imread(image_name)
    
    imo = cv2.resize(predict_np, (image.shape[1], image.shape[0]))

    aaa = img_name.split(".")
    bbb = aaa[0:-1]
    imidx = bbb[0]
    for i in range(1,len(bbb)):
        imidx = imidx + "." + bbb[i]
        
    cv2.imwrite(d_dir+imidx+'.png', imo)

def save_out(save_name_list, save_img_list, d_dir):    
    
    n = len(save_name_list)
    
    for i in range(n):
        image_name = save_name_list[i]
        predict_np = save_img_list[i]
        
        img_name = image_name.split("/")[-1]
        image = io.imread(image_name)
        
        imo = cv2.resize(predict_np, (image.shape[1], image.shape[0]))
    
        aaa = img_name.split(".")
        bbb = aaa[0:-1]
        imidx = bbb[0]
        for i in range(1,len(bbb)):
            imidx = imidx + "." + bbb[i]
            
        cv2.imwrite(d_dir+imidx+'.png', imo)

# --------- 1. get image path and name ---------

Model_path = args.modelname

basedir = os.getcwd()

image_dir = basedir + '/dataset/FGDataset' + args.level +'/TestDataset/ToyFG/Imgs/' 
prediction_dir = basedir + '/Prediction' + args.tag + '/'
model_dir = basedir + '/model_save' + args.tag + '/' + Model_path  # model name
print(image_dir)
img_name_list = glob.glob(image_dir + '*.jpg')
img_list_len = len(img_name_list)

# --------- 2. dataloader ---------
#1. dataload
test_codobj_dataset = CObjDataset(img_name_list=img_name_list, lbl_name_list=[], edg_name_list=[], transform=transforms.Compose([RescaleT(224),ToTensorLab(flag=0, state='custom')]))
test_codobj_dataloader = DataLoader(test_codobj_dataset, batch_size=args.batchsize, shuffle=False, num_workers=2)

# --------- 3. model define ---------
print("...Load NetWork...")

net = BaseFGM()

net.load_state_dict(torch.load(model_dir))

if torch.cuda.is_available():
	net.cuda()

net.eval()

save_name_list = []
save_img_list = []

i_img = 0

# --------- 4. inference for each image ---------
for i_test, data_test in enumerate(test_codobj_dataloader):

    #print("inferencing:",img_name_list[i_test].split("/")[-1])
    if i_test % 30 == 0:
        print("[Test OutPut:  %.2f %%]" % (i_test * args.batchsize * 100 / img_list_len), end="\r", flush=True)
    inputs_test = data_test['image']
    inputs_test = inputs_test.type(torch.FloatTensor)

    if torch.cuda.is_available():
        inputs_test = Variable(inputs_test.cuda())
    else:
        inputs_test = Variable(inputs_test)
    
    D = net(inputs_test)
    d1 = D[0][0]
    
	# normalization
    pred = d1[:,0,:,:]
    pred = normPRED(pred)
    
    predict = pred
    predict = predict.squeeze()
    predict_np = predict.cpu().data.numpy()
    
    predict_np = predict_np * 255
    
    (bat_n, _, _ ) = np.shape(predict_np)
    for i in range(bat_n):
        img_pre = predict_np[i, :, :]
        img_name = img_name_list[i_img]
        save_name_list.append(img_name)
        save_img_list.append(img_pre)
        i_img = i_img + 1

    del D

save_out(save_name_list, save_img_list, prediction_dir)
    