""" File that holds adversarial perturbations as torch.nn.Modules.
    An adversarial perturbation is an example-specific

TODO: this needs to be fleshed out, as a general technique to make adversarial
      perturbations.
"""

import torch
import torch.nn as nn
import spatial_transformers as st
import utils.image_utils as img_utils
import utils.pytorch_utils as utils
from torch.autograd import Variable
import functools


# assert initialized decorator
def initialized(func):
    @functools.wraps(func)
    def wrapper(self, *args, **kwargs):
        assert self.initialized, ("Parameters not initialized yet. "
                                  "Call .forward(...) first")
        return func(self, *args, **kwargs)
    return wrapper

##############################################################################
#                                                                            #
#                                   SKELETON CLASS                           #
#                                                                            #
##############################################################################

class AdversarialPerturbation(nn.Module):
    """ Skeleton class to hold adversarial perturbations FOR A SINGLE MINIBATCH.
        For general input-agnostic adversarial perturbations, see the
        ThreatModel class

        All subclasses need the following:
        - perturbation_norm() : no args -> scalar Variable
        - self.parameters() needs to iterate over params we want to optimize
        - constrain_params() : no args -> no return,
             modifies the parameters such that this is still a valid image
        - forward : no args -> Variable - applies the adversarial perturbation
                    the originals and outputs a Variable of how we got there
        - adversarial_tensors() : applies the adversarial transform to the
                                  originals and outputs TENSORS that are the
                                  adversarial images
    """

    def __init__(self, threat_model, perturbation_params):

        super(AdversarialPerturbation, self).__init__()
        self.threat_model = threat_model
        self.initialized = False
        self.perturbation_params = perturbation_params

        if isinstance(perturbation_params, tuple):
            self.use_gpu = perturbation_params[1].use_gpu or utils.use_gpu()
        else:
            self.use_gpu = perturbation_params.use_gpu or utils.use_gpu()
        # Stores parameters of the adversarial perturbation and hyperparams
        # to compute total perturbation norm here


    def __call__(self, x):
        return self.forward(x)

    def __repr__(self):
        if isinstance(self.perturbation_params, tuple):
            output_str = "[Perturbation] %s: %s" % (self.__class__.__name__,
                                                    self.perturbation_params[1])
            output_str += '\n['
            for el in self.perturbation_params[0]:
                output_str += '\n\t%s,' % el
            output_str += '\n]'
            return output_str
        else:
            return "[Perturbation] %s: %s"  % (self.__class__.__name__,
                                               self.perturbation_params)

    def _merge_setup(self, *args):
        """ Internal method to be used when initializing a new perturbation
            from merging only. Should not be called outside this file!!
        """
        pass

    def setup(self, x):
        """ This is the standard setup technique and should be used to
            initialize a perturbation (i.e. sets up parameters and unlocks
            all other methods)
        ARGS:
            x : Variable or Tensor (NxCxHxW) - the images this perturbation is
                intended for
        """
        self.num_examples = x.shape[0]


    @initialized
    def perturbation_norm(self, x=None):
        """ This returns the 'norm' of this perturbation. Optionally, for
            certain norms, having access to the images for which the
            perturbation is intended can have an effect on the output.
        ARGS:
            x : Variable or Tensor (NxCxHxW) - optionally can be the images
                that the perturbation was intended for
        RETURNS:
            Scalar Variable
        """
        raise NotImplementedError("Need to call subclass method here")

    @initialized
    def constrain_params(self):
        """ This modifies the parameters such that the perturbation falls within
            the threat model it belongs to. E.g. for l-infinity threat models,
            this clips the params to match the right l-infinity bound.

            TODO: for non-lp norms, projecting to the nearest point in the level
                  set
        """
        raise NotImplementedError("Need to call subclass method here")

    @initialized
    def make_valid_image(self, x):
        """ This takes in the minibatch self's parameters were tuned for and
            clips the parameters such that this is still a valid image.
        ARGS:
            x : Variable or Tensor (NxCxHxW) - the images this this perturbation
                was intended for
        RETURNS:
            None
        """
        pass # Only implement in classes that can create invalid images

    @initialized
    def forward(self, x):
        """ This takes in the minibatch self's parameters were tuned for and
            outputs a variable of the perturbation applied to the images
        ARGS:
            x : Variable (NxCxHxW) - the images this this perturbation
                was intended for
        RETURNS:
            Variable (NxCxHxW) - the perturbation applied to the input images
        """
        raise NotImplementedError("Need to call subclass method here")

    @initialized
    def add_to_params(self, grad_data):
        """ This takes in a Tensor the same shape as self's parameters and
            adds to them. Note that this usually won't preserve gradient
            information
            (also this might have different signatures in subclasses)
        ARGS:
            x : Tensor (params-shape) - Tensor to be added to the
                parameters of self
        RETURNS:
            None, but modifies self's parameters
        """
        raise NotImplementedError("Need to call subclass method here")

    @initialized
    def update_params(self, step_fxn):
        """ This takes in a function step_fxn: Tensor -> Tensor that generates
            the change to the parameters that we step along. This loops through
            all parameters and updates signs accordingly.
            For sequential perturbations, this also multiplies by a scalar if
            provided
        ARGS:
            step_fxn : Tensor -> Tensor - function that maps tensors to tensors.
                       e.g. for FGSM, we want a function that multiplies signs
                       by step_size
        RETURNS:
            None, but updates the parameters
        """
        raise NotImplementedError("Need to call subclass method here")


    @initialized
    def adversarial_tensors(self, x=None):
        """ Little helper method to get the tensors of the adversarial images
            directly
        """
        assert x is not None or self.originals is not None
        if x is None:
            x = self.originals

        return self.forward(x).data

    @initialized
    def attach_attr(self, attr_name, attr):
        """ Special method to set an attribute if it doesn't exist in this
            object yet. throws error if this attr already exists
        ARGS:
            attr_name : string - name of attribute we're attaching
            attr: object - attribute we're attaching
        RETURNS:
            None
        """
        if hasattr(self, attr_name):
            raise Exception("%s already has attribute %s" % (self, attr_name))
        else:
            setattr(self, attr_name, attr)


    @initialized
    def attach_originals(self, originals):
        """ Little helper method to tack on the original images to self to
            pass around the (images, perturbation) in a single object
        """
        self.attach_attr('originals', originals)


    @initialized
    def random_init(self):
        """ Modifies the parameters such that they're randomly initialized
            uniformly across the threat model (this is harder for nonLp threat
            models...). Takes no args and returns nothing, but modifies the
            parameters
        """
        raise NotImplementedError("Need to call subclass method here")


    @initialized
    def clone_perturbation():
        """ Special technique to create a 'clone' of this perturbaton where
            we return a new perturbation object with a deep copy of
            parameters. Uses 'merge_perturbation' as a subroutine
        """
        pass


    @initialized
    def merge_perturbation(self, other, self_mask):
        """ Special technique to merge this perturbation with another
            perturbation of the same threat model.
            This will return a new perturbation object that, for each parameter
            will return the parameters of self for self_mask, and the
            perturbation of other for NOT(self_mask)

        ARGS:
            other: AdversarialPerturbation Object - instance of other
                   adversarial perturbation that is instantiated with the
                   same threat model as self
            self_indices: ByteTensor [N] : bytetensor indicating which
                          parameters to include from self and which to include
                          from other
        """

        # this parent class just does the shared asserts such that this is a
        # valid thing
        assert self.__class__ == other.__class__
        assert self.threat_model == other.threat_model
        assert self.num_examples == other.num_examples
        assert self.perturbation_params == other.perturbation_params
        assert other.initialized


    @initialized
    def scatter_perturbation(self, scatter_size, mask):
        """ Takes a perturbation over examples of shape NxCxHxW and scatters
            them into a perturbation for examples of shape N'xCxHxW where N' > N
            The 'originals' for unscattered indices in the returned perturbation
            are zeros for unscattered, and the perturbation params are the
            identity params
        ARGS:
            scatter_size : int - the number of examples the scattering maps into
            mask: int[] - list of indices that each index of self maps into.
                          Should be sorted and unique
        RETURNS: Subclasses will return a perturbation of the same type, but
                 over a larger set
                 This class just returns the scattered
        """
        try:
            assert scatter_size >= self.num_examples
        except Exception as err:
            print("ERROR MESSAGE", scatter_size, self.num_examples)
            raise err
        assert len(mask) == self.num_examples
        assert sorted(set(mask)) == mask # sorted + unique

        # Also generate the scat
        scattered_originals = utils.scatter_expand(self.originals,
                                                   scatter_size, mask)
        return scattered_originals


    @initialized
    def collect_successful(self, classifier_net, normalizer,
                           success_def='alter_top_logit', topk=None,
                           labels=None):
        """ Returns an object with the original examples, adversarial examples,
            and indices of successful attacks, where the definition if success
            is provided as an argument. Success can be either to 'misclassify'
            or 'alter_top_logit':
            'misclassify' := adversarial examples have a different topk logit
                             than the provided labels
            'alter_top_logit' := adversarial examples have a different top-1
                                 logit than the original examples
        ARGS:
            classifier_net : nn.Module subclass - neural net that is the
                             relevant classifier
            normalizer : DifferentiableNormalize object - object to convert
                         input data to mean-zero, unit-var examples
            success_def : string - must be 'misclassify', 'alter_top_logit'



        RETURNS:
            {'adversarials': N'xCxHxW tensor of SUCCESSFUL AE's,
             'originals': N'xCxHxW tensor of CORRESPONDING ORIGINAL examples
             'success_idxs': indices (out of original N examples) of SUCCESSFUL
                             AE's
            }
        """

        # Safety check
        assert self.originals is not None
        assert success_def in ['misclassify', 'alter_top_logit']
        if success_def == 'misclassify':
            if topk is None:
                topk = (1,)
            assert labels is not None

        # Gather tensors
        adversarials = Variable(self.adversarial_tensors())
        originals = Variable(self.originals)

        # Run adversarials through the net
        adv_out = classifier_net(normalizer(adversarials))
        if success_def == 'alter_top_logit':
            labels = torch.max(classifier_net(normalizer(originals)), 1)[1]

        correctly_class_idxs = set(utils.accuracy(adv_out, labels,
                                                   return_correct_idxs=True)[1])
        success_idxs = torch.LongTensor([_ for _ in range(self.num_examples)
                                         if _ not in correctly_class_idxs])

        if self.originals.is_cuda:
            success_idxs = success_idxs.cuda()

        return {'adversarials': self.adversarial_tensors().index_select(0,
                                                              success_idxs),
                'originals': self.originals.index_select(0, success_idxs),
                'success_idxs': success_idxs}


    @initialized
    def display(self, scale=5, successful_only=False, classifier_net=None,
                normalizer=None, num_to_display=None):
        """ Displays this adversarial perturbation in a 3-row format:
            top row is adversarial images, second row is original images,
            bottom row is difference magnified by scale (default 5)
        ARGS:
            scale: int - how much to magnify differences by
            successful_only: bool - if True we only display successful (in that
                             advs output different classifier outputs)
                             If this is not None, classifie_net and normalizer
                             cannot be None
        RETURNS:
            None, but displays images
        """
        if successful_only:
            assert classifier_net is not None
            assert normalizer is not None

            successful = self.collect_successful(classifier_net, normalizer)
            advs = successful['adversarials']
            origs = successful['originals']
        else:
            advs = self.adversarial_tensors()
            origs = self.originals

        if num_to_display is not None:
            advs = advs[:num_to_display]
            origs = origs[:num_to_display]

        diffs = torch.clamp((advs - origs) * scale + 0.5, 0, 1)
        img_utils.show_images([advs, origs, diffs])


