import numpy as np
import tensorflow as tf
from edgeml.graph.rnn import FastGRNNCell, FastRNNCell

class DataIngest:
    def __init__(self, x_dim, y_dim, prefetchNum=5):
        self.scope = 'divide-conquer/'
        self.x_dim = x_dim
        self.y_dim = y_dim
        self.prefetchNum = 5
        self.graph = None

        self.X = None
        self.Y = None
        self.batchSize = None
        self.numEpochs  = None
        self.dataset_init = None
        self.x_batch, self.y_batch = None, None
        self.__graphCreated = False

    def createPipeline(self):
        assert self.__graphCreated is False
        scope = self.scope + 'input-pipeline/'
        with tf.name_scope(scope):
            X = tf.placeholder(tf.float32, self.x_dim, name='inpX')
            Y = tf.placeholder(tf.float32, self.y_dim, name='inpY')
            batchSize = tf.placeholder(tf.int64, name='batch-size')
            numEpochs = tf.placeholder(tf.int64, name='num-epochs')
            dataset_x_target = tf.data.Dataset.from_tensor_slices(X)
            dataset_y_target = tf.data.Dataset.from_tensor_slices(Y)
            couple = (dataset_x_target, dataset_y_target)
            ds_target = tf.data.Dataset.zip(couple).repeat(numEpochs)
            ds_target = ds_target.batch(batchSize)
            ds_target = ds_target.prefetch(self.prefetchNum)
            ds_iterator_target = tf.data.Iterator.from_structure(ds_target.output_types,
                                                                 ds_target.output_shapes)
            ds_next_target = ds_iterator_target
            ds_init_target = ds_iterator_target.make_initializer(ds_target,
                                                                 name='dataset-init')
            x_batch, y_batch = ds_iterator_target.get_next()
            tf.add_to_collection('next-x-batch', x_batch)
            tf.add_to_collection('next-y-batch', y_batch)
        self.X = X
        self.Y = Y
        self.batchSize = batchSize
        self.numEpochs = numEpochs
        self.dataset_init = ds_init_target
        self.x_batch, self.y_batch = x_batch, y_batch
        self.__graphCreated = True

    def restoreFromGraph(self, graph):
        assert self.__graphCreated is False
        scope = self.scope + 'input-pipeline/'
        self.X = graph.get_tensor_by_name(scope + "inpX:0")
        self.Y = graph.get_tensor_by_name(scope + "inpY:0")
        self.batchSize = graph.get_tensor_by_name(scope + "batch-size:0")
        self.numEpochs = graph.get_tensor_by_name(scope + "num-epochs:0")
        self.dataset_init = graph.get_operation_by_name(scope + "dataset-init")
        self.x_batch = graph.get_collection('next-x-batch')
        self.y_batch = graph.get_collection('next-y-batch')
        msg = 'More than one tensor named next-x-batch/next-y-batch. '
        msg += 'Are you not resetting your graph?'
        assert len(self.x_batch) == 1, msg
        assert len(self.y_batch) == 1, msg
        self.x_batch = self.x_batch[0]
        self.y_batch = self.y_batch[0]
        self.__graphCreated = True

    def __call__(self, graph=None):
        self.graph = graph
        if self.__graphCreated:
            return self.x_batch, self.y_batch
        if self.graph is None:
            self.createPipeline()
        else:
            self.restoreFromGraph(self.graph)
        return self.x_batch, self.y_batch

    def runInitializer(self, sess, x_data, y_data, batchSize, numEpochs):
        '''
        This method is used to ingest data by the dataset API. Call this method
        with the data matrices after the graph has been initialized.
        x_data, y_data, batchSize: Self explanatory.
        numEpochs: The Tensorflow dataset API implements iteration over epochs
            by appending the data to itself numEpochs times and then iterating
            over the resulting data as if it was a single data set.
        '''
        msg = "Graph not created, please invoke __call__"
        assert self.__graphCreated is True, msg
        msg = 'X shape should be %r' % self.x_dim
        assert np.array_equal(x_data.shape[1:], self.x_dim[1:]), msg
        msg = 'X and Y should have same first dimension'
        assert y_data.shape[0] == x_data.shape[0], msg
        msg = 'Y shape should be %r' % self.y_dim
        assert np.array_equal(y_data.shape[1:], self.y_dim[1:]), msg
        feed_dict = {
            self.X: x_data,
            self.Y: y_data,
            self.batchSize: batchSize,
            self.numEpochs: numEpochs
        }
        assert self.dataset_init is not None, 'Internal error!'
        sess.run(self.dataset_init, feed_dict=feed_dict)

class DNN_Network:
    def __init__(self, numInput, hidden1, hidden2, numOutput):
        self.numInput = numInput
        self.hidden1 = hidden1
        self.hidden2 = hidden2
        self.numOutput = numOutput

        self.__weights = None
        self.__biases = None
        self.__logits = None

    def getLogits(self):
        return self.__logits

    def getWeightsAndBiases(self):
        return self.__weights, self.__biases

    def __forward(self, x_batch):
        numInput = self.numInput
        numOutput = self.numOutput
        hidden1 = self.hidden1
        hidden2 = self.hidden2
        w1_ = tf.random_normal(shape=[numInput, hidden1])
        w2_ = tf.random_normal(shape=[hidden1, hidden2])
        w3_ = tf.random_normal(shape=[hidden2, numOutput])
        b1_ = tf.random_normal(shape=[hidden1])
        b2_ = tf.random_normal(shape=[hidden2])
        b3_ = tf.random_normal(shape=[numOutput])
        weights = {
            'w1': tf.Variable(w1_, name='w1'),
            'w2': tf.Variable(w2_, name='w2'),
            'w3': tf.Variable(w3_, name='w3'),
        }
        biases = {
            'b1': tf.Variable(b1_, name='b1'),
            'b2': tf.Variable(b2_, name='b2'),
            'b3': tf.Variable(b3_, name='b3')
        }
        l1 = tf.matmul(x_batch, weights['w1']) + biases['b1']
        l1 = tf.nn.softmax(l1)
        l2 = tf.matmul(l1, weights['w2']) + biases['b2']
        l2 = tf.nn.softmax(l2)
        logits = tf.matmul(l2, weights['w3']) + biases['b3']
        self.__weights = weights
        self.__biases = biases
        self.__logits = logits
        return logits

    def __call__(self, x_batch):
        logits = self.__forward(x_batch)
        return logits

