'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  product_basis.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
from typing import List, Optional

import torch
import torch.nn as nn

import torch.fx
import opt_einsum_fx

from src.o3.tensor_product import PlainTensorProduct
from src.utils.o3 import get_slices, get_shapes


L_MAX = 3
BATCH_SIZE = 10


class WeightedPathSummation(nn.Module):
    """Computes weighted, species (atom type) dependent summation over the separate paths 
    leading to the specific rotational order `l` of Cartesian harmonics provided in the 
    first input tensor. The second input tensor is typically one-hot encoded species and
    is used to get species dependent weights. The number of features in the output tensor 
    must be the same as for the first input tensor.

    Args:
        in1_l_max (int): Maximal rank of the first input tensor.
        out_l_max (int): Maximal rank of the output tensor.
        in1_features (int): Number of features in the first input tensor.
        in2_features (int): Number of features in the second input tensor.
        in1_paths (List[int], optional): Provides the number of paths used to generate Cartesian 
                                         harmonics of a particular rank provided in the first 
                                         input tensor. The weighted sum is computes across these 
                                         paths for each rank l.
        coupled_feats (bool, optional): If True, use mix features.
    """
    def __init__(self,
                 in1_l_max: int,
                 out_l_max: int,
                 in1_features: int,
                 in2_features: int,
                 in1_paths: Optional[List[int]] = None,
                 coupled_feats: bool = False):
        super(WeightedPathSummation, self).__init__()
        self.in1_l_max = in1_l_max
        self.out_l_max = out_l_max
        self.in1_features = in1_features
        self.in2_features = in2_features
        self.in1_paths = in1_paths
        
        if self.out_l_max > L_MAX or self.in1_l_max > L_MAX:
            raise RuntimeError(f'Product basis is implemented for l <= {L_MAX=}.')
        
        # define the number of paths used to compute Cartesian harmonics in the first input tensor
        if in1_paths is None:
            self.in1_paths = [1 for _ in range(in1_l_max + 1)]
        else:
            self.in1_paths = in1_paths
        assert len(self.in1_paths) == in1_l_max + 1
        
        # slices and shapes for tensors of rank l in the flattened input tensor
        self.in1_slices = get_slices(in1_l_max, in1_features, self.in1_paths)
        self.in1_shapes = get_shapes(in1_l_max, in1_features, self.in1_paths, use_prod=True if coupled_feats else False)
        
        # dimensions of the input tensors for sanity checks
        self.in1_dim = sum([(3 ** l) * in1_features * self.in1_paths[l] for l in range(in1_l_max + 1)])
        self.in2_dim = in2_features
        
        # define weight
        self.weight = nn.ParameterList([])
        for n_paths in self.in1_paths[:self.out_l_max+1]:
            if coupled_feats:
                self.weight.append(nn.Parameter(torch.randn(in2_features, in1_features * n_paths, in1_features) / n_paths / in1_features ** 0.5))
            else:
                self.weight.append(nn.Parameter(torch.randn(in2_features, in1_features, n_paths) / n_paths))
        
        # trace and optimize contractions
        self.contractions = nn.ModuleList()
        
        if coupled_feats:
            contraction_eq = 'wvu,av,aw->au'
            example_weight = torch.randn(in2_features, in1_features * self.in1_paths[0], in1_features)
            example_in1 = torch.randn(BATCH_SIZE, in1_features * self.in1_paths[0])
        else:
            contraction_eq = 'wvp,avp,aw->av'
            example_weight = torch.randn(in2_features, in1_features, self.in1_paths[0])
            example_in1 = torch.randn(BATCH_SIZE, in1_features, self.in1_paths[0])
        example_in2 = torch.randn(BATCH_SIZE, in2_features)
        
        contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(contraction_eq, w, x, y))
        contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                             example_inputs=(example_weight, example_in1, example_in2))
        self.contractions.append(contraction_op)
        
        if self.out_l_max > 0:
            
            if coupled_feats:
                contraction_eq = 'wvu,aiv,aw->aiu'
                example_weight = torch.randn(in2_features, in1_features * self.in1_paths[1], in1_features)
                example_in1 = torch.randn(BATCH_SIZE, 3, in1_features * self.in1_paths[1])
            else:
                contraction_eq = 'wvp,aivp,aw->aiv'
                example_weight = torch.randn(in2_features, in1_features, self.in1_paths[1])
                example_in1 = torch.randn(BATCH_SIZE, 3, in1_features, self.in1_paths[1])
            
            contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(contraction_eq, w, x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(example_weight, example_in1, example_in2))
            self.contractions.append(contraction_op)
            
        if self.out_l_max > 1:
            
            if coupled_feats:
                contraction_eq = 'wvu,aijv,aw->aiju'
                example_weight = torch.randn(in2_features, in1_features * self.in1_paths[2], in1_features)
                example_in1 = torch.randn(BATCH_SIZE, 3, 3, in1_features * self.in1_paths[2])
            else:
                contraction_eq = 'wvp,aijvp,aw->aijv'
                example_weight = torch.randn(in2_features, in1_features, self.in1_paths[2])
                example_in1 = torch.randn(BATCH_SIZE, 3, 3, in1_features, self.in1_paths[2])
            
            contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(contraction_eq, w, x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(example_weight, example_in1, example_in2))
            self.contractions.append(contraction_op)
            
        if self.out_l_max > 2:
            
            if coupled_feats:
                contraction_eq = 'wvu,aijkv,aw->aijku'
                example_weight = torch.randn(in2_features, in1_features * self.in1_paths[3], in1_features)
                example_in1 = torch.randn(BATCH_SIZE, 3, 3, 3, in1_features * self.in1_paths[3])
            else:
                contraction_eq = 'wvp,aijkvp,aw->aijkv'
                example_weight = torch.randn(in2_features, in1_features, self.in1_paths[3])
                example_in1 = torch.randn(BATCH_SIZE, 3, 3, 3, in1_features, self.in1_paths[3])
                
            contraction_tr = torch.fx.symbolic_trace(lambda w, x, y: torch.einsum(contraction_eq, w, x, y))
            contraction_op = opt_einsum_fx.optimize_einsums_full(model=contraction_tr, 
                                                                 example_inputs=(example_weight, example_in1, example_in2))
            self.contractions.append(contraction_op)

    def forward(self, 
                x: torch.Tensor, 
                y: torch.Tensor) -> torch.Tensor:
        """Computes weighted sum across the contraction paths.

        Args:
            x (torch.Tensor): First input tensor. Typically contains concatenated 
                              Cartesian harmonics.
            y (torch.Tensor): Second input tensor. Typically contains one-hot 
                              encoded species (atom types).

        Returns:
            torch.Tensor: Tensor with concatenated Cartesian harmonics.
        """
        torch._assert(x.shape[-1] == self.in1_dim, 'Incorrect last dimension for x.')
        torch._assert(y.shape[-1] == self.in2_dim, 'Incorrect last dimension for y.')
        
        # x shape: n_neighbors x (3 x ... x l-times x ... x 3 * n_paths) x in_features x n_paths
        x_0 = x[:, self.in1_slices[0]].view(x.shape[0], *self.in1_shapes[0])
        if self.out_l_max > 0: x_1 = x[:, self.in1_slices[1]].view(x.shape[0], *self.in1_shapes[1])
        if self.out_l_max > 1: x_2 = x[:, self.in1_slices[2]].view(x.shape[0], *self.in1_shapes[2])
        if self.out_l_max > 2: x_3 = x[:, self.in1_slices[3]].view(x.shape[0], *self.in1_shapes[3])
        
        # sum paths for l=0
        x_0 = self.contractions[0](self.weight[0], x_0, y)
        if self.out_l_max == 0:
            return x_0
        
        # sum paths for l=1
        x_1 = self.contractions[1](self.weight[1], x_1, y)
        if self.out_l_max == 1:
            return torch.cat([x_0, 
                              x_1.reshape(x.shape[0], 3 * self.in1_features)], -1)
        
        # sum paths for l=2
        x_2 = self.contractions[2](self.weight[2], x_2, y)
        if self.out_l_max == 2:
            return torch.cat([x_0, 
                              x_1.reshape(x.shape[0], 3 * self.in1_features),
                              x_2.reshape(x.shape[0], (3 ** 2) * self.in1_features)], -1)
        
        # sum paths for l=3
        x_3 = self.contractions[3](self.weight[3], x_3, y)
        if self.out_l_max == 3:
            return torch.cat([x_0, 
                              x_1.reshape(x.shape[0], 3 * self.in1_features),
                              x_2.reshape(x.shape[0], (3 ** 2) * self.in1_features),
                              x_3.reshape(x.shape[0], (3 ** 3) * self.in1_features)], -1)
            
    def __repr__(self) -> str:
        return (f"{self.__class__.__name__} ({self.in1_l_max} -> {self.out_l_max} | {self.in1_paths[:self.out_l_max+1]} -> {[1 for _ in range(self.out_l_max+1)]} paths | {sum([w.numel() for w in self.weight])} weights)")