class PerturbationParameters(dict):
    """ Object that stores parameters like a dictionary.
        This allows perturbation classes to be only partially instantiated and
        then fed various 'originals' later.
    Implementation taken from : https://stackoverflow.com/a/14620633/3837607
    (and then modified with the getattribute trick to return none instead of
     error for missing attributes)
    """
    def __init__(self, *args, **kwargs):
        super(PerturbationParameters, self).__init__(*args, **kwargs)
        if kwargs.get('manual_gpu') is not None:
            self.use_gpu = kwargs['manual_gpu']
        else:
            self.use_gpu = utils.use_gpu()
        self.__dict__ = self

    def __getattribute__(self, name):
        try:
            return object.__getattribute__(self, name)
        except AttributeError:
            return None


class ThreatModel(object):
    def __init__(self, perturbation_class, param_kwargs, *other_args):
        """ Factory class to generate per_minibatch instances of Adversarial
            perturbations.
        ARGS:
            perturbation_class : class - subclass of Adversarial Perturbations
            param_kwargs : dict - dict containing named kwargs to instantiate
                           the class in perturbation class
        """
        assert issubclass(perturbation_class, AdversarialPerturbation)
        self.perturbation_class = perturbation_class
        if isinstance(param_kwargs, dict):
            param_kwargs = PerturbationParameters(**param_kwargs)
        self.param_kwargs = param_kwargs
        self.other_args = other_args

    def __repr__(self):
        return "[Threat] %s: %s"  % (str(self.perturbation_class.__name__),
                                     self.param_kwargs)

    def __call__(self, *args):
        if args == ():
            return self.perturbation_obj()
        else:
            perturbation_obj = self.perturbation_obj()
            perturbation_obj.setup(*args)
            return perturbation_obj



    def perturbation_obj(self):
        return self.perturbation_class(self, self.param_kwargs, *self.other_args)