class LSTM_Network:
    def __init__(self, numTimeSteps, numInput, numHidden, numOutput,
                 useDropout=False, graph=None, earlyMode=False):
        self.numTimeSteps = numTimeSteps
        self.numInput = numInput
        self.numHidden = numHidden
        self.numOutput = numOutput
        self.useDropout = useDropout
        self.graph = graph
        self.earlyMode = earlyMode

        self.keep_prob = None
        if useDropout is False:
            self.keep_prob = 1.0
        self.__weights = None
        self.__biases = None
        self.__logits = None
        self.__graphCreated = False

    def getLogits(self):
        return self.__logits

    def getWeightsAndBiases(self):
        return self.__weights, self.__biases

    def __forward(self, x_batch):
        assert self.__graphCreated is False
        numHidden = self.numHidden
        numInput = self.numInput
        numTimeSteps = self.numTimeSteps
        numOutput = self.numOutput
        if self.keep_prob is None:
            assert self.useDropout == True
            self.keep_prob = tf.placeholder(tf.float32, name='keep_prob')
        else:
            assert self.useDropout is False
            self.keep_prob = 1.0
        w1_ = tf.random_normal(shape=[numHidden, numOutput])
        b1_ = tf.random_normal(shape=[numOutput])
        weights = {
            'w1': tf.Variable(w1_, name='w1'),
        }
        biases = {
            'b1': tf.Variable(b1_, name='b1'),
        }
        x = tf.unstack(x_batch, numTimeSteps, 1)
        lstm_cell = tf.contrib.rnn.BasicLSTMCell(numHidden, forget_bias=1.0,
                                                 name='cell')
        wrapped_cell = tf.contrib.rnn.DropoutWrapper(lstm_cell,
                                                     input_keep_prob=self.keep_prob,
                                                     output_keep_prob=self.keep_prob)
        # We can use dropout at input layer. Refer -
        # https://github.com/keras-team/keras/issues/96
        outputs, states = tf.contrib.rnn.static_rnn(wrapped_cell, x, dtype=tf.float32)
        # Outputs is list of elements each of dimension [-1, HIDDEN_DIM]
        if not self.earlyMode:
            logits = tf.matmul(outputs[-1], weights['w1'])
            logits = tf.add(logits, biases['b1'], name='logits')
            self.__logits = logits
        else:
            outputs__ = []
            for output in outputs:
                outputs__.append(tf.expand_dims(output, axis=1))
            # outputs__ is a list of elements of dim [-1, 1, HIDDEN_DIM]
            outputs = tf.concat(outputs__, axis=1)
            # outputs is a Tensor of elements of dim [-1, NUM_TS, HIDDEN_DIM]
            logits = tf.tensordot(outputs, weights['w1'], axes=1)
            logits = tf.add(logits, biases['b1'], name='logits')
            self.__logits = logits
        self.__weights, self.__biases = [], []
        kernel, bias = lstm_cell.variables
        self.__weights.append(kernel)
        self.__weights.append(weights['w1'])
        self.__biases.append(bias)
        self.__biases.append(biases['b1'])
        self.__graphCreated = True
        return logits

    def __restore(self, graph):
        assert self.__graphCreated is False
        self.__logits = graph.get_tensor_by_name('logits:0')
        b = graph.get_tensor_by_name('b1:0')
        w = graph.get_tensor_by_name('w1:0')
        if self.useDropout:
            self.keep_prob = graph.get_tensor_by_name('keep_prob:0')
        self.__weights, self.__biases = [], []
        self.__weights.append(w)
        self.__biases.append(b)
        self.__graphCreated = True

    def __call__(self, x_batch):
        if self.__graphCreated:
            return self.__logits
        if self.graph is None:
            self.__forward(x_batch)
        else:
            self.__restore(self.graph)
        return self.__logits

