import numpy as np
import torch


class TimeTop:
    def __init__(self, node_num, min_time, lam, encode_type, decode_hop, decode_beta, device):
        self.device = device
        self.score_matrix = [torch.zeros((node_num, node_num), device=self.device) for _ in range(decode_hop)]
        self.score_matrix = [torch.eye(node_num, device=self.device)] + self.score_matrix
        self.node_degree = torch.zeros(node_num, device=device)
        self.now_time = min_time
        self.lam = lam
        self.new_lam = lam * 1000
        self.node_num = node_num
        self.encode_type = encode_type
        self.decode_hop = decode_hop
        self.decode_beta = decode_beta
        self.edge_happen_flag = dict()
        self.node_happen_flag = dict()

    def update(self, src_ids, dst_ids, interact_times):
        concat_src_ids = np.concatenate([src_ids, dst_ids])
        concat_dst_ids = np.concatenate([dst_ids, src_ids])
        concat_times = np.tile(interact_times, 2)
        next_time = interact_times[-1]
        for i in range(1, self.decode_hop + 1):
            self.score_matrix[i] = self.score_matrix[i] * np.exp(-self.lam * (next_time - self.now_time)*i)

        time_weight = torch.tensor(np.exp(-self.new_lam * (next_time - concat_times)), device=self.device,
                                   dtype=torch.float)
        for i in range(self.decode_hop, 0, -1):
            self.score_matrix[i].scatter_add_(dim=0, index=torch.from_numpy(concat_src_ids).
                                              to(self.device)[:,None].expand(-1, self.node_num),
                                              src=self.score_matrix[i - 1][concat_dst_ids] * time_weight[:, None])

        self.node_degree = self.node_degree * np.exp(-self.new_lam * (next_time - self.now_time))
        self.node_degree.scatter_add_(dim=0, index=torch.from_numpy(concat_src_ids).to(self.device), src=time_weight)
        self.now_time = next_time
        self.update_happen_flag(src_ids, dst_ids)

    def update_happen_flag(self, src_ids, dst_ids):
        for src_id, dst_id in zip(src_ids, dst_ids):
            self.edge_happen_flag[(src_id, dst_id)] = 1
            self.edge_happen_flag[(dst_id, src_id)] = 1
            self.node_happen_flag[src_id] = 1
            self.node_happen_flag[dst_id] = 1

    def query_edge_happen(self, src_ids, dst_ids):
        results = [(src_id, dst_id) in self.edge_happen_flag for src_id, dst_id in zip(src_ids, dst_ids)]
        return torch.tensor(results, device=self.device)

    def query_node_happen(self, node_ids):
        results = [node_id in self.node_happen_flag for node_id in node_ids]
        return torch.tensor(results, device=self.device)

    def two_dimensional_index_select(self, scores, idx1, idx2):
        scores = scores[idx1]
        scores = torch.gather(scores, 1, index=torch.tensor(idx2[:, None], device=self.device)).squeeze(dim=1)
        return scores

    def get_node_degree(self, node_ids):
        return self.node_degree[node_ids] * 0.00001

    def get_probability(self, src_ids, dst_ids, neg_src_ids, neg_dst_ids):
        # degree = self.score_matrix
        score_matrix = self.score_matrix[1]
        beta = self.decode_beta
        for i in range(2, self.decode_hop + 1):
            score_matrix = score_matrix + beta * self.score_matrix[i]
            beta = beta * self.decode_beta
        pos_probability = self.two_dimensional_index_select(score_matrix, src_ids, dst_ids)
        neg_probability = self.two_dimensional_index_select(score_matrix, neg_src_ids, neg_dst_ids)
        return pos_probability, neg_probability

