#
# Obtain hyperspherical prototypes prior to network training.
#
# Hyperspherical Prototype Networks.
#
import os
import sys
import numpy as np
import random
import argparse
import torch
import torch.optim as optim
import torch.nn.functional as F
from   torch import nn

#
# PArse user arguments.
#
def parse_args():
    parser = argparse.ArgumentParser(description="Hyperspherical prototype construction")
    parser.add_argument('-c', dest="classes", help="Nr. classes", default=100, type=int)
    parser.add_argument('-d', dest="dims", help="Nr. dimensions", default=100, type=int)
    parser.add_argument('-l', dest="learning_rate", help="Learning rate", default=0.1, type=float)
    parser.add_argument('-m', dest="momentum", help="Momentum", default=0.9, type=float)
    parser.add_argument('-e', dest="epochs", help="Nr. epochs", default=10000, type=int,)
    parser.add_argument('-s', dest="seed", help="Seed", default=300, type=int)
    parser.add_argument('-r', dest="resdir", help="Resdir", default="", type=str)
    args = parser.parse_args()
    return args

#
# Compute the loss related to the prototypes.
#
def prototype_loss(prototypes):
    # Dot product of normalized prototypes is cosine similarity.
    product = torch.matmul(prototypes, prototypes.t()) + 1
    # Remove diagnonal from loss.
    product -= 2. * torch.diag(torch.diag(product))
    # Minimize maximum cosine similarity.
    loss = product.max(dim=1)[0]
    return loss.mean(), product.max()

#
# Compute a loss based on ranking order w.r.t. semantics.
#
def rank_loss(prototypes, ranktensor, triplets):
    product  = torch.matmul(prototypes, prototypes.t())
    outputs = product[triplets[:,0], triplets[:,1]] - product[triplets[:,0], triplets[:,2]]
    m       = nn.Sigmoid()
    outputs = m(outputs)
    labels  = ranktensor[triplets[:,0], triplets[:,1], triplets[:,2]]
    f_loss  = nn.BCELoss()
    loss    = f_loss(outputs, labels)
    return loss

#
# Main entry point of the script.
#
if __name__ == "__main__":
    # Parse user arguments.
    args = parse_args()
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device = torch.device("cuda")
    kwargs = {'num_workers': 64, 'pin_memory': True}

    # Set seed.
    seed = args.seed
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    # Initialize prototypes and optimizer.
    prototypes = torch.randn(args.classes, args.dims)
    prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))
    optimizer = optim.SGD([prototypes], lr=args.learning_rate, momentum=args.momentum)

    # Optimize for separation.
    for i in xrange(args.epochs):
        # Compute loss.
        loss1, sep = prototype_loss(prototypes)
        loss2 = rank_loss(prototypes, ranktensor, triplets)
        loss = loss1 + loss2
        # Update.
        loss.backward()
        optimizer.step()
        # Renormalize prototypes.
        prototypes = nn.Parameter(F.normalize(prototypes, p=2, dim=1))
        optimizer = optim.SGD([prototypes], lr=args.learning_rate, momentum=args.momentum)
        print "%03d/%d: %.4f\r" %(i, args.epochs, sep),
        sys.stdout.flush()
    print
    
    # Store result.
    np.save(args.resdir + "prototypes-%dd-%dc.npy" %(args.dims, args.classes), prototypes.data.numpy())