class StackedLSTM_Network:
    def __init__(self, numBricks, numBrickTimeSteps, numInput, numHidden0,
                 numHidden1, numOutput, useDropout=False, initializer=None):
        '''
        A two layer-divide and conquer LSTM network.
        The input data dictates the stacking ideally though this code is
        independent of it. The data format is as follows:
            [-1, NUM_BRICKS, NUM_BRICKS, NUM_INPUT]

        numHidden0: Brick level LSTM
        numHidden1: Post-brick LSTM

        Here, the LSTM architecture at layer-0 will create NUM_BRICKS hidden
        states --- one for each group. Then a secondary layer consumes this to
        produce the requisite outputs.
        '''
        self.XAVIER = 'xavier'
        self.numBricks = numBricks
        self.numBrickTimeSteps = numBrickTimeSteps
        self.numInput = numInput
        self.numHidden0 = numHidden0
        self.numHidden1 = numHidden1
        self.numOutput = numOutput
        self.useDropout = useDropout
        self.initializer = initializer
        # leaf
        self.keep_prob0 = None
        # parent
        self.keep_prob1 = None

        self.__weights = None
        self.__biases = None
        self.__logits = None

    def getLogits(self):
        return self.__logits

    def getWeightsAndBiases(self):
        '''
        Returns kernel0, bias0, kernel1, bias1, fcW, fcB
        '''
        return self.__weights, self.__biases

    def __forward(self, x_batch):
        numHidden0 = self.numHidden0
        numHidden1 = self.numHidden1
        numInput = self.numInput
        numTimeSteps = self.numBrickTimeSteps
        numOutput = self.numOutput
        numBricks = self.numBricks
        batchSize = tf.shape(x_batch)[0]

        if self.initializer == self.XAVIER:
            print("Using xavier initialization")
            iti = tf.contrib.layers.xavier_initializer()
            w1_ = iti(shape=[numHidden1, numOutput])
            b1_ = iti(shape=[numOutput])
        else:
            w1_ = tf.random_normal(shape=[numHidden1, numOutput])
            b1_ = tf.random_normal(shape=[numOutput])
        W1 = tf.Variable(w1_, name='w1')
        B1 = tf.Variable(b1_, name='b1')
        assert self.keep_prob0 is None
        assert self.keep_prob1 is None
        if self.useDropout is True:
            self.keep_prob0 = tf.placeholder(tf.float32, name='keep_prob0')
            self.keep_prob1 = tf.placeholder(tf.float32, name='keep_prob1')
        else:
            self.keep_prob0 = 1.0
            self.keep_prob1 = 1.0

        lstm_cell0 = tf.contrib.rnn.BasicLSTMCell(numHidden0, forget_bias=1.0,
                                                  name='lstm_cell_0')
        wcell0 = tf.contrib.rnn.DropoutWrapper(lstm_cell0,
                                               input_keep_prob=self.keep_prob0,
                                               output_keep_prob=self.keep_prob0)
        lstm_cell1 = tf.contrib.rnn.BasicLSTMCell(numHidden1, forget_bias=1.0,
                                                  name='lstm_cell_1')
        wcell1 = tf.contrib.rnn.DropoutWrapper(lstm_cell1,
                                               input_keep_prob=self.keep_prob1,
                                               output_keep_prob=self.keep_prob1)
        # Unstack to get a list of 'numBricks' tensors of shape (batch_size,
        # numBrickTimesteps, numOutput)
        bricks = tf.unstack(x_batch, self.numBricks, 1)
        final_output_state_list = []
        for i, brick in enumerate(bricks):
            print('leaf: %d '% i, end='')
            with tf.name_scope('lstm0_%s' % i ):
                input0 = tf.unstack(brick, numTimeSteps, 1)
                state0 = lstm_cell0.zero_state(batchSize, tf.float32)
                for j in range(numTimeSteps):
                    output0, state0 = wcell0(input0[j], state0)
                # # lstm_cell returns output and state
                # #   output is a single tensor [batchSize, hiddenDim]
                # #   states is a pair of tensors, each of [batchSize, hiddenDim]
                # # static_rnn returns outputs and state
                # #   outputs is a list of tensors of size [batchSize, hiddenDim]
                # #   states is a pair of tensors, the state at the last step
                # outputs0, state0 = tf.contrib.rnn.static_rnn(wcell0, input0,
                                                             # dtype=tf.float32)
                # final_output_state_list.append((outputs0[-1], state0))
                final_output_state_list.append((output0, state0))
        # Create layer 1
        with tf.name_scope('lstm1'):
            state1 = lstm_cell1.zero_state(batchSize, tf.float32)
            for i in range(numBricks):
                print('parent: %d '% i, end='')
                out, _ = final_output_state_list[i]
                output1, state1 = wcell1(out, state1)
        print()
        logits = tf.matmul(output1, W1)
        logits = tf.add(logits, B1, name='logits')
        self.__weights, self.__biases = [], []
        kernel, bias = lstm_cell0.variables
        self.__weights.append(kernel)
        self.__biases.append(bias)
        kernel, bias = lstm_cell1.variables
        self.__weights.append(kernel)
        self.__biases.append(bias)
        self.__weights.append(W1)
        self.__biases.append(B1)
        self.__logits = logits
        return logits

    def __call__(self, x_batch):
        assert x_batch.get_shape().ndims == 4
        logits = self.__forward(x_batch)
        return logits

class StackedLSTMFC_Network:
    def __init__(self, numBricks, numBrickTimeSteps, numInput, numHidden0,
                 numFC, numOutput, useDropout=False, initializer=None):
        '''
        A two layer-divide and conquer LSTM network where the first layer is an
        LSTM while the second layer is a FC network.

        The input data dictates the stacking ideally though this code is
        independent of it. The data format is as follows:
            [-1, NUM_BRICKS, NUM_BRICKS, NUM_INPUT]

        numHidden0: Brick level LSTM
        numFC: Post brick FC dimensions

        Here, the LSTM architecture at layer-0 will create NUM_BRICKS hidden
        states --- one for each group. Then a secondary layer consumes this to
        produce the requisite outputs.
        '''
        self.XAVIER = 'xavier'
        self.numBricks = numBricks
        self.numBrickTimeSteps = numBrickTimeSteps
        self.numInput = numInput
        self.numHidden0 = numHidden0
        self.numFC = numFC
        self.numOutput = numOutput
        self.useDropout = useDropout
        self.initializer = initializer
        # leaf
        self.keep_prob0 = None
        self.__weights = None
        self.__biases = None
        self.__logits = None

    def getLogits(self):
        return self.__logits

    def getWeightsAndBiases(self):
        '''
        Returns kernel0, bias0, kernel1, bias1, fcW, fcB
        '''
        return self.__weights, self.__biases

    def __forward(self, x_batch):
        numHidden0 = self.numHidden0
        numFC = self.numFC
        numInput = self.numInput
        numTimeSteps = self.numBrickTimeSteps
        numOutput = self.numOutput
        numBricks = self.numBricks
        batchSize = tf.shape(x_batch)[0]

        if self.initializer == self.XAVIER:
            print("Using xavier initialization")
            iti = tf.contrib.layers.xavier_initializer()
            w0_ = iti(shape=[numHidden0 * numBricks, numFC])
            b0_ = iti(shape=[numFC])
            w1_ = iti(shape=[numFC, numOutput])
            b1_ = iti(shape=[numOutput])
        else:
            w0_ = tf.random_normal(shape=[numHidden0 * numBricks, numFC])
            b0_ = tf.random_normal(shape=[numFC])
            w1_ = tf.random_normal(shape=[numFC, numOutput])
            b1_ = tf.random_normal(shape=[numOutput])
        W0 = tf.Variable(w0_, name='w0')
        B0 = tf.Variable(b0_, name='b0')
        W1 = tf.Variable(w1_, name='w1')
        B1 = tf.Variable(b1_, name='b1')
        assert self.keep_prob0 is None
        if self.useDropout is True:
            self.keep_prob0 = tf.placeholder(tf.float32, name='keep_prob0')
        else:
            self.keep_prob0 = 1.0

        lstm_cell0 = tf.contrib.rnn.BasicLSTMCell(numHidden0, forget_bias=1.0,
                                                  name='lstm_cell_0')
        wcell0 = tf.contrib.rnn.DropoutWrapper(lstm_cell0,
                                               input_keep_prob=self.keep_prob0,
                                               output_keep_prob=self.keep_prob0)
        # Unstack to get a list of 'numBricks' tensors of shape (batch_size,
        # numBrickTimesteps, numOutput)
        bricks = tf.unstack(x_batch, self.numBricks, 1)
        final_output_state_list = []
        for i, brick in enumerate(bricks):
            print('leaf: %d '% i, end='')
            with tf.name_scope('lstm0_%s' % i ):
                input0 = tf.unstack(brick, numTimeSteps, 1)
                state0 = lstm_cell0.zero_state(batchSize, tf.float32)
                for j in range(numTimeSteps):
                    output0, state0 = wcell0(input0[j], state0)
                final_output_state_list.append((output0, state0))
        # outList is a list of elements of shape [batch_size, nhidden]
        outList = [elem[0] for elem in final_output_state_list]
        layer_0_out = tf.concat(outList, axis=1)
        assert layer_0_out.shape[1] == numBricks * numHidden0
        layer_1_out = tf.matmul(layer_0_out, W0) + B0
        output1 = tf.nn.sigmoid(layer_1_out, name='sigmoid0')
        logits = tf.matmul(output1, W1)
        logits = tf.add(logits, B1, name='logits')

        self.__weights, self.__biases = [], []
        kernel, bias = lstm_cell0.variables
        self.__weights.append(kernel)
        self.__biases.append(bias)
        self.__weights.append(W0)
        self.__biases.append(B0)
        self.__weights.append(W1)
        self.__biases.append(B1)
        self.__logits = logits
        return logits

    def __call__(self, x_batch):
        assert x_batch.get_shape().ndims == 4
        logits = self.__forward(x_batch)
        return logits

