# !/usr/bin/env python

import sys
import time
import numpy as np
import uitree_pb2 as proto
import fileinput
import math
from tree import TreeLearner

def tree_learning(d, tree):
    """ The overall tree learning algorithm (Algorithm 2 in the paper)
    
    Returns:
        the leant new projection from item to leaf node (\pi_{new})
        
    Args:
        d (int, required): the tree learning level gap
        tree (tree, required): the old tree (\pi_{old})
    """
    l_max = tree.tree_meta.max_level - 1
    l = d

    pi_new = dict()

    # \pi_{new} <- \pi_{old}
    for item_code in tree.item_codes:
        ci = tree.nodes[item_code].id
        pi_new[ci] = tree.get_ancestor(item_code, l - d)

    while d > 0:
        nodes = tree.get_nodes_given_level(l - d)
        for ni in nodes:
            C_ni = get_itemset_given_ancestor(pi_new, ni)
            pi_star = assign_parent(l_max, l, d, ni, C_ni, tree)

            # update pi_new according to the found optimal pi_star
            for item, node in pi_star.items():
                pi_new[item] = node

        d = min(d, l_max - l)
        l = l + d

    return pi_new


def get_itemset_given_ancestor(pi_new, node):
    res = []
    for ci, code in pi_new.items():
        if code == node:
            res.append(ci)
    return res

def get_weights(C_ni, ni, children_of_ni_in_level_l, tree)
    """use the user preference prediction model to calculate the required weights
    
    Returns: 
        all weights
    
    Args:
        C_ni (item, required): item set whose ancestor is the non-leaf node ni
        ni (node, required): a non-leaf node in level l-d
        children_of_ni_in_level_l (list, required): the level l-th children of ni
        tree (tree, required): the old tree (\pi_{old})
    """
    edge_weights = dict()
    
    for ck in C_ni:
        edge_weights[ck] = list()
        edge_weights[ck].append([]) # the first element is the list of nodes in level l
        edge_weights[ck].append([]) # the second element is the list of corresponding weights

        for node in children_of_ni_in_level_l:
            path_to_ni = tree.get_parent_path(node, ni)
            weight = 0.0
            for n in path_to_ni:
                sample_set = set() # the sample set that the target item is ck
                
                # use the user preference prediction model to calculate the required weights.
                # the detailed calculation process is omitted here
                weight += calculate_weight_use_prediction_model(sample_set, n)

            edge_weights[ck][0].append(node)
            edge_weights[ck][1].append(weight)

    return edge_weights

def assign_parent(l_max, l, d, ni, C_ni, tree):
    """implementation of line 5 of Algorithm 2
    
    Returns: 
        updated \pi_{new}
    
    Args:
        l_max (int, required): the max level of the tree
        l (int, required): current assign level
        d (int, required): level gap in tree_learning
        ni (node, required): a non-leaf node in level l-d
        C_ni (item, required): item set whose ancestor is the non-leaf node ni
        tree (tree, required): the old tree (\pi_{old})
    """

    # get the children of ni in level l
    children_of_ni_in_level_l = tree.get_children_given_ancestor_and_level(ni, l)

    # get all the required weights 
    edge_weights = get_weights(C_ni, ni, children_of_ni_in_level_l, tree) 

    # assign each item to the level l node with the maximum weight
    assign_dict = dict()
    for ci, info in edge_weights.items():
        assign_candidate_nodes = info[0]
        assign_weights = np.array(info[1], dtype=np.float32)
        sorted_idx = np.argsort(-assign_weights)
        # assign item ci to the node with the largest weight
        max_weight_node = assign_candidate_nodes[sorted_idx[0]]
        if max_weight_node in assign_dict:
            assign_dict[max_weight_node].append((ci, sorted_idx, assign_candidate_nodes, assign_weights))
        else:
            assign_dict[max_weight_node] = [(ci, sorted_idx, assign_candidate_nodes, assign_weights)]

    edge_weights = None

    # get each item's original assignment of level l in tree, used in rebalance process
    origin_relation = dict()
    for ci in C_ni:
        origin_relation[ci] = tree.get_ancestor(ci, l)

    # rebalance
    max_assign_num = int(math.pow(2, l_max - l))
    processed_set = set()
    while True:
        max_assign_cnt = 0
        max_assign_node = None

        for node in children_of_ni_in_level_l:
            if node in processed_set:
                continue
            if node not in assign_dict:
                continue
            if len(assign_dict[node]) > max_assign_cnt:
                max_assign_cnt = len(assign_dict[node])
                max_assign_node = node

        if max_assign_node == None or max_assign_cnt <= max_assign_num:
            break
            
        # rebalance
        processed_set.add(max_assign_node)
        elements = assign_dict[max_assign_node]
        elements.sort(key=lambda x: (int(max_assign_node != origin_relation[x[0]]), -x[1]))
        for e in elements[max_assign_num:]:
            for idx in e[1]:
                other_parent_node = e[2][idx]
                if other_parent_node in processed_set:
                    continue
                if other_parent_node not in assign_dict:
                    assign_dict[other_parent_node] = [(e[0], e[1], e[2], e[3])]
                else:
                    assign_dict[other_parent_node].append((e[0], e[1], e[2], e[3]))
                break
        del elements[max_assign_num:]

    pi_new = dict()
    for parent_code, value in assign_dict.items():
        max_assign_num = int(math.pow(2, l_max - l))
        assert len(value) <= max_assign_num
        for e in value:
            assert e[0] not in pi_new
            pi_new[e[0]] = parent_code

    return pi_new


if __name__ == '__main__':
    tree_idx = 0
    old_tree_file = './cate_tree_%d/tree.pb' % tree_idx
    new_tree_file = './cate_tree_%d/tree.pb' % (tree_idx + 1)


    # load old tree
    tree = TreeLearner('old_tree')
    tree.load_tree(old_tree_file)

    # Algorithm 2: Tree learning algorithm
    d = 7
    pi_new = tree_learning(d, tree)

    # assign leaf nodes and save new tree
    tree.assign_leaf_nodes(pi_new, new_tree_file)
