# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for t5x.adafactor."""

import functools
import operator
from typing import Sequence

from absl.testing import absltest
from absl.testing import parameterized

import flax
from flax import optim  # used for equivalence testing only
from flax import traverse_util
import jax
from jax import numpy as jnp
from jax import random
import numpy as np

from t5x import adafactor
from t5x import optimizers

OptimizerState = optimizers.OptimizerState

_AdafactorHyperParams = adafactor._AdafactorHyperParams
_AdafactorParamState = adafactor._AdafactorParamState

_BATCH = adafactor.FactorDim.BATCH
_ROW = adafactor.FactorDim.ROW
_COL = adafactor.FactorDim.COLUMN

# Testing helpers


def _assert_numpy_allclose(a, b, atol=None, rtol=None):
  a, b = jnp.array(a), jnp.array(b)
  a = a.astype(np.float32) if a.dtype == jnp.bfloat16 else a
  b = b.astype(np.float32) if b.dtype == jnp.bfloat16 else b
  kw = {}
  if atol:
    kw['atol'] = atol
  if rtol:
    kw['rtol'] = rtol
  np.testing.assert_allclose(a, b, **kw)


def check_eq(xs, ys, atol=None, rtol=None):
  xs_leaves, xs_tree = jax.tree_flatten(xs)
  ys_leaves, ys_tree = jax.tree_flatten(ys)
  assert xs_tree == ys_tree, f"Tree shapes don't match. \n{xs_tree}\n{ys_tree}"
  assert jax.tree_util.tree_all(
      jax.tree_multimap(lambda x, y: np.array(x).shape == np.array(y).shape,
                        xs_leaves, ys_leaves)), "Leaves' shapes don't match."
  assert jax.tree_multimap(
      functools.partial(_assert_numpy_allclose, atol=atol, rtol=rtol),
      xs_leaves, ys_leaves)


def flattened_state_dict(x):
  s = flax.serialization.to_state_dict(x)
  return flax.traverse_util.flatten_dict(s, sep='/')


def tree_shape(x):
  return jax.tree_map(jnp.shape, x)


def tree_equals(x, y):
  return jax.tree_util.tree_all(jax.tree_multimap(operator.eq, x, y))


def _get_multi_adafactor(
    learning_rate: float, step_offset: int,
    adafactor_exclude_from_parameter_scale: Sequence[str]
) -> optim.MultiOptimizer:
  """Get adafactor with support for excluding some parameters from scaling."""

  def _should_not_scale(path):
    return any([s in path for s in adafactor_exclude_from_parameter_scale])

  scaled_vars = traverse_util.ModelParamTraversal(
      lambda path, _: not _should_not_scale(path))
  unscaled_vars = traverse_util.ModelParamTraversal(
      lambda path, _: _should_not_scale(path))
  scaled_opt = optim.Adafactor(
      learning_rate, decay_rate=0.8, step_offset=step_offset)
  unscaled_opt = optim.Adafactor(
      learning_rate,
      decay_rate=0.8,
      step_offset=step_offset,
      multiply_by_parameter_scale=False)
  return optim.MultiOptimizer((scaled_vars, scaled_opt),
                              (unscaled_vars, unscaled_opt))


# Inline test data