class StackedFastCellFC_Network:
    def __init__(self, numBricks, numBrickTimeSteps, numInput, numHidden0,
                 numFC, numOutput, celltype='FastRNN', useDropout=False,
                 initializer=None, gate_non_linearity='sigmoid',
                 update_non_linearity='sigmoid', wRank=None, uRank=None,
                 alphaInit=-3.0, betaInit=3.0, zetaInit=1.0, nuInit=-4.0):
        '''
        A two layer-divide and conquer LSTM network.
        The input data dictates the stacking ideally though this code is
        independent of it. The data format is as follows:
            [-1, NUM_BRICKS, NUM_BRICKS, NUM_INPUT]

        numHidden0: Brick level LSTM
        numHidden1: Post-brick LSTM
        initializer: in [None, 'xavier']

        Here, the LSTM architecture at layer-0 will create NUM_BRICKS hidden
        states --- one for each group. Then a secondary layer consumes this to
        produce the requisite outputs.
        '''
        assert celltype in ['FastRNN', 'FastGRNN']
        self.cell_type = celltype
        self.XAVIER = 'xavier'
        self.numBricks = numBricks
        self.numBrickTimeSteps = numBrickTimeSteps
        self.numInput = numInput
        self.numHidden0 = numHidden0
        self.numFC = numFC
        self.numOutput = numOutput
        self.useDropout = useDropout
        self.initializer = initializer
        self.gate_non_linearity = gate_non_linearity
        self.update_non_linearity = update_non_linearity
        self.wRank, self.uRank = wRank, uRank
        self.zetaInit, self.nuInit = zetaInit, nuInit
        self.alphaInit, self.betaInit = alphaInit, betaInit
        # leaf
        self.keep_prob0 = None
        # parent
        self.__weights = None
        self.__biases = None
        self.__logits = None

    def getLogits(self):
        return self.__logits

    def getWeightsAndBiases(self):
        '''
        Returns
        FastRNN
            weights: W0, U0, alpha, beta, W1, U1, alpha, beta
                (can be split based on low-rank parameter)
            biases:
                B_h
        '''
        return self.__weights, self.__biases

    def __getCells(self, x_batch ):
        numHidden0 = self.numHidden0
        numInput = self.numInput
        numTimeSteps = self.numBrickTimeSteps
        numOutput = self.numOutput
        numBricks = self.numBricks
        batchSize = tf.shape(x_batch)[0]
        if self.cell_type == 'FastGRNN':
            cell0 = FastGRNNCell(numHidden0,
                                 gate_non_linearity=self.gate_non_linearity,
                                 update_non_linearity =
                                 self.update_non_linearity, wRank=self.wRank,
                                 uRank=self.uRank, zetaInit=self.zetaInit,
                                 nuInit=self.nuInit, name='FastGRNN_cell_0')
            wcell0 = tf.contrib.rnn.DropoutWrapper(cell0,
                                                   input_keep_prob=self.keep_prob0,
                                                   output_keep_prob=self.keep_prob0)
            return cell0, wcell0
        # We default to FastRNN
        cell0 = FastRNNCell(numHidden0,
                            update_non_linearity=self.update_non_linearity,
                            wRank=self.wRank, uRank=self.uRank,
                            alphaInit=self.alphaInit, betaInit=self.betaInit,
                            name='FastRNN_cell_0')
        wcell0 = tf.contrib.rnn.DropoutWrapper(cell0,
                                               input_keep_prob=self.keep_prob0,
                                               output_keep_prob=self.keep_prob0)
        return cell0, wcell0

    def __forward(self, x_batch, use_non_linear):
        numHidden0 = self.numHidden0
        numFC = self.numFC
        numInput = self.numInput
        numTimeSteps = self.numBrickTimeSteps
        numOutput = self.numOutput
        numBricks = self.numBricks
        batchSize = tf.shape(x_batch)[0]

        if self.initializer == self.XAVIER:
            print("Using xavier initialization")
            iti = tf.contrib.layers.xavier_initializer()
            w0_ = iti(shape=[numHidden0 * numBricks, numFC])
            b0_ = iti(shape=[numFC])
            w1_ = iti(shape=[numC, numOutput])
            b1_ = iti(shape=[numOutput])
        else:
            w0_ = tf.random_normal(shape=[numHidden0 * numBricks, numFC])
            b0_ = tf.random_normal(shape=[numFC])
            w1_ = tf.random_normal(shape=[numFC, numOutput])
            b1_ = tf.random_normal(shape=[numOutput])
        W0 = tf.Variable(w0_, name='w0')
        B0 = tf.Variable(b0_, name='b0')
        W1 = tf.Variable(w1_, name='w1')
        B1 = tf.Variable(b1_, name='b1')
        assert self.keep_prob0 is None
        if self.useDropout is True:
            self.keep_prob0 = tf.placeholder(tf.float32, name='keep_prob0')
        else:
            self.keep_prob0 = 1.0

        cell0, wcell0 = self.__getCells(x_batch)
        # Unstack to get a list of 'numBricks' tensors of shape (batch_size,
        # numBrickTimesteps, numOutput)
        bricks = tf.unstack(x_batch, self.numBricks, 1)
        final_output_state_list = []
        for i, brick in enumerate(bricks):
            print('leaf: %d '% i, end='')
            with tf.name_scope('fast0_%s' % i ):
                input0 = tf.unstack(brick, numTimeSteps, 1)
                state0 = cell0.zero_state(batchSize, tf.float32)
                for j in range(numTimeSteps):
                    output0, state0 = wcell0(input0[j], state0)
                final_output_state_list.append((output0, state0))

        outList = [elem[0] for elem in final_output_state_list]
        layer_0_out = tf.concat(outList, axis=1)
        assert layer_0_out.shape[1] == numBricks * numHidden0
        layer_1_out = tf.matmul(layer_0_out, W0) + B0
        if use_non_linear:
            output1 = tf.nn.sigmoid(layer_1_out, name='sigmoid0')
        else:
            output1 = layer_1_out
        logits = tf.matmul(output1, W1)
        logits = tf.add(logits, B1, name='logits')

        self.__weights, self.__biases = [], []
        weights = cell0.getVars()
        self.__weights.append(weights)
        self.__weights.append(W0)
        self.__biases.append(B0)
        self.__weights.append(W1)
        self.__biases.append(B1)
        self.__logits = logits
        return logits

    def __restoreFromGraph(self, graph):
        def getWeights(graph, t):
            weights = []
            if self.wRank is None:
                w0 = graph.get_tensor_by_name(t + 'W:0')
                weights.append(w0)
            else:
                w0_0 = graph.get_tensor_by_name(t + 'W1:0')
                w0_1 = graph.get_tensor_by_name(t + 'W2:0')
                weights.extend([w0_0, w0_1])
            if self.uRank is None:
                u0 = graph.get_tensor_by_name(t + 'U:0')
                weights.extend([u0])
            else:
                u0_0 = graph.get_tensor_by_name(t + 'U1:0')
                u0_1 = graph.get_tensor_by_name(t + 'U2:0')
                weights.extend([u0_0, u0_1])
            return weights

        assert self.__logits is None
        assert self.keep_prob0 is None
        assert self.keep_prob1 is None
        assert self.__weights is None
        assert self.__biases is None
        logits = graph.get_tensor_by_name('logits:0')
        self.__logits = logits
        self.keep_prob0 = graph.get_tensor_by_name('keep_prob0:0')
        self.keep_prob1 = graph.get_tensor_by_name('keep_prob1:0')
        self.__weights, self.__biases = [], []
        if self.cell_type == 'FastRNN':
            t = 'fast_rnn_cell/FastRNN_cell_0/FastRNNcell/'
            weights = getWeights(graph, t)
            b0 = graph.get_tensor_by_name(t + 'B_h:0')
            alpha = graph.get_tensor_by_name(t + 'alpha:0')
            beta = graph.get_tensor_by_name(t + 'beta:0')
            self.__weights.extend(weights)
            self.__weights.extend([alpha, beta])
            self.__biases.extend([b0])
            t = 'fast_rnn_cell/FastRNN_cell_1/FastRNNcell/'
            weights = getWeights(graph, t)
            b1 = graph.get_tensor_by_name(t + 'B_h:0')
            alpha = graph.get_tensor_by_name(t + 'alpha:0')
            beta = graph.get_tensor_by_name(t + 'beta:0')
            self.__weights.extend(weights)
            self.__weights.extend([alpha, beta])
            self.__biases.extend([b1])
        elif self.cell_type == 'FastGRNN':
            raise NotImplementedError
        return self.__logits

    def __call__(self, x_batch, graph=None, use_non_linear=True):
        if graph is not None:
            return self.__restoreFromGraph(graph)
        assert x_batch.get_shape().ndims == 4
        logits = self.__forward(x_batch, use_non_linear)
        return logits