class WeightedProductBasis(nn.Module):
    """Weighted product basis obtained by contracting irreducible Cartesian tensors/Cartesian harmonics.
    
    Args:
        in1_l_max (int): Maximal rank of the first input tensor.
        out_l_max (int): Maximal rank of the output tensor.
        in1_features (int): Number of features in the first input tensor.
        in2_features (int): Number of features in the second input tensor.
        correlation (int): Correlation order, i.e., number of contracted tensors.
        coupled_feats (bool, optional): If True, use mix features.
        symmetric_product (bool, optional): If True, exploit symmetry of the tensor product to reduce 
                                            the number of possible tensor contractions.
    """
    def __init__(self,
                 in1_l_max: int,
                 out_l_max: int,
                 in1_features: int,
                 in2_features: int,
                 correlation: int,
                 coupled_feats: bool = False,
                 symmetric_product: bool = True):
        super(WeightedProductBasis, self).__init__()
        self.correlation = correlation
        self.in1_l_max = in1_l_max
        self.out_l_max = out_l_max
        self.in1_features = in1_features
        
        # prepare tensor products for computing the product basis from the first input tensor
        # tensor products are computed only if correlation > 1
        self.tps = nn.ModuleList([])
        for i in range(self.correlation - 1):
            if i == self.correlation - 2:
                target_l_max=out_l_max
            else:
                target_l_max=in1_l_max
                
            in1_paths = None if i == 0 else self.tps[-1].n_paths
            
            # prepare tensor products
            # while computing the first contraction, symmetry of the tensor product can be used to save computational cost
            # however, in this case less parameters are learned during training
            self.tps.append(PlainTensorProduct(in1_l_max=in1_l_max, in2_l_max=in1_l_max, out_l_max=target_l_max,
                                               in1_features=in1_features, in2_features=in1_features, out_features=in1_features,
                                               in1_paths=in1_paths, symmetric_product=symmetric_product if i == 0 else False))
        
        # prepare weighted path sums
        self.weighted_sums = nn.ModuleList([])
        self.weighted_sums.append(WeightedPathSummation(in1_l_max=in1_l_max, out_l_max=out_l_max, 
                                                        in1_features=in1_features, in2_features=in2_features,
                                                        coupled_feats=coupled_feats))
        
        for i, tp in enumerate(self.tps):
            if i == self.correlation - 2:
                target_l_max=out_l_max
            else:
                target_l_max=in1_l_max
            self.weighted_sums.append(WeightedPathSummation(in1_l_max=target_l_max, out_l_max=out_l_max,
                                                            in1_features=in1_features, in2_features=in2_features,
                                                            in1_paths=tp.n_paths, coupled_feats=coupled_feats))
    
    def forward(self, 
                x: torch.Tensor,
                y: torch.Tensor) -> torch.Tensor:
        """Computes the weighted product basis features.

        Args:
            x (torch.Tensor): First input tensor. Typically contains concatenated 
                              Cartesian harmonics.
            y (torch.Tensor): Second input tensor. Typically contains one-hot 
                              encoded species.

        Returns:
            torch.Tensor: Output tensor with product basis features.
        """
        # shape: n_batch x n_feats * (1 + 3 + ... + 3^l)
        # correlation = 1
        basis = self.weighted_sums[0](x, y)
        
        # compute tensor products for correlation > 1
        if self.correlation > 1:
            out_tp = self.tps[0](x, x)
            basis = basis + self.weighted_sums[1](out_tp, y)
            for i in range(1, self.correlation - 1):
                out_tp = self.tps[i](out_tp, x)
                basis = basis + self.weighted_sums[i+1](out_tp, y)
        return basis