# class TimeTop:
#     def __init__(self, node_num, min_time, lam, encode_type, decode_hop, decode_beta, device):
#         self.device = device
#         self.score_matrix = torch.zeros((node_num, node_num), device=self.device)
#         self.new_score_matrix = torch.zeros_like(self.score_matrix)
#         self.now_time = min_time
#         self.lam = lam
#         self.new_lam = lam * 1000
#         self.encode_type = encode_type
#         self.decode_hop = decode_hop
#         self.decode_beta = decode_beta
#         self.edge_happen_flag = dict()
#         self.node_happen_flag = dict()
#
#     def update(self, src_ids, dst_ids, interact_times):
#         concat_src_ids = np.concatenate([src_ids, dst_ids])
#         concat_dst_ids = np.concatenate([dst_ids, src_ids])
#         concat_times = np.tile(interact_times, 2)
#         if self.encode_type == 'continuous':
#             next_time = interact_times[-1]
#             self.score_matrix = self.score_matrix * np.exp(-self.lam * (next_time - self.now_time))
#             torch.index_put_(input=self.score_matrix,
#                              indices=[torch.tensor(concat_src_ids, device=self.device),
#                                       torch.tensor(concat_dst_ids, device=self.device)],
#                              values=torch.tensor(np.exp(-self.lam * (next_time - concat_times)), device=self.device,
#                                                  dtype=torch.float),
#                              accumulate=True)
#             self.new_score_matrix = self.new_score_matrix * np.exp(-self.new_lam * (next_time - self.now_time))
#             torch.index_put_(input=self.new_score_matrix,
#                              indices=[torch.tensor(concat_src_ids, device=self.device),
#                                       torch.tensor(concat_dst_ids, device=self.device)],
#                              values=torch.tensor(np.exp(-self.new_lam * (next_time - concat_times)), device=self.device,
#                                                  dtype=torch.float),
#                              accumulate=True)
#             self.now_time = next_time
#         elif self.encode_type == 'discrete':
#             self.score_matrix = self.score_matrix * self.lam
#             torch.index_put_(input=self.score_matrix,
#                              indices=[torch.tensor(concat_src_ids), torch.tensor(concat_dst_ids)],
#                              values=torch.tensor(1.0), accumulate=True)
#         else:
#             raise ValueError("Not Implemented Encode Type of TimeTop")
#         self.update_happen_flag(src_ids, dst_ids)
#
#     def update_happen_flag(self, src_ids, dst_ids):
#         for src_id, dst_id in zip(src_ids, dst_ids):
#             self.edge_happen_flag[(src_id, dst_id)] = 1
#             self.edge_happen_flag[(dst_id, src_id)] = 1
#             self.node_happen_flag[src_id] = 1
#             self.node_happen_flag[dst_id] = 1
#
#     def query_edge_happen(self, src_ids, dst_ids):
#         results = [(src_id, dst_id) in self.edge_happen_flag for src_id, dst_id in zip(src_ids, dst_ids)]
#         return torch.tensor(results, device=self.device)
#
#     def query_node_happen(self, node_ids):
#         results = [node_id in self.node_happen_flag for node_id in node_ids]
#         return torch.tensor(results, device=self.device)
#
#     def two_dimensional_index_select(self, scores, idx1, idx2):
#         scores = scores[idx1]
#         scores = torch.gather(scores, 1, index=torch.tensor(idx2[:, None], device=self.device)).squeeze(dim=1)
#         return scores
#
#     def get_node_degree(self, node_ids):
#         node_degree = torch.sum(self.new_score_matrix, dim=1)
#         return node_degree[node_ids] * 0.00001
#
#     def get_probability(self, src_ids, dst_ids, neg_src_ids, neg_dst_ids):
#         # degree = self.score_matrix
#         score_matrix = self.score_matrix
#         norm_score_matrix = self.score_matrix / (
#                 torch.sqrt(torch.sum(self.score_matrix, keepdim=True, dim=0)) + 0.0000001)
#         norm_score_matrix = norm_score_matrix / (
#                 torch.sqrt(torch.sum(self.score_matrix, keepdim=True, dim=1)) + 0.0000001)
#         now_adj = norm_score_matrix
#         beta = 1.0
#         for i in range(1, self.decode_hop):
#             now_adj = torch.matmul(now_adj, norm_score_matrix)
#             beta = beta * self.decode_beta
#             score_matrix = score_matrix + beta * now_adj
#             # print(i, torch.sum(score_matrix), beta)
#             # print('------------------')
#         pos_probability = self.two_dimensional_index_select(score_matrix, src_ids, dst_ids)
#         neg_probability = self.two_dimensional_index_select(score_matrix, neg_src_ids, neg_dst_ids)
#         return pos_probability, neg_probability