class StackedFastCell_Network:
    def __init__(self, numBricks, numBrickTimeSteps, numInput, numHidden0,
                 numHidden1, numOutput, celltype='FastRNN', useDropout=False,
                 initializer=None, gate_non_linearity='sigmoid',
                 update_non_linearity='sigmoid', wRank=None, uRank=None,
                 alphaInit=-3.0, betaInit=3.0, zetaInit=1.0, nuInit=-4.0):
        '''
        A two layer-divide and conquer LSTM network.
        The input data dictates the stacking ideally though this code is
        independent of it. The data format is as follows:
            [-1, NUM_BRICKS, NUM_BRICKS, NUM_INPUT]

        numHidden0: Brick level LSTM
        numHidden1: Post-brick LSTM
        initializer: in [None, 'xavier']

        Here, the LSTM architecture at layer-0 will create NUM_BRICKS hidden
        states --- one for each group. Then a secondary layer consumes this to
        produce the requisite outputs.
        '''
        assert celltype in ['FastRNN', 'FastGRNN']
        self.cell_type = celltype
        self.XAVIER = 'xavier'
        self.numBricks = numBricks
        self.numBrickTimeSteps = numBrickTimeSteps
        self.numInput = numInput
        self.numHidden0 = numHidden0
        self.numHidden1 = numHidden1
        self.numOutput = numOutput
        self.useDropout = useDropout
        self.initializer = initializer
        self.gate_non_linearity = gate_non_linearity
        self.update_non_linearity = update_non_linearity
        self.wRank, self.uRank = wRank, uRank
        self.zetaInit, self.nuInit = zetaInit, nuInit
        self.alphaInit, self.betaInit = alphaInit, betaInit
        # leaf
        self.keep_prob0 = None
        # parent
        self.keep_prob1 = None
        self.__weights = None
        self.__biases = None
        self.__logits = None

    def getLogits(self):
        return self.__logits

    def getWeightsAndBiases(self):
        '''
        Returns
        FastRNN
            weights: W0, U0, alpha, beta, W1, U1, alpha, beta
                (can be split based on low-rank parameter)
            biases:
                B_h
        '''
        return self.__weights, self.__biases

    def __getCells(self, x_batch ):
        numHidden0 = self.numHidden0
        numHidden1 = self.numHidden1
        numInput = self.numInput
        numTimeSteps = self.numBrickTimeSteps
        numOutput = self.numOutput
        numBricks = self.numBricks
        batchSize = tf.shape(x_batch)[0]
        if self.cell_type == 'FastGRNN':
            cell0 = FastGRNNCell(numHidden0,
                             gate_non_linearity=self.gate_non_linearity,
                             update_non_linearity = self.update_non_linearity,
                             wRank=self.wRank, uRank=self.uRank,
                             zetaInit=self.zetaInit, nuInit=self.nuInit,
                             name='FastGRNN_cell_0')
            wcell0 = tf.contrib.rnn.DropoutWrapper(cell0,
                                                   input_keep_prob=self.keep_prob0,
                                                   output_keep_prob=self.keep_prob0)
            cell1 = FastGRNNCell(numHidden1,
                             gate_non_linearity=self.gate_non_linearity,
                             update_non_linearity=self.update_non_linearity,
                             wRank=self.wRank, uRank=self.uRank,
                             zetaInit=self.zetaInit, nuInit=self.nuInit,
                             name='FastGRNN_cell_1')
            wcell1 = tf.contrib.rnn.DropoutWrapper(cell1,
                                                   input_keep_prob=self.keep_prob1,
                                                   output_keep_prob=self.keep_prob1)
            return cell0, cell1, wcell0, wcell1
        # We default to FastRNN
        cell0 = FastRNNCell(numHidden0,
                            update_non_linearity=self.update_non_linearity,
                            wRank=self.wRank, uRank=self.uRank,
                            alphaInit=self.alphaInit, betaInit=self.betaInit,
                            name='FastRNN_cell_0')
        wcell0 = tf.contrib.rnn.DropoutWrapper(cell0,
                                               input_keep_prob=self.keep_prob0,
                                               output_keep_prob=self.keep_prob0)
        cell1 = FastRNNCell(numHidden1,
                            update_non_linearity=self.update_non_linearity,
                            wRank=self.wRank, uRank=self.uRank,
                            alphaInit=self.alphaInit, betaInit=self.betaInit,
                            name='FastRNN_cell_1')
        wcell1 = tf.contrib.rnn.DropoutWrapper(cell1,
                                               input_keep_prob=self.keep_prob1,
                                               output_keep_prob=self.keep_prob1)
        return cell0, cell1, wcell0, wcell1

    def __forward(self, x_batch):
        numHidden0 = self.numHidden0
        numHidden1 = self.numHidden1
        numInput = self.numInput
        numTimeSteps = self.numBrickTimeSteps
        numOutput = self.numOutput
        numBricks = self.numBricks
        batchSize = tf.shape(x_batch)[0]

        if self.initializer == self.XAVIER:
            print("Using xavier initialization")
            iti = tf.contrib.layers.xavier_initializer()
            w1_ = iti(shape=[numHidden1, numOutput])
            b1_ = iti(shape=[numOutput])
        else:
            w1_ = tf.random_normal(shape=[numHidden1, numOutput])
            b1_ = tf.random_normal(shape=[numOutput])
        W1 = tf.Variable(w1_, name='w1')
        B1 = tf.Variable(b1_, name='b1')
        assert self.keep_prob0 is None
        assert self.keep_prob1 is None
        if self.useDropout is True:
            self.keep_prob0 = tf.placeholder(tf.float32, name='keep_prob0')
            self.keep_prob1 = tf.placeholder(tf.float32, name='keep_prob1')
        else:
            self.keep_prob0 = 1.0
            self.keep_prob1 = 1.0

        cell0, cell1, wcell0, wcell1 = self.__getCells(x_batch)
        # Unstack to get a list of 'numBricks' tensors of shape (batch_size,
        # numBrickTimesteps, numOutput)
        bricks = tf.unstack(x_batch, self.numBricks, 1)
        final_output_state_list = []
        for i, brick in enumerate(bricks):
            print('leaf: %d '% i, end='')
            with tf.name_scope('fast0_%s' % i ):
                input0 = tf.unstack(brick, numTimeSteps, 1)
                state0 = cell0.zero_state(batchSize, tf.float32)
                for j in range(numTimeSteps):
                    output0, state0 = wcell0(input0[j], state0)
                # # lstm_cell returns output and state
                # #   output is a single tensor [batchSize, hiddenDim]
                # #   states is a pair of tensors, each of [batchSize, hiddenDim]
                # # static_rnn returns outputs and state
                # #   outputs is a list of tensors of size [batchSize, hiddenDim]
                # #   states is a pair of tensors, the state at the last step
                # outputs0, state0 = tf.contrib.rnn.static_rnn(wcell0, input0,
                                                             # dtype=tf.float32)
                # final_output_state_list.append((outputs0[-1], state0))
                final_output_state_list.append((output0, state0))
        # Create layer 1
        with tf.name_scope('fast1'):
            state1 = cell1.zero_state(batchSize, tf.float32)
            for i in range(numBricks):
                print('parent: %d '% i, end='')
                out, _ = final_output_state_list[i]
                output1, state1 = wcell1(out, state1)
        print()
        logits = tf.matmul(output1, W1)
        logits = tf.add(logits, B1, name='logits')
        self.__weights, self.__biases = [], []
        weights = cell0.getVars()
        self.__weights.append(weights)
        weights = cell1.getVars()
        self.__weights.append(weights)
        self.__weights.append(W1)
        self.__biases.append(B1)
        self.__logits = logits
        return logits

    def __restoreFromGraph(self, graph):
        def getWeights(graph, t):
            weights = []
            if self.wRank is None:
                w0 = graph.get_tensor_by_name(t + 'W:0')
                weights.append(w0)
            else:
                w0_0 = graph.get_tensor_by_name(t + 'W1:0')
                w0_1 = graph.get_tensor_by_name(t + 'W2:0')
                weights.extend([w0_0, w0_1])
            if self.uRank is None:
                u0 = graph.get_tensor_by_name(t + 'U:0')
                weights.extend([u0])
            else:
                u0_0 = graph.get_tensor_by_name(t + 'U1:0')
                u0_1 = graph.get_tensor_by_name(t + 'U2:0')
                weights.extend([u0_0, u0_1])
            return weights

        assert self.__logits is None
        assert self.keep_prob0 is None
        assert self.keep_prob1 is None
        assert self.__weights is None
        assert self.__biases is None
        logits = graph.get_tensor_by_name('logits:0')
        self.__logits = logits
        self.keep_prob0 = graph.get_tensor_by_name('keep_prob0:0')
        self.keep_prob1 = graph.get_tensor_by_name('keep_prob1:0')
        self.__weights, self.__biases = [], []
        if self.cell_type == 'FastRNN':
            t = 'fast_rnn_cell/FastRNN_cell_0/FastRNNcell/'
            weights = getWeights(graph, t)
            b0 = graph.get_tensor_by_name(t + 'B_h:0')
            alpha = graph.get_tensor_by_name(t + 'alpha:0')
            beta = graph.get_tensor_by_name(t + 'beta:0')
            self.__weights.extend(weights)
            self.__weights.extend([alpha, beta])
            self.__biases.extend([b0])
            t = 'fast_rnn_cell/FastRNN_cell_1/FastRNNcell/'
            weights = getWeights(graph, t)
            b1 = graph.get_tensor_by_name(t + 'B_h:0')
            alpha = graph.get_tensor_by_name(t + 'alpha:0')
            beta = graph.get_tensor_by_name(t + 'beta:0')
            self.__weights.extend(weights)
            self.__weights.extend([alpha, beta])
            self.__biases.extend([b1])
        elif self.cell_type == 'FastGRNN':
            raise NotImplementedError
        return self.__logits

    def __call__(self, x_batch, graph=None):
        if graph is not None:
            return self.__restoreFromGraph(graph)
        assert x_batch.get_shape().ndims == 4
        logits = self.__forward(x_batch)
        return logits