##############################################################################
#                                                                            #
#                            ADDITION PARAMETERS                             #
#                                                                            #
##############################################################################

class DeltaAddition(AdversarialPerturbation):

    def __init__(self, threat_model, perturbation_params, *other_args):
        """ Maintains a delta that gets addded to the originals to generate
            adversarial images. This is the type of adversarial perturbation
            that the literature extensivey studies
        ARGS:
            threat_model : ThreatModel object that is used to initialize self
            perturbation_params: PerturbationParameters object.
                { lp_style : None, int or 'inf' - if not None is the type of
                            Lp_bound that we apply to this adversarial example
                lp_bound : None or float - cannot be None if lp_style is
                           not None, but if not None should be the lp bound
                           we allow for adversarial perturbations
                custom_norm : None or fxn:(NxCxHxW) -> Scalar Variable. This is
                              not implemented for now
                }
        """

        super(DeltaAddition, self).__init__(threat_model, perturbation_params)
        self.lp_style = perturbation_params.lp_style
        self.lp_bound = perturbation_params.lp_bound
        if perturbation_params.custom_norm is not None:
            raise NotImplementedError("Only LP norms allowed for now")
        self.scalar_step = perturbation_params.scalar_step or 1.0


    def _merge_setup(self, num_examples, delta_data):
        """ DANGEROUS TO BE CALLED OUTSIDE OF THIS FILE!!!"""
        self.num_examples = num_examples
        self.delta = nn.Parameter(delta_data)
        self.initialized = True


    def setup(self, x):
        super(DeltaAddition, self).setup(x)
        self.delta = nn.Parameter(torch.zeros_like(x))
        self.initialized = True

    @initialized
    def perturbation_norm(self, x=None, lp_style=None):
        lp_style = lp_style or self.lp_style
        assert isinstance(lp_style, int) or lp_style == 'inf'
        return utils.batchwise_norm(self.delta, lp=lp_style)


    @initialized
    def constrain_params(self):
        new_delta = utils.batchwise_lp_project(self.delta.data, self.lp_style,
                                               self.lp_bound)
        delta_diff = new_delta - self.delta.data
        self.delta.data.add_(delta_diff)

    @initialized
    def make_valid_image(self, x):
        new_delta = self.delta.data
        change_in_delta = utils.clamp_0_1_delta(new_delta, x)
        self.delta.data.add_(change_in_delta)

    @initialized
    def update_params(self, step_fxn):
        assert self.delta.grad.data is not None
        self.add_to_params(step_fxn(self.delta.grad.data) * self.scalar_step)

    @initialized
    def add_to_params(self, grad_data):
        """ sets params to be self.params + grad_data
        """
        # grad data can either be a tensor of same shape/device as delta
        # OR a singleton list containing a tensor of same shape/device as delta

        if isinstance(grad_data, list):
            assert len(grad_data) == 1
            grad_data = grad_data[0]
        assert grad_data.shape == self.delta.data.shape
        self.delta.data.add_(grad_data)


    @initialized
    def random_init(self, param=True):
        if param is True:
            bound = self.lp_bound
        else:
            bound = param
        self.delta = nn.Parameter(utils.random_from_lp_ball(self.delta.data,
                                                            self.lp_style,
                                                            bound))

    @initialized
    def clone_perturbation(self):
        new_perturbation = DeltaAddition(self.threat_model,
                                         self.perturbation_params)
        new_perturbation._merge_setup(self.num_examples, self.delta.data)
        self_mask = self.delta.new_ones((self.num_examples,), dtype=torch.uint8)
        return self.merge_perturbation(new_perturbation, self_mask)

    @initialized
    def merge_perturbation(self, other, self_mask):
        super(DeltaAddition, self).merge_perturbation(other, self_mask)

        # initialize a new perturbation
        new_perturbation = DeltaAddition(self.threat_model,
                                         self.perturbation_params)

        # make the new parameters
        new_delta = utils.fold_mask(self.delta.data, other.delta.data,
                                    self_mask)

        # do the merge setup and return the object
        new_perturbation._merge_setup(self.num_examples,
                                      new_delta)
        return new_perturbation

    @initialized
    def scatter_perturbation(self, scatter_size, mask):
        scattered_originals = super(DeltaAddition, self)\
                                .scatter_perturbation(scatter_size, mask)
        new_perturbation = DeltaAddition(self.threat_model,
                                         self.perturbation_params)
        new_perturbation.setup(scattered_originals)
        scattered_params = utils.scatter_expand(self.delta,
                                                scatter_size, mask)
        new_perturbation.add_to_params(scattered_params)


        return new_perturbation

    def forward(self, x):
        if not self.initialized:
            self.setup(x)
        self.make_valid_image(x) # not sure which one to do first...
        self.constrain_params()
        return x + self.delta


    @initialized
    def binsearch_closer(self, model, normalizer, labels, binsearch_steps=10):
        """ Takes in a model, normalizer and perturbation and labels and binary
            searches along the line between the originals and adverarsarial
            examples to find adversarial examples that are closer than the provided ones.
        ARGS:
            model: nn.Module subclass - neural net that we have adversarial examples
                                        for
            normalizer: DifferentiableNormalizer object - object to whiten inputs

            perturbation: AdversarialPerturbation object - must be a DeltaAddition
                          and have attached originals
            labels : Tensor - labels for the tensor we are searching for
            num_binsearch_steps: int - number of binary search steps we perform
        RETURNS:
            new Perturbation object with closer adversarial tensors
        """

        # Setup stuff
        assert self.originals is not None
        successful = self.collect_successful(model, normalizer,
                                             success_def='misclassify',
                                             labels=labels)
        num_examples = self.originals.shape[0]
        channels = self.originals.dim() - 1
        success_mask = torch.zeros(num_examples, dtype=torch.uint8,
                                                 device=self.originals.device)
        if successful['success_idxs'].numel() == 0:
            return self
        success_mask.index_fill_(0, successful['success_idxs'], 1).float()


        scale_hi = torch.ones(num_examples, dtype=self.originals.dtype,
                                            device=self.originals.device)
        scale_lo = torch.zeros_like(scale_hi)


        # Start binsearch
        for _ in range(binsearch_steps):
            # -- first check if the midpoint is misclassified
            scale = (scale_hi + scale_lo) / 2.0
            scale = scale.view((-1,) + (1,) * channels)
            to_checks = self.originals + scale * self.delta
            scale = scale.squeeze()
            check_labels = torch.max(model(normalizer(to_checks)), 1)[1]
            iter_misclass = (check_labels != labels).float()

            # -- adjust upper and lower bounds accordingly
            scale_hi = iter_misclass * scale +(1 - iter_misclass) * scale_hi
            scale_lo = iter_misclass * scale_lo + (1 - iter_misclass) * scale


        # Now make new perturbation with the binary searched things
        scale_hi = scale_hi.view((-1,) + (1,) * channels)
        success_mask = success_mask.view((-1,) + (1,) * channels).float()

        new_delta = success_mask * self.delta.data * scale_hi +\
                    (1 - success_mask) * self.delta.data

        new_perturbation = DeltaAddition(self.threat_model,
                                         self.perturbation_params)
        new_perturbation._merge_setup(num_examples, new_delta)
        new_perturbation.attach_originals(self.originals)

        return new_perturbation




