"""
AIR applied to the multi-mnist data set [1].

[1] Eslami, SM Ali, et al. "Attend, infer, repeat: Fast scene
understanding with generative models." Advances in Neural Information
Processing Systems. 2016.
"""


import argparse
import math
import os
import time
from functools import partial

import numpy as np
import torch
import visdom

import pyro
import pyro.contrib.examples.multi_mnist as multi_mnist
import pyro.optim as optim
import pyro.poutine as poutine
from air import AIR, latents_to_tensor
from pyro.contrib.examples.util import get_data_directory
from pyro.infer import SVI, JitTraceGraph_ELBO, TraceGraph_ELBO
from viz import draw_many, tensor_to_objs


def count_accuracy(X, true_counts, air, batch_size):
    assert X.size(0) == true_counts.size(0), 'Size mismatch.'
    assert X.size(0) % batch_size == 0, 'Input size must be multiple of batch_size.'
    counts = torch.LongTensor(3, 4).zero_()
    error_latents = []
    error_indicators = []

    def count_vec_to_mat(vec, max_index):
        out = torch.LongTensor(vec.size(0), max_index + 1).zero_()
        out.scatter_(1, vec.type(torch.LongTensor).view(vec.size(0), 1), 1)
        return out

    for i in range(X.size(0) // batch_size):
        X_batch = X[i * batch_size:(i + 1) * batch_size]
        true_counts_batch = true_counts[i * batch_size:(i + 1) * batch_size]
        z_where, z_pres = air.guide(X_batch, batch_size)
        inferred_counts = sum(z.cpu() for z in z_pres).squeeze().data
        true_counts_m = count_vec_to_mat(true_counts_batch, 2)
        inferred_counts_m = count_vec_to_mat(inferred_counts, 3)
        counts += torch.mm(true_counts_m.t(), inferred_counts_m)
        error_ind = 1 - (true_counts_batch == inferred_counts)
        error_ix = error_ind.nonzero().squeeze()
        error_latents.append(latents_to_tensor((z_where, z_pres)).index_select(0, error_ix))
        error_indicators.append(error_ind)

    acc = counts.diag().sum().float() / X.size(0)
    error_indices = torch.cat(error_indicators).nonzero().squeeze()
    if X.is_cuda:
        error_indices = error_indices.cuda()
    return acc, counts, torch.cat(error_latents), error_indices


# Defines something like a truncated geometric. Like the geometric,
# this has the property that there's a constant difference in log prob
# between p(steps=n) and p(steps=n+1).
def make_prior(k):
    assert 0 < k <= 1
    u = 1 / (1 + k + k**2 + k**3)
    p0 = 1 - u
    p1 = 1 - (k * u) / p0
    p2 = 1 - (k**2 * u) / (p0 * p1)
    trial_probs = [p0, p1, p2]
    # dist = [1 - p0, p0 * (1 - p1), p0 * p1 * (1 - p2), p0 * p1 * p2]
    # print(dist)
    return lambda t: trial_probs[t]


# Implements "prior annealing" as described in this blog post:
# http://akosiorek.github.io/ml/2017/09/03/implementing-air.html

# That implementation does something very close to the following:
# --z-pres-prior (1 - 1e-15)
# --z-pres-prior-raw
# --anneal-prior exp
# --anneal-prior-to 1e-7
# --anneal-prior-begin 1000
# --anneal-prior-duration 1e6

# e.g. After 200K steps z_pres_p will have decayed to ~0.04

# These compute the value of a decaying value at time t.
# initial: initial value
# final: final value, reached after begin + duration steps
# begin: number of steps before decay begins
# duration: number of steps over which decay occurs
# t: current time step


def lin_decay(initial, final, begin, duration, t):
    assert duration > 0
    x = (final - initial) * (t - begin) / duration + initial
    return max(min(x, initial), final)


def exp_decay(initial, final, begin, duration, t):
    assert final > 0
    assert duration > 0
    # half_life = math.log(2) / math.log(initial / final) * duration
    decay_rate = math.log(initial / final) / duration
    x = initial * math.exp(-decay_rate * (t - begin))
    return max(min(x, initial), final)


def load_data():
    inpath = get_data_directory(__file__)
    X_np, Y = multi_mnist.load(inpath)
    X_np = X_np.astype(np.float32)
    X_np /= 255.0
    X = torch.from_numpy(X_np)
    # Using FloatTensor to allow comparison with values sampled from
    # Bernoulli.
    counts = torch.FloatTensor([len(objs) for objs in Y])
    return X, counts


def main(**kwargs):

    args = argparse.Namespace(**kwargs)

    if 'save' in args:
        if os.path.exists(args.save):
            raise RuntimeError('Output file "{}" already exists.'.format(args.save))

    if args.seed is not None:
        pyro.set_rng_seed(args.seed)

    X, true_counts = load_data()
    X_size = X.size(0)
    if args.cuda:
        X = X.cuda()

    # Build a function to compute z_pres prior probabilities.
    if args.z_pres_prior_raw:
        def base_z_pres_prior_p(t):
            return args.z_pres_prior
    else:
        base_z_pres_prior_p = make_prior(args.z_pres_prior)

    # Wrap with logic to apply any annealing.
    def z_pres_prior_p(opt_step, time_step):
        p = base_z_pres_prior_p(time_step)
        if args.anneal_prior == 'none':
            return p
        else:
            decay = dict(lin=lin_decay, exp=exp_decay)[args.anneal_prior]
            return decay(p, args.anneal_prior_to, args.anneal_prior_begin,
                         args.anneal_prior_duration, opt_step)

    model_arg_keys = ['window_size',
                      'rnn_hidden_size',
                      'decoder_output_bias',
                      'decoder_output_use_sigmoid',
                      'baseline_scalar',
                      'encoder_net',
                      'decoder_net',
                      'predict_net',
                      'embed_net',
                      'bl_predict_net',
                      'non_linearity',
                      'pos_prior_mean',
                      'pos_prior_sd',
                      'scale_prior_mean',
                      'scale_prior_sd']
    model_args = {key: getattr(args, key) for key in model_arg_keys if key in args}
    air = AIR(
        num_steps=args.model_steps,
        x_size=50,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        z_what_size=args.encoder_latent_size,
        use_cuda=args.cuda,
        **model_args
    )

    if args.verbose:
        print(air)
        print(args)

    if 'load' in args:
        print('Loading parameters...')
        air.load_state_dict(torch.load(args.load))

    # Viz sample from prior.
    if args.viz:
        vis = visdom.Visdom(env=args.visdom_env)
        z, x = air.prior(5, z_pres_prior_p=partial(z_pres_prior_p, 0))
        vis.images(draw_many(x, tensor_to_objs(latents_to_tensor(z))))

    def isBaselineParam(module_name, param_name):
        return 'bl_' in module_name or 'bl_' in param_name

    def per_param_optim_args(module_name, param_name):
        lr = args.baseline_learning_rate if isBaselineParam(module_name, param_name) else args.learning_rate
        return {'lr': lr}

    adam = optim.Adam(per_param_optim_args)
    elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
    svi = SVI(air.model, air.guide, adam, loss=elbo)

    # Do inference.
    t0 = time.time()
    examples_to_viz = X[5:10]

    for i in range(1, args.num_steps + 1):

        loss = svi.step(X, batch_size=args.batch_size, z_pres_prior_p=partial(z_pres_prior_p, i))

        if args.progress_every > 0 and i % args.progress_every == 0:
            print('i={}, epochs={:.2f}, elapsed={:.2f}, elbo={:.2f}'.format(
                i,
                (i * args.batch_size) / X_size,
                (time.time() - t0) / 3600,
                loss / X_size))

        if args.viz and i % args.viz_every == 0:
            trace = poutine.trace(air.guide).get_trace(examples_to_viz, None)
            z, recons = poutine.replay(air.prior, trace=trace)(examples_to_viz.size(0))
            z_wheres = tensor_to_objs(latents_to_tensor(z))

            # Show data with inferred objection positions.
            vis.images(draw_many(examples_to_viz, z_wheres))
            # Show reconstructions of data.
            vis.images(draw_many(recons, z_wheres))

        if args.eval_every > 0 and i % args.eval_every == 0:
            # Measure accuracy on subset of training data.
            acc, counts, error_z, error_ix = count_accuracy(X, true_counts, air, 1000)
            print('i={}, accuracy={}, counts={}'.format(i, acc, counts.numpy().tolist()))
            if args.viz and error_ix.size(0) > 0:
                vis.images(draw_many(X[error_ix[0:5]], tensor_to_objs(error_z[0:5])),
                           opts=dict(caption='errors ({})'.format(i)))

        if 'save' in args and i % args.save_every == 0:
            print('Saving parameters...')
            torch.save(air.state_dict(), args.save)


if __name__ == '__main__':
    assert pyro.__version__.startswith('0.5.0')
    parser = argparse.ArgumentParser(description="Pyro AIR example", argument_default=argparse.SUPPRESS)
    parser.add_argument('-n', '--num-steps', type=int, default=int(1e8),
                        help='number of optimization steps to take')
    parser.add_argument('-b', '--batch-size', type=int, default=64,
                        help='batch size')
    parser.add_argument('-lr', '--learning-rate', type=float, default=1e-4,
                        help='learning rate')
    parser.add_argument('-blr', '--baseline-learning-rate', type=float, default=1e-3,
                        help='baseline learning rate')
    parser.add_argument('--progress-every', type=int, default=1,
                        help='number of steps between writing progress to stdout')
    parser.add_argument('--eval-every', type=int, default=0,
                        help='number of steps between evaluations')
    parser.add_argument('--baseline-scalar', type=float,
                        help='scale the output of the baseline nets by this value')
    parser.add_argument('--no-baselines', action='store_true', default=False,
                        help='do not use data dependent baselines')
    parser.add_argument('--encoder-net', type=int, nargs='+', default=[200],
                        help='encoder net hidden layer sizes')
    parser.add_argument('--decoder-net', type=int, nargs='+', default=[200],
                        help='decoder net hidden layer sizes')
    parser.add_argument('--predict-net', type=int, nargs='+',
                        help='predict net hidden layer sizes')
    parser.add_argument('--embed-net', type=int, nargs='+',
                        help='embed net architecture')
    parser.add_argument('--bl-predict-net', type=int, nargs='+',
                        help='baseline predict net hidden layer sizes')
    parser.add_argument('--non-linearity', type=str,
                        help='non linearity to use throughout')
    parser.add_argument('--viz', action='store_true', default=False,
                        help='generate vizualizations during optimization')
    parser.add_argument('--viz-every', type=int, default=100,
                        help='number of steps between vizualizations')
    parser.add_argument('--visdom-env', default='main',
                        help='visdom enviroment name')
    parser.add_argument('--load', type=str,
                        help='load previously saved parameters')
    parser.add_argument('--save', type=str,
                        help='save parameters to specified file')
    parser.add_argument('--save-every', type=int, default=1e4,
                        help='number of steps between parameter saves')
    parser.add_argument('--cuda', action='store_true', default=False,
                        help='use cuda')
    parser.add_argument('--jit', action='store_true', default=False,
                        help='use PyTorch jit')
    parser.add_argument('-t', '--model-steps', type=int, default=3,
                        help='number of time steps')
    parser.add_argument('--rnn-hidden-size', type=int, default=256,
                        help='rnn hidden size')
    parser.add_argument('--encoder-latent-size', type=int, default=50,
                        help='attention window encoder/decoder latent space size')
    parser.add_argument('--decoder-output-bias', type=float,
                        help='bias added to decoder output (prior to applying non-linearity)')
    parser.add_argument('--decoder-output-use-sigmoid', action='store_true',
                        help='apply sigmoid function to output of decoder network')
    parser.add_argument('--window-size', type=int, default=28,
                        help='attention window size')
    parser.add_argument('--z-pres-prior', type=float, default=0.5,
                        help='prior success probability for z_pres')
    parser.add_argument('--z-pres-prior-raw', action='store_true', default=False,
                        help='use --z-pres-prior directly as success prob instead of a geometric like prior')
    parser.add_argument('--anneal-prior', choices='none lin exp'.split(), default='none',
                        help='anneal z_pres prior during optimization')
    parser.add_argument('--anneal-prior-to', type=float, default=1e-7,
                        help='target z_pres prior prob')
    parser.add_argument('--anneal-prior-begin', type=int, default=0,
                        help='number of steps to wait before beginning to anneal the prior')
    parser.add_argument('--anneal-prior-duration', type=int, default=100000,
                        help='number of steps over which to anneal the prior')
    parser.add_argument('--pos-prior-mean', type=float,
                        help='mean of the window position prior')
    parser.add_argument('--pos-prior-sd', type=float,
                        help='std. dev. of the window position prior')
    parser.add_argument('--scale-prior-mean', type=float,
                        help='mean of the window scale prior')
    parser.add_argument('--scale-prior-sd', type=float,
                        help='std. dev. of the window scale prior')
    parser.add_argument('--no-masking', action='store_true', default=False,
                        help='do not mask out the costs of unused choices')
    parser.add_argument('--seed', type=int, help='random seed', default=None)
    parser.add_argument('-v', '--verbose', action='store_true', default=False,
                        help='write hyper parameters and network architecture to stdout')
    main(**vars(parser.parse_args()))