class StackedRNN3_Network:
    '''
    3 Layer stacked network. Each layer can be either a LSTM or FastCell

    Note that this is not entirely tested.
    '''
    def __init__(self, numBricks, numTimeSteps0, numInput, numHidden0,
                 numHidden1, numTimeSteps1, numHidden2, numOutput,
                 useDropout=False, initializer=None, cell0=None, cell1=None,
                 cell2=None):
        '''
        A 3 layer divide and conquer LSTM network

        The input data dictates the stacking ideally though this code is
        independent of it. The data format is as follows:
            [-1, NUM_BRICKS, NUM_BRICKTIMESTEPS, NUM_INPUT]

        numHidden0: Brick level LSTM
        numHidden1: Post-brick LSTM
        initializer: in [None, 'xavier']

        Here, the LSTM architecture at layer-0 will create NUM_BRICKS hidden
        states --- one for each group. Then a secondary layer consumes this to
        produce the requisite outputs.
        '''
        self.XAVIER = 'xavier'
        self.numBricks = numBricks
        self.numTimeSteps0 = numTimeSteps0
        self.numTimeSteps1 = numTimeSteps1
        assert numBricks % numTimeSteps1 == 0
        self.numInput = numInput
        self.numHidden0 = numHidden0
        self.numHidden1 = numHidden1
        self.numTimeSteps1 = numTimeSteps1
        self.numHidden2 = numHidden2
        self.numOutput = numOutput
        self.useDropout = useDropout
        self.initializer = initializer
        self.cell0 = cell0
        self.cell1 = cell1
        self.cell2 = cell2
        # leaf
        self.keep_prob0 = None
        self.keep_prob1 = None
        self.keep_prob2 = None

        self.__weights = None
        self.__biases = None
        self.__logits = None

    def getLogits(self):
        return self.__logits

    def getWeightsAndBiases(self):
        '''
        Returns kernel0, bias0, kernel1, bias1, fcW, fcB
        '''
        return self.__weights, self.__biases

    def __forward(self, x_batch):
        numHidden0 = self.numHidden0
        numHidden1 = self.numHidden1
        numHidden2 = self.numHidden2
        numInput = self.numInput
        numTimeSteps0 = self.numTimeSteps0
        numTimeSteps1 = self.numTimeSteps1
        numOutput = self.numOutput
        numBricks0 = self.numBricks
        batchSize = tf.shape(x_batch)[0]

        if self.initializer == self.XAVIER:
            print("Using xavier initialization")
            iti = tf.contrib.layers.xavier_initializer()
            w1_ = iti(shape=[numHidden2, numOutput])
            b1_ = iti(shape=[numOutput])
        else:
            w1_ = tf.random_normal(shape=[numHidden2, numOutput])
            b1_ = tf.random_normal(shape=[numOutput])
        W1 = tf.Variable(w1_, name='w1')
        B1 = tf.Variable(b1_, name='b1')
        assert self.keep_prob0 is None
        assert self.keep_prob1 is None
        assert self.keep_prob2 is None
        if self.useDropout is True:
            self.keep_prob0 = tf.placeholder(tf.float32, name='keep_prob0')
            self.keep_prob1 = tf.placeholder(tf.float32, name='keep_prob1')
            self.keep_prob2 = tf.placeholder(tf.float32, name='keep_prob2')
        else:
            self.keep_prob0 = 1.0
            self.keep_prob1 = 1.0
            self.keep_prob2 = 1.0
        if self.cell0 == None:
            cell0 = tf.contrib.rnn.BasicLSTMCell(numHidden0, forget_bias=1.0,
                                                 name='lstm_cell_0')
            self.cell0 = cell0
        else:
            cell0 = self.cell0
        wcell0 = tf.contrib.rnn.DropoutWrapper(cell0,
                                               input_keep_prob=self.keep_prob0,
                                               output_keep_prob=self.keep_prob0)
        if self.cell1 == None:
            cell1 = tf.contrib.rnn.BasicLSTMCell(numHidden1, forget_bias=1.0,
                                                 name='lstm_cell_1')
            self.cell1 = cell1
        else:
            cell1 = self.cell1
        wcell1 = tf.contrib.rnn.DropoutWrapper(cell1,
                                               input_keep_prob=self.keep_prob1,
                                               output_keep_prob=self.keep_prob1)
        if self.cell2 == None:
            cell2 = tf.contrib.rnn.BasicLSTMCell(numHidden2, forget_bias=1.0,
                                                 name='lstm_cell_2')
            self.cell2 = cell2
        else:
            cell2 = self.cell2
        wcell2 = tf.contrib.rnn.DropoutWrapper(cell2,
                                               input_keep_prob=self.keep_prob2,
                                               output_keep_prob=self.keep_prob2)

        # Unstack to get a list of 'numBricks' tensors of shape (batch_size,
        # numBrickTimesteps, numOutput)
        brick0List = tf.unstack(x_batch, numBricks0, 1)
        final_output_state_list0 = []
        for i, brick in enumerate(brick0List):
            with tf.name_scope('layer0_%s' % i ):
                input0 = tf.unstack(brick, numTimeSteps0, 1)
                state0 = cell0.zero_state(batchSize, tf.float32)
                for j in range(numTimeSteps0):
                    output0, state0 = wcell0(input0[j], state0)
            final_output_state_list0.append((output0, state0))
        assert len(final_output_state_list0) == numBricks
        brick1List = []
        # Manually unstack
        numBricks1 = int(len(final_output_state_list0) / numTimeSteps1)
        for i in range(0, numBricks1):
            start = i * numTimeSteps1
            end = (i + 1) * numTimeSteps1
            brick_i = final_output_state_list0[start:end]
            brick1List.append(brick_i)
        assert len(brick1List) == numBricks1
        # Create layer 1
        final_output_state_list1 = []
        for i, brick1 in enumerate(brick2List):
            assert len(brick1) == numTimeSteps1
            # Each brick2 is a list of tensors, each of dimension
            # [batch_size, num_hidden0]
            zstate1 = cell1.zero_state(batchSize, tf.float32)
            for j in range(numTimeSteps1):
                inp, _ = brick1[j]
                output1, state1 = wcell1(inp, zstate1)
            final_output_state_list1.append((output1, state1))

        with tf.name_scope('layer2'):
            zstate2 = cell2.zero_state(batchSize, tf.float32)
            for i in range(numBricks1):
                out, _ = final_output_state_list1[i]
                output2, state2 = wcell2(out, zstate2)

        logits = tf.matmul(output2, W1)
        logits = tf.add(logits, B1, name='logits')
        self.__weights, self.__biases = [], []
        weights = cell0.getVars()
        self.__weights.append(weights)
        weights = cell1.getVars()
        self.__weights.append(weights)
        weights = cell2.getVars()
        self.__weights.append(weights)
        self.__weights.append(W1)
        self.__biases.append(B1)
        self.__logits = logits
        return logits

    def __call__(self, x_batch):
        assert x_batch.get_shape().ndims == 4
        logits = self.__forward(x_batch)
        return logits

