# Copyright (c) 2017, Oren Kraus All rights reserved.
#
# Redistribution and use in source and binary forms, with or without modification,
# are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation and/or
# other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its contributors
# may be used to endorse or promote products derived from this software without
# specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
# ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
# (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
# LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
# ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
# SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import os
os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

import tensorflow as tf
import nn_layers
import h5py
import numpy as np
from PIL import Image
import argparse


import argparse
parser = argparse.ArgumentParser(description='Visualize DeepLoc model on COOS7')
parser.add_argument("-l","--logdir",action="store",dest="logdir",help="directory to save models",
                    default='./logs_COOS7/model.ckpt-4500')
parser.add_argument("-o", "--output-folder", action="store", dest="outputdir", help="directory to store results",
                    default='./output_figures')
args = parser.parse_args()
print ('log dir:',args.logdir,'out dir:',args.outputdir)


locNetCkpt = args.logdir
output_dir = args.outputdir

if not os.path.exists(locNetCkpt+'.meta'):
    raise NameError('please download pretrained model using download_datasets.sh')


#################
# DeepLoc MODEL #
#################

is_training = tf.placeholder(tf.bool, [], name='is_training') # for batch norm
inputs = tf.placeholder('float32', shape = [60,60,2], name='inputs')  # for batch norm
labels = tf.placeholder('float32', shape = [None,19], name ='labels')

input_reshape = tf.reshape(inputs, [1, 60, 60 ,2])
conv1 = nn_layers.conv_layer(input_reshape, 3, 3, 2, 64, 1, 'conv_1', is_training=is_training)
conv2 = nn_layers.conv_layer(conv1, 3, 3, 64, 64, 1, 'conv_2', is_training=is_training)
pool1 = nn_layers.pool2_layer(conv2, 'pool1')
conv3 = nn_layers.conv_layer(pool1, 3, 3, 64, 128, 1, 'conv_3', is_training=is_training)
conv4 = nn_layers.conv_layer(conv3, 3, 3, 128, 128, 1, 'conv_4', is_training=is_training)
pool2 = nn_layers.pool2_layer(conv4, 'pool2')
conv5 = nn_layers.conv_layer(pool2, 3, 3, 128, 256, 1, 'conv_5', is_training=is_training)
conv6 = nn_layers.conv_layer(conv5, 3, 3, 256, 256, 1, 'conv_6', is_training=is_training)
conv7 = nn_layers.conv_layer(conv6, 3, 3, 256, 256, 1, 'conv_7', is_training=is_training)
conv8 = nn_layers.conv_layer(conv7, 3, 3, 256, 256, 1, 'conv_8', is_training=is_training)
pool3 = nn_layers.pool2_layer(conv8, 'pool3')
pool3_flat = tf.reshape(pool3, [-1, 8 * 8 * 256])
fc_1 = nn_layers.nn_layer(pool3_flat, 8 * 8 * 256, 512, 'fc_1', act=tf.nn.relu, is_training=is_training)
fc_2 = nn_layers.nn_layer(fc_1, 512, 512, 'fc_2', act=tf.nn.relu, is_training=is_training)
lastAct = nn_layers.nn_layer(fc_2, 512, 7, 'final_layer', act=None, is_training=is_training)


# initialize DeepLoc model
sess = tf.Session()
sess.run(tf.global_variables_initializer(),{is_training:False})

# load model checkpoint
saver = tf.train.Saver(tf.global_variables())
saver.restore(sess, locNetCkpt)

def getInitImage(inputCell):
    outData = np.zeros((60,60,2))
    outData[:,:,0] = inputCell[2:-2,2:-2,0]
    outData[:,:,1] = inputCell[2:-2,2:-2,1]
    #stretch
    for chan in range(2):
        p_low = np.percentile(outData[:,:,chan],0.1)
        p_high = np.percentile(outData[:,:,chan],99.9)
        outData[:,:,chan] = outData[:,:,chan] - p_low
        outData[:,:,chan] = outData[:,:,chan] / (p_high-p_low)

    return outData

testsets = ['train', 'test1', 'test2', 'test3', 'test4']
for testset in testsets:
    datapath = "./COO7_images/" + testset + "/"
    outpath = "./COOS7_deeploc_features/"
    layers = [(fc_1, "fc_1"), (fc_2, "fc_2")]
    for l, l_name in layers:
        for dir in os.listdir(datapath):
            for image in os.listdir(datapath + dir):
                if "_protein.tif" in image:
                    print ("Evaluating " + image)
                    gfp = np.array(Image.open(datapath + dir + "/" + image)).astype(np.float32)
                    rfp = np.array(Image.open(datapath + dir + "/" + image.replace("_protein.tif", "_nucleus.tif"))).astype(np.float32)
                    input_image = np.stack((gfp, rfp), axis=-1)
                    input_image = getInitImage(input_image)

                    layer = sess.run(l, {inputs: input_image, is_training: False})

                    prediction = np.squeeze(layer)
                    #prediction = np.max(prediction, axis=(0, 1))

                    name = image.rsplit("_", 1)[0]
                    if not os.path.exists(outpath + l_name):
                        os.makedirs(outpath + l_name)
                    outputfile = outpath + l_name + "/" + testset + ".txt"
                    output = open(outputfile, "a")
                    output.write(dir)
                    output.write("\t")
                    output.write(name)
                    output.write("\t")
                    for feat in prediction:
                        output.write(str(feat))
                        output.write("\t")
                    output.write("\n")
                    output.close()
