# Copyright (c) 2017, Oren Kraus All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation and/or
# other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
import tensorflow as tf
import nn_layers
import h5py
import numpy as np
from PIL import Image
import argparse
import cellDataClass as dataClass # NO QUEUE
import preprocess_images as procIm # NO QUEUE
import copy

import argparse

parser = argparse.ArgumentParser(description='Visualize DeepLoc model on COOS7')
parser.add_argument("-l", "--logdir", action="store", dest="logdir", help="directory to save models",
                    default='./logs_COOS7/model.ckpt-4500')
parser.add_argument("-o", "--output-folder", action="store", dest="outputdir", help="directory to store results",
                    default='./output_figures')
args = parser.parse_args()
print
'log dir:', args.logdir, 'out dir:', args.outputdir

locNetCkpt = args.logdir
output_dir = args.outputdir

if not os.path.exists(locNetCkpt + '.meta'):
    raise NameError('please download pretrained model using download_datasets.sh')

def getInitImage(inputCell):
    outData = np.zeros((60, 60, 2))
    outData[:, :, 0] = inputCell[2:-2, 2:-2, 0]
    outData[:, :, 1] = inputCell[2:-2, 2:-2, 1]
    # stretch
    for chan in range(2):
        p_low = np.percentile(outData[:, :, chan], 0.1)
        p_high = np.percentile(outData[:, :, chan], 99.9)
        outData[:, :, chan] = outData[:, :, chan] - p_low
        outData[:, :, chan] = outData[:, :, chan] / (p_high - p_low)

    return outData

def accuracy_numpy(y_pred, y_lab):
    accuracy = np.mean(np.argmax(y_pred, 1) == np.argmax(y_lab, 1))
    return accuracy


loc = tf.Graph()
with loc.as_default():
    loc_saver = tf.train.import_meta_graph(locNetCkpt+'.meta')
locSession = tf.Session(graph=loc)
loc_saver.restore(locSession, locNetCkpt)

pred_loc = loc.get_tensor_by_name(u'softmax:0')
input_loc = loc.get_tensor_by_name(u'input:0')
is_training_loc = loc.get_tensor_by_name(u'is_training:0')

testsets = ['test1', 'test2', 'test3', 'test4']

for testset in testsets:
    testHdf5 = "./datasets/COOS7_" + testset + ".hdf5"
    outfile = "./COOS7_features/" + testset + ".txt"

    cropSize = 60
    batchSize = 1
    stretchLow = 0.1 # stretch chasavennels lower percentile
    stretchHigh = 99.9 # stretch channels upper percentile

    imSize = 64
    numClasses = 7
    numChan = 2
    dataset = dataClass.Data(testHdf5,['data', 'Index'],batchSize)

    data = dataset
    numberDataPoints = data.stopInd - data.startInd


    for i in range(numberDataPoints):
        crop_list = np.zeros((data.batchSize, 5, numClasses))
        batch = data.getBatch()
        processedBatch=procIm.preProcessTestImages(batch['data'],
                               imSize,cropSize,numChan,
                               rescale=False,stretch=True,
                               means=None,stds=None,
                               stretchLow=stretchLow,stretchHigh=stretchHigh)
        for crop in range(5):
            images = processedBatch[:, crop, :, :, :]
            tmp = copy.copy(locSession.run([pred_loc], feed_dict={input_loc: images, is_training_loc: False}))
            crop_list[:, crop, :] = tmp[0]

        mean_crops = np.mean(crop_list, 1)
        curAcc = accuracy_numpy(mean_crops,batch['Index'])
        output = open(outfile, "a")
        output.write(str(np.argmax(batch['Index'])))
        output.write("\t")
        output.write(str(np.argmax(mean_crops, 1)[0]))
        output.write("\t")
        output.write(str(curAcc))
        output.write("\n")
        output.close()