class LSTMPure:
    def __init__(self, numTimeSteps, kernelfile, biasfile):
        '''
        This is an LSTM implemented in pure python. A better version can be
        found in the Magnaexps package.
        '''
        self.kernelfile = kernelfile
        self.biasfile = biasfile
        self.kernel = np.load(kernelfile)
        self.bias = np.load(biasfile)
        self.numTimeSteps = numTimeSteps
        self.statesLen = int(self.kernel.shape[1] / 4)
        self.featLen = self.kernel.shape[0] - self.statesLen

    def __cell(self, x, h, c, forgetBias):
        h_ = h.copy()
        x_ = np.concatenate([x, h_], axis=1)
        combOut = np.matmul(x_, self.kernel)
        combOut = combOut + self.bias
        i, j, f, o = np.split(combOut, 4, axis=1)
        new_c = c * sigmoid(f + forgetBias) + sigmoid(i) * np.tanh(j)
        new_h = np.tanh(new_c) * sigmoid(o)
        new_o = sigmoid(o)
        c = new_c
        h = new_h
        o = new_o
        return h, c

    def __unstack(self, x, axis=0):
        unstacked_x = []
        for x_ in np.rollaxis(x, axis=axis):
            unstacked_x.append(x_)
        return unstacked_x


    def predict(self, x, forgetBias=1.0):
        assert x.ndim == 3
        assert x.shape[1] == self.numTimeSteps
        assert x.shape[2] == self.featLen
        batchSize = len(x)
        x_unstacked = self.__unstack(x, axis=1)
        h = np.zeros([batchSize, self.statesLen]).astype(np.float)
        c = np.zeros([batchSize, self.statesLen]).astype(np.float)
        outSteps = []
        for x_batch in x_unstacked:
            h, c = self.__cell(x_batch, h, c, forgetBias=forgetBias)
            outSteps.append(h)
        return outSteps