##############################################################################
#                                                                            #
#                               SPATIAL PARAMETERS                           #
#                                                                            #
##############################################################################

class ParameterizedXformAdv(AdversarialPerturbation):

    def __init__(self, threat_model, perturbation_params, *other_args):
        super(ParameterizedXformAdv, self).__init__(threat_model,
                                                    perturbation_params)
        assert issubclass(perturbation_params.xform_class,
                          st.ParameterizedTransformation)

        self.lp_style = perturbation_params.lp_style
        self.lp_bound = perturbation_params.lp_bound
        self.use_stadv = perturbation_params.use_stadv
        self.scalar_step = perturbation_params.scalar_step or 1.0


    def _merge_setup(self, num_examples, new_xform):
        """ DANGEROUS TO BE CALLED OUTSIDE OF THIS FILE!!!"""
        self.num_examples = num_examples
        self.xform = new_xform
        self.initialized = True

    def setup(self, originals):
        super(ParameterizedXformAdv, self).setup(originals)
        self.xform = self.perturbation_params.xform_class(shape=originals.shape,
                                                        manual_gpu=self.use_gpu)
        self.initialized = True

    @initialized
    def perturbation_norm(self, x=None, lp_style=None):
        lp_style = lp_style or self.lp_style
        if self.use_stadv is not None:
            assert isinstance(self.xform, st.FullSpatial)
            return self.xform.stAdv_norm()
        else:
            return self.xform.norm(lp=lp_style)

    @initialized
    def constrain_params(self, x=None):
        # Do lp projections
        if isinstance(self.lp_style, int) or self.lp_style == 'inf':
            self.xform.project_params(self.lp_style, self.lp_bound)



    @initialized
    def update_params(self, step_fxn):
        param_list = list(self.xform.parameters())
        assert len(param_list) == 1
        params = param_list[0]
        assert params.grad.data is not None
        self.add_to_params(step_fxn(params.grad.data) * self.scalar_step)


    @initialized
    def add_to_params(self, grad_data):
        """ Assumes only one parameters object in the Spatial Transform """
        param_list = list(self.xform.parameters())
        assert len(param_list) == 1
        params = param_list[0]

        # grad data can either be a tensor of same shape/device as delta
        # OR a singleton list containing a tensor of same shape/device as delta
        param_list = list(self.xform.parameters())
        assert len(param_list) == 1
        params = param_list[0]
        if isinstance(grad_data, list):
            assert len(grad_data) == 1
            grad_data = grad_data[0]
        assert grad_data.shape == params.data.shape

        params.data.add_(grad_data)

    @initialized
    def random_init(self):
        param_list = list(self.xform.parameters())
        assert len(param_list) == 1
        param = param_list[0]
        random_perturb = utils.random_from_lp_ball(param.data,
                                                   self.lp_style,
                                                   self.lp_bound)

        param.data.add_(self.xform.identity_params(self.xform.img_shape) +
                        random_perturb - self.xform.xform_params.data)


    @initialized
    def clone_perturbation(self):
        new_perturbation = ParameterizedXformAdv(self.threat_model,
                                                 self.perturbation_params)
        new_perturbation._merge_setup(self.num_examples, self.xform)
        self_mask = list(self.xform.parameters())[0].new_ones((self.num_examples,),
                                                      dtype=torch.uint8)
        return self.merge_perturbation(new_perturbation, self_mask)


    @initialized
    def merge_perturbation(self, other, self_mask):
        super(ParameterizedXformAdv, self).merge_perturbation(other, self_mask)
        new_perturbation = ParameterizedXformAdv(self.threat_model,
                                                 self.perturbation_params)

        new_xform = self.xform.merge_xform(other.xform, self_mask)
        new_perturbation._merge_setup(self.num_examples, new_xform)

        return new_perturbation


    @initialized
    def scatter_perturbation(self, scatter_size, mask):
        scattered_originals = super(ParameterizedXformAdv, self)\
                                .scatter_perturbation(scatter_size, mask)
        example_shape = scattered_originals.shape[1:]
        new_perturbation = DeltaAddition(self.threat_model,
                                         self.perturbation_params)
        new_perturbation.setup(scattered_originals)

        # Now have to merge the parameters

        identity_el = new_perturbation.xform.identity_params(1,
                                                           *example_shape)[0]

        scattered_params = utils.scatter_expand(self.xform.xform_params,
                                                scatter_size, mask,
                                                identity_el=identity_el)
        new_perturbation.xform.xform_params = nn.Parameter(scattered_params)
        return new_perturbation


    def forward(self, x):
        if not self.initialized:
            self.setup(x)
        self.constrain_params()
        return self.xform.forward(x)