MODEL_SHAPE = {
    'decoder': {
        'decoder_norm': {'scale': [128]},
        'layers_0': {
            'encoder_decoder_attention': {
                'key': {'kernel': [128, 256]},
                'out': {'kernel': [256, 128]},
                'query': {'kernel': [128, 256]},
                'value': {'kernel': [128, 256]}},
            'mlp': {
                'wi': {'kernel': [128, 512]},
                'wo': {'kernel': [512, 128]}},
            'pre_cross_attention_layer_norm': {'scale': [128]},
            'pre_mlp_layer_norm': {'scale': [128]},
            'pre_self_attention_layer_norm': {'scale': [128]},
            'self_attention': {
                'key': {'kernel': [128, 256]},
                'out': {'kernel': [256, 128]},
                'query': {'kernel': [128, 256]},
                'value': {'kernel': [128, 256]}}},
        'layers_1': {
            'encoder_decoder_attention': {
                'key': {'kernel': [128, 128]},
                'out': {'kernel': [128, 128]},
                'query': {'kernel': [128, 128]},
                'value': {'kernel': [128, 128]}},
            'mlp': {
                'wi': {'kernel': [128, 512]},
                'wo': {'kernel': [512, 128]}},
            'pre_cross_attention_layer_norm': {'scale': [128]},
            'pre_mlp_layer_norm': {'scale': [128]},
            'pre_self_attention_layer_norm': {'scale': [128]},
            'self_attention': {
                'key': {'kernel': [128, 256]},
                'out': {'kernel': [256, 128]},
                'query': {'kernel': [128, 256]},
                'value': {'kernel': [128, 256]}}},
        'relpos_bias': {'rel_embedding': [2, 32]}},
    'encoder': {
        'encoder_norm': {'scale': [128]},
        'layers_0': {
            'attention': {
                'key': {'kernel': [128, 256]},
                'out': {'kernel': [256, 128]},
                'query': {'kernel': [128, 256]},
                'value': {'kernel': [128, 256]}},
            'mlp': {
                'wi': {'kernel': [128, 512]},
                'wo': {'kernel': [512, 128]}},
            'pre_attention_layer_norm': {'scale': [128]},
            'pre_mlp_layer_norm': {'scale': [128]}},
        'layers_1': {
            'attention': {
                'key': {'kernel': [128, 256]},
                'out': {'kernel': [256, 128]},
                'query': {'kernel': [128, 256]},
                'value': {'kernel': [128, 256]}},
            'mlp': {
                'wi': {'kernel': [128, 512]},
                'wo': {'kernel': [512, 128]}},
            'pre_attention_layer_norm': {'scale': [128]},
            'pre_mlp_layer_norm': {'scale': [128]}},
        'relpos_bias': {'rel_embedding': [2, 32]}},
    'token_embedder': {'embedding': [32128, 128]}}  # pyformat: disable


