# !/usr/bin/env python

import mxnet as mx
import numpy as np

from tools.layer_master import mx_dnn_layer

# This function is the model definition of user preference prediction model in mxnet format
def dnn_model_define(user_feature_embedding, target_node_embedding, label, bs, eb_dim, fea_groups="1,1,1,2,2,2,10,10,20,20", active_op='prelu', use_batch_norm=False):
    fea_groups = [int(s) for s in fea_groups.split(',')] # record num of each time window
    total_group_length = np.sum(np.array(fea_groups))

    user_input_before_reshape = mx.sym.concat(*user_feature_embedding)
    user_input = mx.sym.reshape(user_input_before_reshape, shape=(bs, total_group_length, eb_dim))

    layer_data = []

    # average embedding in each time window
    idx = 0
    for group_length in fea_groups:
        block_before_sum = mx.sym.slice_axis(user_input, axis=1, begin=idx, end=idx+group_length)
        block = mx.sym.sum_axis(block_before_sum, axis=1) / group_length
        if idx == 0:
            grouped_user_input = block
        else:
            grouped_user_input = mx.sym.concat(grouped_user_input, block, dim=1)
        idx += group_length

    label = mx.symbol.BlockGrad(label)
    din = mx.symbol.concat(*[grouped_user_input, target_node_embedding], dim=1)

    # three-layer DNN definition and final layer for softmax
    net_version = "DNN"
    layer_arr = []
    din_width = (len(fea_groups) + 1) * eb_dim
    layer1 = mx_dnn_layer(din_width, 128, active_op=active_op, use_batch_norm=use_batch_norm, version="%d_%s" % (1, net_version))
    layer_arr.append(layer1)
    layer2 = mx_dnn_layer(128, 64, active_op=active_op, use_batch_norm=use_batch_norm, version="%d_%s" % (2, net_version))
    layer_arr.append(layer2)
    layer3 = mx_dnn_layer(64, 24, active_op=active_op, use_batch_norm=use_batch_norm, version="%d_%s" % (3, net_version))
    layer_arr.append(layer3)
    layer4 = mx_dnn_layer(24, 2, active_op='', use_batch_norm=False, version="%d_%s" % (4, net_version))
    layer_arr.append(layer4)

    layer_data.append(din)
    for layer in layer_arr:
        layer_data.append(layer.call(layer_data[-1]))
    dout = layer_data[-1]

    prob = mx.symbol.SoftmaxOutput(data=dout, label=label, normalization='batch')
    loss = - mx.symbol.sum(mx.sym.log(prob) * label) / bs

    return prop, loss