##############################################################################
#                                                                            #
#                            SPATIAL + ADDITION PARAMETERS                   #
#                                                                            #
##############################################################################

class SequentialPerturbation(AdversarialPerturbation):
    """ Takes a list of perturbations and composes them. A norm needs to
        be specified here to describe the perturbations.
    """

    def __init__(self, threat_model, perturbation_sequence,
                 global_parameters=PerturbationParameters(pad=10),
                 preinit_pipeline=None):
        """ Initializes a sequence of adversarial perturbation layers
        ARGS:
            originals : NxCxHxW tensor - original images we create adversarial
                        perturbations for
            perturbation_sequence : ThreatModel[]  -
                list of ThreatModel objects
            global_parameters : PerturbationParameters - global parameters to
                                use. These contain things like how to norm this
                                sequence, how to constrain this sequence, etc
            preinit_pipelines: list[]
                if not None i
         """
        super(SequentialPerturbation, self).__init__(threat_model,
                                                    (perturbation_sequence,
                                                     global_parameters))

        if preinit_pipeline is not None:
            layers = preinit_pipeline
        else:
            layers = []
            for threat_model in perturbation_sequence:
                assert isinstance(threat_model, ThreatModel)
                layers.append(threat_model())

        self.pipeline = []
        for layer_no, layer in enumerate(layers):
            self.pipeline.append(layer)
            self.add_module('layer_%02d' % layer_no, layer)


        # norm: pipeline -> Scalar Variable
        self.norm = global_parameters.norm
        self.norm_weights = global_parameters.norm_weights

        # padding with black is useful to not throw information away during
        # sequential steps
        self.pad = nn.ConstantPad2d(global_parameters.pad or 0, 0)
        self.unpad = nn.ConstantPad2d(-1 * (global_parameters.pad or 0), 0)




    def _merge_setup(self, num_examples):
        self.num_examples = num_examples
        self.initialized = True


    def setup(self, x):
        super(SequentialPerturbation, self).setup(x)
        x = self.pad(x)
        for layer in self.pipeline:
            layer.setup(x)
        self.initialized = True


    @initialized
    def perturbation_norm(self, x=None, lp_style=None):
        # Need to define a nice way to describe the norm here. This can be
        # an empirical norm between input/output
        # For now, let's just say it's the sum of the norms of each constituent
        if self.norm is not None:
            return self.norm(self.pipeline, x=x, lp_style=lp_style)
        else:
            norm_weights = self.norm_weights or\
                              [1.0 for _ in range(len(self.pipeline))]
            out = None
            for i, layer in enumerate(self.pipeline):
                weight = norm_weights[i]
                layer_norm = layer.perturbation_norm(x=x, lp_style=lp_style)
                if out is None:
                    out = layer_norm * weight
                else:
                    out = out + layer_norm * weight
            return out

    @initialized
    def make_valid_image(self, x):
        x = self.pad(x)
        for layer in self.pipeline:
            layer.make_valid_image(x)
            x = layer(x)


    @initialized
    def constrain_params(self):
        # Need to do some sort of crazy projection operator for general things
        # For now, let's just constrain each thing in sequence

        for layer in self.pipeline:
            layer.constrain_params()

    @initialized
    def update_params(self, step_fxn):
        for layer in self.pipeline:
            layer.update_params(step_fxn)

    @initialized
    def add_to_params(self, update_list):
        """ A little messy here: update_list needs to be a list of length
            len(pipeline) such that each element in this list can be passed
            to its corresponding pipeline elements 'add_to_params' method
        """
        assert len(update_list) == len(self.pipeline)
        for i in xrange(len(update_list)):
            self.pipeline[i].add_to_params(update_list[i])


    @initialized
    def clone_perturbation(self):
        new_pipeline = []
        for layer in self.pipeline:
            new_pipeline.append(layer.clone_perturbation())

        layer_params, global_params = self.perturbation_params
        new_perturbation = SequentialPerturbation(self.threat_model,
                                                layer_params,
                                                global_parameters=global_params,
                                                preinit_pipeline=new_pipeline)
        new_perturbation._merge_setup(self.num_examples)
        return new_perturbation


    @initialized
    def merge_perturbation(self, other, self_mask):
        super(SequentialPerturbation, self).merge_perturbation(other, self_mask)


        new_pipeline = []
        for self_layer, other_layer in zip(self.pipeline, other.pipeline):
            new_pipeline.append(self_layer.merge_perturbation(other_layer,
                                                              self_mask))

        layer_params, global_params = self.perturbation_params

        new_perturbation = SequentialPerturbation(self.threat_model,
                                                layer_params,
                                                global_parameters=global_params,
                                                preinit_pipeline=new_pipeline)
        new_perturbation._merge_setup(self.num_examples)

        return new_perturbation


    @initialized
    def scatter_perturbation(self, scatter_size, mask):
        scattered_originals = super(SequentialPerturbation, self).\
                                scatter_perturbation(scatter_size, mask)

        # First scatter on each component perturbation
        new_pipeline = []
        for layer in self.pipeline:
            new_pipeline.append(layer.scatter_perturbation(scatter_size, mask))


        # Then build the new perturbation object
        layer_params, global_params = self.perturbation_params
        new_perturbation = SequentialPerturbation(self.threat_model,
                                                layer_params,
                                                global_parameters=global_params,
                                                preinit_pipeline=new_pipeline)
        new_perturbation._merge_setup(scatter_size)

        return new_perturbation


    def forward(self, x, layer_slice=None):
        """ Layer slice here is either an int or a tuple
        If int, we run forward only the first layer_slice layers
        If tuple, we start at the

        """

        # Blocks to handle only particular layer slices (debugging)
        if layer_slice is None:
            pipeline_iter = iter(self.pipeline)
        elif isinstance(layer_slice, int):
            pipeline_iter = iter(self.pipeline[:layer_slice])
        elif isinstance(layer_slice, tuple):
            pipeline_iter = iter(self.pipeline[layer_slice[0]: layer_slice[1]])
        # End block to handle particular layer slices

        # Handle padding
        original_hw = x.shape[-2:]
        if not self.initialized:
            self.setup(x)

        self.constrain_params()
        self.make_valid_image(x)

        x = self.pad(x)
        for layer in pipeline_iter:
            x = layer(x)
        return self.unpad(x)


    @initialized
    def random_init(self):
        for layer in self.pipeline:
            layer.random_init()

    @initialized
    def attach_originals(self, originals):
        self.originals = originals
        for layer in self.pipeline:
            layer.attach_originals(originals)