class AdafactorTest(parameterized.TestCase):

  # Classic Adafactor Behavior Tests

  def test_2D_simple(self):
    x = {'a': jnp.ones((24, 16))}
    opt_def = adafactor.Adafactor(min_dim_size_to_factor=8)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)}
    self.assertTrue(tree_equals(shapes, ref))

  def test_2D_simple_nofactor(self):
    x = {'a': jnp.ones((24, 16))}
    opt_def = adafactor.Adafactor(min_dim_size_to_factor=32)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (1,), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)}
    self.assertTrue(tree_equals(shapes, ref))

  def test_2D_simple_nofactor_momentum(self):
    x = {'a': jnp.ones((24, 16))}
    opt_def = adafactor.Adafactor(min_dim_size_to_factor=32, beta1=0.1)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (24, 16), 'a/v': (24, 16), 'a/v_col': (1,), 'a/v_row': (1,)}
    self.assertTrue(tree_equals(shapes, ref))

  def test_3D_simple(self):
    x = {'a': jnp.ones((24, 4, 16))}
    factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),))
    opt_def = adafactor.Adafactor(
        min_dim_size_to_factor=8, factor_map=factor_map)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)}
    self.assertTrue(tree_equals(shapes, ref))

  def test_init_state(self):
    params = {'x': np.zeros((3, 2))}
    optimizer_def = adafactor.Adafactor(
        learning_rate=0.1, decay_rate=0.8, beta1=None, min_dim_size_to_factor=0)
    state = optimizer_def.init_state(params)

    expected_hyper_params = _AdafactorHyperParams(0.1, True, True, None, 0.8, 0,
                                                  1.0, None, 0, 1e-30, 1e-3)
    self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
    expected_state = OptimizerState(
        0, {
            'x':
                _AdafactorParamState(
                    np.zeros((2,)), np.zeros((3,)), np.zeros(
                        (1,)), np.zeros((1,)))
        })
    check_eq(state, expected_state)

    # unfactorized
    optimizer_def = adafactor.Adafactor(
        learning_rate=0.1, decay_rate=0.8, beta1=0.0, min_dim_size_to_factor=32)
    state = optimizer_def.init_state(params)

    expected_hyper_params = _AdafactorHyperParams(0.1, True, True, 0.0, 0.8, 0,
                                                  1.0, None, 32, 1e-30, 1e-3)
    self.assertEqual(optimizer_def.hyper_params, expected_hyper_params)
    expected_state = OptimizerState(
        0, {
            'x':
                _AdafactorParamState(
                    np.zeros((1,)), np.zeros((1,)), np.zeros(
                        (3, 2)), np.zeros((3, 2)))
        })
    check_eq(state, expected_state)

  def test_apply_gradient(self):
    optimizer_def = adafactor.Adafactor(
        learning_rate=0.1, decay_rate=0.8, min_dim_size_to_factor=0)
    params = {'x': np.ones((3, 2), np.float32)}
    state = OptimizerState(
        1, {
            'x':
                _AdafactorParamState(
                    np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]),
                    np.zeros((1,)), np.zeros((1,)))
        })
    grads = {'x': np.ones((3, 2), np.float32)}
    new_params, new_state = optimizer_def.apply_gradient(
        optimizer_def.hyper_params, params, state, grads)
    expected_new_state = OptimizerState(
        2, {
            'x':
                _AdafactorParamState(
                    np.array([0.9574349, 0.9574349]),
                    np.array([0.6169143, 0.6169143, 0.6169143]), np.zeros(
                        (1,)), np.zeros((1,)))
        })
    expected_new_params = {'x': 0.9 * np.ones((3, 2))}
    check_eq(new_params, expected_new_params)
    check_eq(new_state, expected_new_state, rtol=1e-6)

    # unfactored w momentum
    optimizer_def = adafactor.Adafactor(
        learning_rate=0.1, beta1=0.0, decay_rate=0.8, min_dim_size_to_factor=32)
    params = {'x': np.ones((3, 2), np.float32)}
    state = OptimizerState(
        1, {
            'x':
                _AdafactorParamState(
                    np.zeros(1,), np.zeros(1,), 0.5 * np.ones(
                        (3, 2)), np.zeros((3, 2)))
        })
    grads = {'x': np.ones((3, 2), np.float32)}
    new_params, new_state = optimizer_def.apply_gradient(
        optimizer_def.hyper_params, params, state, grads)
    expected_new_params = {'x': 0.9 * np.ones((3, 2))}
    check_eq(new_params, expected_new_params)
    expected_new_state = OptimizerState(
        2, {
            'x':
                _AdafactorParamState(
                    np.array([0.0]), np.array([0.0]), 0.787174 * np.ones(
                        (3, 2)), 0.1 * np.ones((3, 2)))
        })
    check_eq(new_state, expected_new_state, rtol=1e-6)

  def test_apply_gradient_with_global_norm_clipping(self):
    optimizer_def = adafactor.Adafactor(
        learning_rate=0.1,
        decay_rate=0.8,
        min_dim_size_to_factor=0,
        global_norm_clip_threshold=1.0)
    params = {'x': np.ones((3, 2), np.float32)}
    state = OptimizerState(
        1, {
            'x':
                _AdafactorParamState(
                    np.array([0.9, 0.9]), np.array([0.1, 0.1, 0.1]),
                    np.zeros((1,)), np.zeros((1,)))
        })
    grads = {'x': np.ones((3, 2), np.float32)}
    new_params, new_state = optimizer_def.apply_gradient(
        optimizer_def.hyper_params, params, state, grads)
    expected_new_state = OptimizerState(
        2, {
            'x':
                _AdafactorParamState(
                    np.array([0.478811, 0.478811]),
                    np.array([0.13829, 0.13829, 0.13829]), np.zeros(
                        (1,)), np.zeros((1,)))
        })
    expected_new_params = {'x': 0.9 * np.ones((3, 2))}
    check_eq(new_params, expected_new_params)
    check_eq(new_state, expected_new_state, rtol=1e-6)

  def test_factorizes(self):
    params = {'x': np.zeros((64, 64))}
    optimizer_def = adafactor.Adafactor(
        learning_rate=0.1,
        decay_rate=0.8,
        beta1=None,
        min_dim_size_to_factor=32)
    state = optimizer_def.init_state(params)
    self.assertEqual(state.param_states['x'].v.shape, (1,))
    self.assertEqual(state.param_states['x'].m.shape, (1,))
    self.assertEqual(state.param_states['x'].v_row.shape, (64,))
    self.assertEqual(state.param_states['x'].v_col.shape, (64,))

    params = {'x': np.zeros((31, 64))}
    optimizer_def = adafactor.Adafactor(
        learning_rate=0.1,
        decay_rate=0.8,
        beta1=None,
        min_dim_size_to_factor=32)
    state = optimizer_def.init_state(params)
    self.assertEqual(state.param_states['x'].v.shape, (31, 64))
    self.assertEqual(state.param_states['x'].m.shape, (1,))
    self.assertEqual(state.param_states['x'].v_row.shape, (1,))
    self.assertEqual(state.param_states['x'].v_col.shape, (1,))

  # Manually specified factorization rules tests.

  @parameterized.parameters(
      {'rule': (_ROW, _COL)},
      {'rule': (_COL, _ROW)},
  )
  def test_2D_ignore_specified_factor_rule(self, rule):
    x = {'a': jnp.ones((24, 16))}
    factor_map = adafactor.HParamMap((('a', rule),))
    opt_def = adafactor.Adafactor(
        min_dim_size_to_factor=8, factor_map=factor_map)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    # Since param is 2D, the explicit factor rule should be ignored and falls
    # back to heuristics where v_row corresponds to the smaller dim.
    ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (16,)}
    self.assertTrue(tree_equals(shapes, ref))

  def test_3D_simple_manual_rules(self):
    x = {'a': jnp.ones((24, 4, 16))}

    factor_map = adafactor.HParamMap((('a', (_COL, _BATCH, _ROW)),))
    opt_def = adafactor.Adafactor(
        min_dim_size_to_factor=8, factor_map=factor_map)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (4, 16)}
    self.assertTrue(tree_equals(shapes, ref))

    factor_map = adafactor.HParamMap((('a', (_ROW, _BATCH, _COL)),))
    opt_def = adafactor.Adafactor(
        min_dim_size_to_factor=8, factor_map=factor_map)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (4, 16), 'a/v_row': (24, 4)}
    self.assertTrue(tree_equals(shapes, ref))

    factor_map = adafactor.HParamMap((('a', (_COL, _ROW, _ROW)),))
    opt_def = adafactor.Adafactor(
        min_dim_size_to_factor=8, factor_map=factor_map)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24,), 'a/v_row': (4, 16)}
    self.assertTrue(tree_equals(shapes, ref))

    factor_map = adafactor.HParamMap((('a', (_COL, _COL, _ROW)),))
    opt_def = adafactor.Adafactor(
        min_dim_size_to_factor=8, factor_map=factor_map)
    optimizer = opt_def.create(x)
    shapes = tree_shape(flattened_state_dict(optimizer.state.param_states))
    ref = {'a/m': (1,), 'a/v': (1,), 'a/v_col': (24, 4), 'a/v_row': (16,)}
    self.assertTrue(tree_equals(shapes, ref))

  def test_standard_factor_rules(self):
    # one-off test to double-check that we're following the previous
    # heuristic convention for rows/columns.
    def test_standard_factor_rules():
      token_embedding = (_COL, _ROW)
      attn_qkv = (_ROW, _COL)
      attn_out = (_COL, _ROW)
      mlp_in = (_ROW, _COL)
      mlp_out = (_COL, _ROW)
      return ((r'_layer_norm/(bias|scale)',
               None), (r'(encoder|decoder)_norm/(bias|scale)', None),
              (r'(encoder_decoder_|self_|\b)attention/(query|key|value)/kernel',
               attn_qkv), (r'(encoder_decoder_|self_|\b)attention/out/kernel',
                           attn_out), (r'mlp/DenseGeneral_\d+/bias', None),
              (r'mlp/wi(_\d+)?/kernel', mlp_in), (r'mlp/wo/kernel', mlp_out),
              (r'\brelpos_bias', None), (r'token_embedder', token_embedding),
              (r'.*', adafactor.HEURISTIC_RULE))

    # create fake model parameters
    k = jax.random.PRNGKey(0)
    params = jax.tree_map(
        lambda shape: jax.random.uniform(k, shape),
        MODEL_SHAPE,
        is_leaf=lambda x: isinstance(x, list))
    # make traditional adafactor state with heuristic
    factor_map1 = adafactor.HParamMap(((r'.*', adafactor.HEURISTIC_RULE),))
    optimizer_def1 = adafactor.Adafactor(
        0.1,
        decay_rate=0.8,
        step_offset=0,
        multiply_by_parameter_scale=True,
        factor_map=factor_map1)
    optimizer1 = optimizer_def1.create(params)
    # make traditional adafactor state with explicit rules
    factor_map2 = adafactor.HParamMap(test_standard_factor_rules())
    optimizer_def2 = adafactor.Adafactor(
        0.1,
        decay_rate=0.8,
        step_offset=0,
        multiply_by_parameter_scale=True,
        factor_map=factor_map2)
    optimizer2 = optimizer_def2.create(params)
    # are they the same?
    check_eq(optimizer1.state.param_states, optimizer2.state.param_states)

  @parameterized.parameters(
      {'shape': (64, 64)},
      {'shape': (64, 132)},
      {'shape': (132, 64)},
      {'shape': (132, 132)},
      {'shape': (132, 140)},
      {'shape': (140, 132)},
  )
  def test_no_factor_map_equivalence(self, shape):
    k = random.PRNGKey(0)
    k1, k2 = random.split(k)
    p = {'a': random.uniform(k1, shape)}
    g = {'a': random.uniform(k2, shape)}

    orig_opt = optim.Adafactor(0.1).create(p)
    new_opt = adafactor.Adafactor(0.1, factor_map=None).create(p)
    check_eq(orig_opt.state_dict(), new_opt.state_dict())

    orig_opt1 = orig_opt.apply_gradient(g)
    new_opt1 = new_opt.apply_gradient(g)
    check_eq(orig_opt1.state_dict(), new_opt1.state_dict())

  @parameterized.parameters({
      'shape': (128, 128),
      'rule': (_ROW, _COL)
  }, {
      'shape': (132, 128),
      'rule': (_COL, _ROW)
  }, {
      'shape': (128, 132),
      'rule': (_ROW, _COL)
  })
  def test_simple_equivalence(self, shape, rule):
    k = random.PRNGKey(0)
    k1, k2 = random.split(k)
    k3, k4 = random.split(k1)
    k5, k6 = random.split(k2)

    p = {'a': random.uniform(k3, shape), 'b': random.uniform(k4, shape)}
    g = {'a': random.uniform(k5, shape), 'b': random.uniform(k6, shape)}

    orig_opt = optim.Adafactor(0.1).create(p)
    factor_map = adafactor.HParamMap(
        rules=((('a'), rule), ('.*', adafactor.HEURISTIC_RULE)))
    new_opt = adafactor.Adafactor(0.1, factor_map=factor_map).create(p)
    check_eq(orig_opt.state_dict(), new_opt.state_dict())

    orig_opt1 = orig_opt.apply_gradient(g)
    new_opt1 = new_opt.apply_gradient(g)
    check_eq(orig_opt1.state_dict(), new_opt1.state_dict())

  @parameterized.parameters({'shape': (64, 64)}, {'shape': (132, 132)})
  def test_multiply_by_parameter_scale_equivalence(self, shape):
    # Use large parameter values to magnify the parameter scaling effect.
    p = {'a': np.random.randn(*shape) * 100, 'b': np.random.randn(*shape) * 100}
    g = {'a': np.random.randn(*shape), 'b': np.random.randn(*shape)}
    orig_opt = _get_multi_adafactor(
        3.0, 0, adafactor_exclude_from_parameter_scale=('a',)).create(p)
    scaling_map = adafactor.HParamMap([('a', False), ('.*', True)])
    new_opt = adafactor.Adafactor(
        3.0, multiply_by_parameter_scale=scaling_map).create(p)
    check_eq(orig_opt.state_dict(), new_opt.state_dict())

    orig_opt1 = orig_opt.apply_gradient(g)
    new_opt1 = new_opt.apply_gradient(g)
    check_eq(orig_opt1.state_dict(), new_opt1.state_dict())

  def test_3d_without_factor_map(self):
    x = {'a': jnp.ones((24, 4, 16))}
    opt_def = adafactor.Adafactor(factor_map=None)
    with self.assertRaises(ValueError):
      _ = opt_def.create(x)


if __name__ == '__main__':
  absltest.main()
