import torch
# from scratch.user.quanqi-hu.single_loop_dc.pu_learning.utils.utils import check_tensor_shape

def check_tensor_shape(tensor, shape):
    # check tensor shape
    if not torch.is_tensor(tensor):
        raise ValueError('Input is not a valid torch tensor!')
    if not isinstance(shape, (tuple, list, int)):
        raise ValueError("Shape must be a tuple, an integer or a list!")
    if isinstance(shape, int):
        shape = torch.Size([shape])
    tensor_shape = tensor.shape
    if len(tensor_shape) != len(shape):
        tensor = tensor.reshape(shape)
    return tensor
    
class HingeLoss(torch.nn.Module):
    def __init__(self, device=None):
        super(HingeLoss, self).__init__()
        if not device:
            self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            self.device = device

    def forward(self, y_pred, y_true):
        y_pred = check_tensor_shape(y_pred, (-1, 1))
        y_true = check_tensor_shape(y_true, (-1, 1))
        # print('y_pred size = ', y_pred.size(), ', y_true size = ', y_true.size())
        # hinge_loss = torch.tensor([])
        hinge_loss = 0
        for ind in range(y_pred.size(0)):
            # print('y_pred size = ', y_pred.size(), ', y_true size = ',  y_true.size())
            hinge_loss += max(0, 1 - y_pred[ind] * y_true[ind])
        return hinge_loss/y_pred.size(0)

    
