# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for t5x.checkpoints."""
import concurrent.futures
import functools
import itertools
import os
from typing import Any, Mapping

from absl import flags
from absl.testing import absltest
from absl.testing import parameterized
from flax import serialization
from flax import traverse_util
from flax.metrics import tensorboard
import jax
import jax.numpy as jnp
import numpy as np
from t5x import checkpoints
from t5x import optimizers
from t5x import partitioning
from t5x import state_utils
from t5x import test_utils
from t5x import train_state as train_state_lib
from t5x import utils
import tensorflow as tf
from tensorflow.io import gfile
import tensorstore as ts

# Parse absl flags test_srcdir and test_tmpdir.
jax.config.parse_flags_with_absl()

mock = absltest.mock
PartitionSpec = partitioning.PartitionSpec
FLAGS = flags.FLAGS
LazyArray = checkpoints.LazyArray

TESTDATA = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'testdata')

FlaxOptimTrainState = train_state_lib.FlaxOptimTrainState


def make_train_state(
    *,
    step: int,
    params: Mapping[str, Any],
    param_states: Mapping[str, Any],
    flax_optimizer_def: optimizers.OptimizerDefType = optimizers.sgd(0.1)
) -> FlaxOptimTrainState:
  """Helper to construct a train state for testing."""
  optimizer = optimizers.Optimizer(
      flax_optimizer_def,
      state=optimizers.OptimizerState(step=step, param_states=param_states),
      target=params)
  return FlaxOptimTrainState(optimizer)


def make_train_state_multi_optimizer(params: Mapping[str, Any],
                                     param_states: Mapping[str, Any],
                                     step: int) -> FlaxOptimTrainState:
  """Helper to construct a train state with multi optimizer for testing."""
  optimizer = optimizers.Optimizer(
      optimizers.MultiOptimizer([
          (traverse_util.ModelParamTraversal(
              lambda path, _: 'kernel' not in path), optimizers.sgd(0.1)),
      ]),
      state=optimizers.OptimizerState(step=step, param_states=param_states),
      target=params)
  return FlaxOptimTrainState(optimizer)


def update_train_state_step(train_state: FlaxOptimTrainState,
                            step: int) -> FlaxOptimTrainState:
  """Helper to update the step inside TrainState."""
  state_dict = train_state.state_dict()
  state_dict['state']['step'] = step
  return train_state.restore_state(state_dict)


class CheckpointChunkShapeTest(absltest.TestCase):

  def test_simple(self):
    self.assertEqual([4096, 4096],
                     checkpoints._choose_chunk_shape([4096, 4096], 4096 * 4096))

    self.assertEqual([4096, 4096],
                     checkpoints._choose_chunk_shape([8192, 8192], 4096 * 4096))

    self.assertEqual([4096, 2731],
                     checkpoints._choose_chunk_shape([8192, 8193], 4096 * 4096))

    self.assertEqual([4096], checkpoints._choose_chunk_shape([8192], 4096))

    self.assertEqual([2731], checkpoints._choose_chunk_shape([8193], 4096))


class CheckpointsTest(parameterized.TestCase):

  def setUp(self):
    super().setUp()
    self.train_state = make_train_state(
        step=np.int32(42),
        params={
            'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
            'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
        },
        param_states={
            'bias': np.int32(1),
            'kernel': np.array([1, 2], np.uint8)
        })
    self.train_state_multi_optimizer = make_train_state_multi_optimizer(
        step=np.int32(42),
        params={
            'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
            'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
        },
        param_states={
            'bias': np.int32(1),
            'kernel': None
        })
    self.default_mesh_axes = make_train_state(
        step=None,
        params={
            'bias': PartitionSpec('model', None),
            'kernel': PartitionSpec(None, 'model')
        },
        param_states={
            'bias': None,
            'kernel': None
        })

    self.ds = tf.data.Dataset.range(1024)

    self.checkpoints_dir = self.create_tempdir()
    self.tmp_dir = self.checkpoints_dir.full_path

    fake_checkpoints = self.create_tempdir()
    self.fake_checkpoints = fake_checkpoints.full_path
    self.steps = (0, 100, 200)
    for step in self.steps:
      step_dir = fake_checkpoints.mkdir(f'checkpoint_{step}')
      step_dir.create_file('checkpoint')

  @mock.patch('jax._src.lib.xla_bridge.process_index')
  @mock.patch('jax.devices')
  @mock.patch('jax.local_devices')
  def get_partitioner(self,
                      process_index,
                      host_count,
                      num_partitions,
                      local_devices_fn,
                      devices_fn,
                      process_index_fn,
                      params_on_devices: bool = True,
                      mesh_axes=None):
    host_count_to_layout = {
        1: (2, 2, 1, 2),
        2: (4, 2, 1, 2),
        4: (4, 4, 1, 2),
        8: (4, 8, 1, 2),
        16: (8, 8, 1, 2),
        32: (8, 16, 1, 2)
    }
    devices = test_utils.make_devices(*host_count_to_layout[host_count])
    devices_fn.return_value = devices
    local_devices = [d for d in devices if d.process_index == 0]
    local_devices_fn.return_value = local_devices
    process_index_fn.return_value = process_index
    num_partitions_to_mps = {
        1: (1, 1, 1, 1),
        2: (1, 1, 1, 2),
        4: (2, 1, 1, 2),
        16: (4, 2, 1, 2)
    }
    mesh = partitioning.get_mesh(
        model_parallel_submesh=num_partitions_to_mps[num_partitions],
        input_devices=devices,
        input_local_devices=local_devices)
    mesh_axes = mesh_axes or self.default_mesh_axes
    local_chunker = partitioning.LocalChunker(mesh)

    class TestPartitioner(partitioning.BasePartitioner):

      def __init__(self):
        self.move_params_to_devices_calls = 0
        super().__init__(
            num_partitions, None, params_on_devices=params_on_devices)

      @property
      def _local_chunker(self):
        return local_chunker

      @property
      def _mesh(self):
        return mesh

      def partition(self,
                    fn,
                    in_axis_resources,
                    out_axis_resources,
                    static_argnums=(),
                    donate_argnums=()):
        raise NotImplementedError

      def compile(self, partitioned_fn, *args):
        raise NotImplementedError

      def move_params_to_devices(self, train_state, train_state_axes):
        assert params_on_devices
        return train_state

      def get_mesh_axes(self, train_state):
        return mesh_axes

    return TestPartitioner()

  # pylint:disable=no-value-for-parameter
  @mock.patch(
      'jax.experimental.multihost_utils.sync_global_devices', return_value=None)
  @mock.patch('time.time', return_value=0)
  @mock.patch('jax.host_count')
  @mock.patch('jax.process_index')
  def call_host_checkpointer(self,
                             process_index,
                             host_count,
                             partitioner,
                             fn,
                             save_dtype,
                             ds_iter,
                             mock_process_index,
                             mock_host_count,
                             unused_mock_host_time,
                             unused_mock_sync_devices,
                             restore_dtype=np.float32):
    mock_process_index.return_value = process_index
    mock_host_count.return_value = host_count

    checkpointer = checkpoints.Checkpointer(
        self.train_state,
        partitioner,
        self.tmp_dir,
        ds_iter,
        save_dtype=save_dtype,
        restore_dtype=restore_dtype)
    return fn(checkpointer)

  # pylint:disable=no-value-for-parameter
  @mock.patch(
      'jax.experimental.multihost_utils.sync_global_devices', return_value=None)
  @mock.patch('time.time', return_value=0)
  @mock.patch('jax.host_count')
  @mock.patch('jax.process_index')
  def call_host_multioptimizer_checkpointer(self, process_index, host_count,
                                            partitioner, fn, save_dtype,
                                            ds_iter, mock_process_index,
                                            mock_host_count,
                                            unused_mock_host_time,
                                            unused_mock_sync_devices):
    mock_process_index.return_value = process_index
    mock_host_count.return_value = host_count

    checkpointer = checkpoints.Checkpointer(
        self.train_state_multi_optimizer,
        partitioner,
        self.tmp_dir,
        ds_iter,
        save_dtype=save_dtype)
    return fn(checkpointer)

  def test_get_parameter_infos(self):
    train_state = make_train_state(
        params={
            'bias': np.ones((8192, 8192), np.float32),
            'kernel': np.ones((2, 16), np.float32)
        },
        param_states={
            'bias': np.int32(1),
            'kernel': np.array([1, 2])
        },
        step=np.int32(42))
    # host 3 of a 4x4 with mesh 'model' dim == 16
    partitioner = self.get_partitioner(3, 4, 16)
    checkpointer = checkpoints.Checkpointer(train_state, partitioner,
                                            self.tmp_dir)

    expected_parameter_infos = {
        'state': {
            'step':
                checkpoints._ParameterInfo(
                    name='state/step', shape=(), ts_spec=None, local_chunk_info=None, axes=None),
            'param_states': {
                'bias':
                    checkpoints._ParameterInfo(
                        name='state/param_states/bias',
                        shape=(),
                        ts_spec=None,
                        local_chunk_info=None, axes=None),
                'kernel':
                    checkpoints._ParameterInfo(
                        name='state/param_states/kernel',
                        shape=(2,),
                        ts_spec=None,
                        local_chunk_info=None, axes=None)
            }
        },
        'target': {
            'bias':
                checkpoints._ParameterInfo(
                    name='target/bias',
                    shape=(8192, 8192),
                    ts_spec=ts.Spec({
                        'driver': 'zarr',
                        'dtype': 'float32',
                        'kvstore': {  # pylint:disable=duplicate-key
                            'driver': 'file',
                            'path': 'target.bias',
                        },
                        'metadata': {
                            'chunks': [4096, 4096],
                            'compressor': {
                                'id': 'gzip'
                            },
                            'shape': [8192, 8192],
                        },
                    }),
                    local_chunk_info=partitioning.LocalChunkInfo(
                        slice=(slice(4096, 8192, None), slice(None, None,
                                                              None)),
                        replica_id=1), axes=PartitionSpec('model', None)),
            'kernel':
                checkpoints._ParameterInfo(
                    name='target/kernel',
                    shape=(2, 16),
                    ts_spec=ts.Spec({
                        'driver': 'zarr',
                        'dtype': 'float32',
                        'kvstore': {  # pylint:disable=duplicate-key
                            'driver': 'file',
                            'path': 'target.kernel',
                        },
                        'metadata': {
                            'chunks': [2, 8],
                            'compressor': {
                                'id': 'gzip'
                            },
                            'shape': [2, 16],
                        },
                    }),
                    local_chunk_info=partitioning.LocalChunkInfo(
                        slice=(slice(None, None, None), slice(8, 16, None)),
                        replica_id=1), axes=PartitionSpec(None, 'model'))
        }
    }  # pyformat: disable
    jax.tree_multimap(self.assertEqual, checkpointer._get_parameter_infos(),
                      expected_parameter_infos)

  def test_get_multioptimizer_parameter_infos(self):
    train_state = make_train_state(
        step=np.int32(42),
        params={
            'bias': np.ones((8192, 8192), jnp.bfloat16),
            'kernel': np.ones((2, 16), np.float32)
        },
        param_states={
            'bias': np.int32(1),
            # The parameter state for Kernel is `None` as if we have a
            # multioptimizer that is not updating this parameter.
            'kernel': None
        })
    # host 3 of a 4x4 with mesh 'model' dim == 16
    partitioner = self.get_partitioner(3, 4, 16)
    checkpointer = checkpoints.Checkpointer(train_state, partitioner,
                                            self.tmp_dir)
    kernel_state_info = (
        checkpointer._get_parameter_infos()['state']['param_states']['kernel'])
    self.assertIsNone(kernel_state_info)

  def test_all_steps(self):
    partitioner = self.get_partitioner(0, 1, 1)
    checkpointer = self.call_host_checkpointer(0, 1, partitioner, lambda c: c,
                                               np.float32, None)

    self.assertIsNone(checkpointer.latest_step())
    for step in ['0', '42', '10', '999.tmp-0', '100']:
      d = os.path.join(checkpointer.checkpoints_dir, f'checkpoint_{step}')
      gfile.makedirs(d)
      ckpt = os.path.join(d, 'checkpoint')
      with gfile.GFile(ckpt, 'w') as f:
        f.write('')
    self.assertSequenceEqual(
        checkpoints.all_steps(checkpointer.checkpoints_dir + '/'),
        [0, 10, 42, 100])

  def test_all_latest_step(self):
    partitioner = self.get_partitioner(0, 1, 1)
    checkpointer = self.call_host_checkpointer(0, 1, partitioner, lambda c: c,
                                               np.float32, None)

    self.assertIsNone(checkpointer.latest_step())

    for step in ['0', '42', '10', '999.tmp-0', '100']:
      d = os.path.join(checkpointer.checkpoints_dir, f'checkpoint_{step}')
      gfile.makedirs(d)
      ckpt = os.path.join(d, 'checkpoint')
      with gfile.GFile(ckpt, 'w') as f:
        f.write('')

    self.assertSequenceEqual(checkpointer.all_steps(), [0, 10, 42, 100])
    self.assertEqual(checkpointer.latest_step(), 100)

    # Remove checkpoint file for step 100 (but leave directory).
    gfile.remove(ckpt)
    self.assertSequenceEqual(checkpointer.all_steps(), [0, 10, 42])
    self.assertEqual(checkpointer.latest_step(), 42)

  def test_all_latest_step_public(self):
    self.assertIsNone(checkpoints.latest_step(self.tmp_dir))

    for step in ['0', '42', '10', '999.tmp-0', '100']:
      d = os.path.join(self.tmp_dir, f'checkpoint_{step}')
      gfile.makedirs(d)
      ckpt = os.path.join(d, 'checkpoint')
      with gfile.GFile(ckpt, 'w') as f:
        f.write('')

    self.assertSequenceEqual(
        checkpoints.all_steps(self.tmp_dir), [0, 10, 42, 100])
    self.assertEqual(checkpoints.latest_step(self.tmp_dir), 100)

    # Remove checkpoint file for step 100 (but leave directory).
    gfile.remove(ckpt)
    self.assertSequenceEqual(checkpoints.all_steps(self.tmp_dir), [0, 10, 42])
    self.assertEqual(checkpoints.latest_step(self.tmp_dir), 42)

  def validate_restore(self,
                       host_count,
                       num_partitions,
                       step=42,
                       checkpoint_dataset=False,
                       expected_restore_dtype=np.float32,
                       lazy_parameters=False,
                       disable_partitioning=False):
    params = self.train_state.params
    param_states = self.train_state.param_states

    for i in range(host_count):
      partitioner = self.get_partitioner(
          i,
          host_count,
          num_partitions,
          params_on_devices=not lazy_parameters,
          mesh_axes=jax.tree_map(lambda x: None, self.default_mesh_axes)
          if disable_partitioning else None)
      ds_shard_id = partitioner.get_data_layout().shard_id

      bias_slice = partitioner.get_local_chunk_info(params['bias'].shape,
                                                    ('model', None)).slice
      kernel_slice = partitioner.get_local_chunk_info(params['kernel'].shape,
                                                      (None, 'model')).slice

      ds_iter = iter(self.ds)

      actual_train_state = self.call_host_checkpointer(
          i,
          host_count,
          partitioner,
          lambda c: c.restore(  # pylint: disable=g-long-lambda
              step=step,
              lazy_parameters=lazy_parameters),
          np.float32,
          ds_iter if checkpoint_dataset else None,
          restore_dtype=expected_restore_dtype)
      if lazy_parameters:
        actual_train_state = jax.tree_map(lambda x: x.get(), actual_train_state)
      self.assertEqual(actual_train_state._optimizer.optimizer_def,
                       self.train_state._optimizer.optimizer_def)

      self.assertEqual(actual_train_state.step, step)
      self.assertEqual(actual_train_state.step.dtype, np.int32)
      self.assertEqual(actual_train_state._optimizer.state.step.dtype, np.int32)
      jax.tree_multimap(np.testing.assert_array_equal,
                        actual_train_state.param_states, param_states)
      self.assertEqual(actual_train_state.param_states['kernel'].dtype,
                       np.uint8)
      self.assertSameElements(actual_train_state.params, ('bias', 'kernel'))
      self.assertTrue(
          all(
              jax.tree_leaves(
                  jax.tree_map(lambda x: x.dtype == expected_restore_dtype,
                               actual_train_state.params))))
      np.testing.assert_equal(actual_train_state.params['bias'],
                              params['bias'][bias_slice])
      np.testing.assert_equal(actual_train_state.params['kernel'],
                              params['kernel'][kernel_slice])
      if checkpoint_dataset:
        # The next value from the restored iterator should equal the
        # replica set id.
        self.assertEqual(next(ds_iter).numpy(), ds_shard_id)

  def validate_multioptimizer_restore(self,
                                      host_count,
                                      num_partitions,
                                      step=42,
                                      checkpoint_dataset=False,
                                      expected_restore_dtype=np.float32):
    params = self.train_state_multi_optimizer.params
    param_states = self.train_state_multi_optimizer.param_states

    for i in range(host_count):
      partitioner = self.get_partitioner(i, host_count, num_partitions)
      ds_shard_id = partitioner.get_data_layout().shard_id

      bias_slice = partitioner.get_local_chunk_info(params['bias'].shape,
                                                    ('model', None)).slice
      kernel_slice = partitioner.get_local_chunk_info(params['kernel'].shape,
                                                      (None, 'model')).slice

      ds_iter = iter(self.ds)

      actual_train_state = self.call_host_multioptimizer_checkpointer(
          i, host_count, partitioner, lambda c: c.restore(step=step),
          np.float32, ds_iter if checkpoint_dataset else None)
      actual_optimizer = actual_train_state._optimizer  # pylint: disable=protected-access
      actual_step = actual_train_state.step
      actual_params = actual_train_state.params
      actual_param_states = actual_train_state.param_states
      self.assertEqual(
          actual_optimizer.optimizer_def,
          self.train_state_multi_optimizer._optimizer.optimizer_def)
      self.assertEqual(actual_optimizer.state.step.dtype, np.int32)
      jax.tree_map(lambda x: self.assertEqual(x.dtype, expected_restore_dtype),
                   actual_optimizer.target)
      self.assertEqual(actual_step, step)
      self.assertEqual(actual_step.dtype, np.int32)
      jax.tree_multimap(np.testing.assert_array_equal, actual_param_states,
                        param_states)
      self.assertSameElements(actual_params, ('bias', 'kernel'))
      self.assertTrue(
          all(
              jax.tree_leaves(
                  jax.tree_map(lambda x: x.dtype == expected_restore_dtype,
                               actual_params))))
      np.testing.assert_equal(actual_params['bias'], params['bias'][bias_slice])
      np.testing.assert_equal(actual_params['kernel'],
                              params['kernel'][kernel_slice])
      if checkpoint_dataset:
        # The next value from the restored iterator should equal the
        # replica set id.
        self.assertEqual(next(ds_iter).numpy(), ds_shard_id)

  def validate_save(self,
                    host_count,
                    num_partitions,
                    step=42,
                    save_dtype=np.float32,
                    checkpoint_dataset=False,
                    multi_optimizer=False,
                    disable_partitioning=False):
    if multi_optimizer:
      params = self.train_state_multi_optimizer.params
      param_states = self.train_state_multi_optimizer.param_states
      optimizer_def = self.train_state_multi_optimizer._optimizer.optimizer_def
    else:
      params = self.train_state.params
      param_states = self.train_state.param_states
      optimizer_def = self.train_state._optimizer.optimizer_def
    # Update these on each save.
    step = np.int32(step)
    expected_bias = np.zeros((4, 1), save_dtype)
    expected_kernel = np.zeros((2, 16), save_dtype)

    bias_tspec = {
        'driver': 'zarr',
        'kvstore': {
            'driver': 'file',
            'path': f'{self.tmp_dir}/checkpoint_{step}.tmp-0/target.bias',
        }
    }
    kernel_tspec = {
        'driver': 'zarr',
        'kvstore': {
            'driver': 'file',
            'path': f'{self.tmp_dir}/checkpoint_{step}.tmp-0/target.kernel',
        }
    }

    # Test save.
    # Each host sets its partition to its host number + 1.
    # Go in reverse since host 0 renames the directory.
    for i in reversed(range(host_count)):
      partitioner = self.get_partitioner(
          i,
          host_count,
          num_partitions,
          mesh_axes=jax.tree_map(lambda x: None, self.default_mesh_axes)
          if disable_partitioning else None)
      data_layout = partitioner.get_data_layout()
      num_ds_shards = data_layout.num_shards
      ds_shard_id = data_layout.shard_id
      chunk_id_for_shard = partitioner.get_local_chunk_info(
          jnp.ones((num_ds_shards,)), ['data']).replica_id

      bias_chunk = partitioner.get_local_chunk_info(params['bias'].shape,
                                                    ('model', None))
      kernel_chunk = partitioner.get_local_chunk_info(params['kernel'].shape,
                                                      (None, 'model'))

      ds_iter = iter(self.ds)

      # pylint:disable=cell-var-from-loop
      def _save_ckpt(checkpointer):
        # Set the checkpoint so that the next value on restore will be the
        # replica set id.
        for _ in range(ds_shard_id):
          next(ds_iter)

        train_state = make_train_state(
            step=step,
            params={
                'bias': params['bias'][bias_chunk.slice],
                'kernel': params['kernel'][kernel_chunk.slice]
            },
            param_states=param_states,
            flax_optimizer_def=optimizer_def)
        checkpointer.save(train_state)

      # pylint:enable=cell-var-from-loop

      self.call_host_checkpointer(i, host_count, partitioner, _save_ckpt,
                                  save_dtype,
                                  ds_iter if checkpoint_dataset else None)

      if disable_partitioning:
        continue

      # Read the current TensorStore.
      if i == 0:
        # Host 0 moves the files.
        bias_tspec['kvstore']['path'] = (
            bias_tspec['kvstore']['path'].replace('.tmp-0', ''))
        kernel_tspec['kvstore']['path'] = (
            kernel_tspec['kvstore']['path'].replace('.tmp-0', ''))

      if checkpoint_dataset:
        ckpt_dir = f'{self.tmp_dir}/checkpoint_{step}'
        if i != 0:
          ckpt_dir += '.tmp-0'
        ds_ckpt_glob = gfile.glob(ckpt_dir + '/train_ds-' +
                                  f'{ds_shard_id:03}-of-{num_ds_shards:03}*')
        if chunk_id_for_shard == 0:
          self.assertLen(ds_ckpt_glob, 2)
        else:
          self.assertEmpty(ds_ckpt_glob)

      # only replica_id=0 is saved for each array chunk
      if bias_chunk.replica_id == 0:
        current_bias = ts.open(bias_tspec).result().read().result().view(
            save_dtype)
        expected_bias[bias_chunk.slice] = (params['bias'][bias_chunk.slice])
        np.testing.assert_equal(current_bias, expected_bias)

      if kernel_chunk.replica_id == 0:
        current_kernel = ts.open(kernel_tspec).result().read().result().view(
            save_dtype)
        expected_kernel[kernel_chunk.slice] = (
            params['kernel'][kernel_chunk.slice])
        np.testing.assert_equal(current_kernel, expected_kernel)

    with gfile.GFile(f'{self.tmp_dir}/checkpoint_{step}/checkpoint', 'rb') as f:
      ckpt_contents = serialization.msgpack_restore(f.read())
    self.assertEqual(ckpt_contents['version'], checkpoints.VERSION)
    jax.tree_multimap(np.testing.assert_allclose,
                      ckpt_contents['optimizer']['state']['param_states'],
                      param_states)
    self.assertEqual(ckpt_contents['optimizer']['state']['step'].dtype,
                     np.int32)
    if disable_partitioning:
      # Parameters should also be in the msgpack checkpoint file.
      jax.tree_multimap(
          np.testing.assert_allclose, ckpt_contents['optimizer']['target'],
          jax.tree_map(lambda arr: arr.astype(save_dtype), params))

    # Jax tree maps ignore Nones so actually check this value is None
    if multi_optimizer:
      self.assertIsNone(
          ckpt_contents['optimizer']['state']['param_states']['kernel'])

  # (host_count, num_partitions)
  TOPOLOGIES = [
      (1, 1),  # 1 host, 1 partition
      (1, 2),  # 1 host, 2 partitions
      (2, 1),  # 2 hosts, 1 partition
      (2, 2),  # 2 hosts, 2 partitions
      (4, 4),  # 4 hosts, 4 partitions
      (4, 1),  # 4 hosts, 1 partition
      (4, 2),  # 4 hosts, 2 partitions
      (8, 2),  # 8 hosts, 2 partitions
  ]

  DTYPES = [
      jnp.int32, jnp.float32, jnp.bfloat16, jnp.uint32, jnp.int64, jnp.float64
  ]

  @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
  def test_save_restore(self, save_topology, restore_topology):
    self.validate_save(*save_topology)
    self.validate_restore(*restore_topology)

  @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
  def test_save_restore_lazy(self, save_topology, restore_topology):
    self.validate_save(*save_topology)
    self.validate_restore(*restore_topology, lazy_parameters=True)

  @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
  def test_save_multioptimizer_restore(self, save_topology, restore_topology):
    self.validate_save(*save_topology)
    self.validate_multioptimizer_restore(*restore_topology)

  @parameterized.parameters(itertools.product(TOPOLOGIES, TOPOLOGIES))
  def test_multioptimizer_save_multioptimizer_restore(self, save_topology,
                                                      restore_topology):
    self.validate_save(*save_topology, multi_optimizer=True)
    self.validate_multioptimizer_restore(*restore_topology)

  def test_load_t5x_checkpoint(self):
    self.validate_save(1, 1)
    ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
    jax.tree_multimap(np.testing.assert_array_equal,
                      self.train_state.state_dict(), ckpt)

  def test_load_t5x_checkpoint_of_multioptimizer(self):
    self.validate_save(1, 1, multi_optimizer=True)
    ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
    jax.tree_multimap(np.testing.assert_array_equal,
                      self.train_state_multi_optimizer.state_dict(), ckpt)
    # Jax tree maps ignore Nones so actually check this value is None
    self.assertIsNone(ckpt['state']['param_states']['kernel'])

  def test_load_t5x_checkpoint_lazy(self):
    self.validate_save(1, 1)
    ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
    lazy_ckpt = checkpoints.load_t5x_checkpoint(
        self.tmp_dir, lazy_parameters=True)
    lazy_loaded_ckpt = jax.tree_map(lambda x: x.get(), lazy_ckpt)
    jax.tree_multimap(np.testing.assert_array_equal, ckpt, lazy_loaded_ckpt)

  def test_load_t5x_checkpoint_of_multioptimizer_lazy(self):
    self.validate_save(1, 1, multi_optimizer=True)
    ckpt = checkpoints.load_t5x_checkpoint(self.tmp_dir)
    lazy_ckpt = checkpoints.load_t5x_checkpoint(
        self.tmp_dir, lazy_parameters=True)
    lazy_loaded_ckpt = jax.tree_map(lambda x: x.get(), lazy_ckpt)
    jax.tree_multimap(np.testing.assert_array_equal, ckpt, lazy_loaded_ckpt)
    # Jax tree maps ignore Nones so actually check this value is None
    self.assertIsNone(lazy_loaded_ckpt['state']['param_states']['kernel'])

  @parameterized.parameters(TOPOLOGIES)
  def test_save_restore_dataset(self, *topology):
    # Note that we must use the same number of replica sets on save/restore.
    self.validate_save(*topology, checkpoint_dataset=True)
    self.validate_restore(*topology, checkpoint_dataset=True)

  @parameterized.parameters(itertools.product(DTYPES, DTYPES))
  def test_save_as_type(self, save_dtype, restore_dtype):
    self.validate_save(1, 1, save_dtype=save_dtype)
    self.validate_restore(1, 1, expected_restore_dtype=restore_dtype)

  @parameterized.parameters(TOPOLOGIES)
  def test_reload_wrong_shape(self, *restore_topology):
    self.validate_save(1, 1)
    self.train_state = make_train_state(
        step=np.int32(42),
        params={
            'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
            'kernel': np.arange(32, dtype=np.float32).reshape((4, 8))
        },
        param_states={
            'bias': np.int32(1),
            'kernel': np.array([1, 2])
        })
    with self.assertRaisesWithLiteralMatch(
        ValueError,
        'Shape of `target/kernel` in checkpoint (2, 16) does not match '
        'expected (4, 8).'):
      self.validate_restore(*restore_topology)

  @parameterized.parameters(TOPOLOGIES)
  def test_save_partitioned_restore_non_partitioned(self, *restore_topology):
    # Save with default partitioning.
    self.validate_save(2, 2)
    # Restore without partitioning.
    self.validate_restore(*restore_topology, disable_partitioning=True)

  @parameterized.parameters(TOPOLOGIES)
  def test_save_non_partitioned_restore_partitioned(self, *restore_topology):
    # Save without partitioning.
    self.validate_save(2, 1, disable_partitioning=True)
    # Restore with partitioning.
    self.validate_restore(*restore_topology)

  @parameterized.parameters(TOPOLOGIES)
  def test_save_non_partitioned_restore_non_partitioned(self,
                                                        *restore_topology):
    # Save without partitioning.
    self.validate_save(2, 1, disable_partitioning=True)
    # Restore with partitioning.
    self.validate_restore(*restore_topology, disable_partitioning=True)

  @mock.patch('time.time', return_value=0)
  def test_keep(self, unused_mock_time):
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state
    checkpointer = checkpoints.Checkpointer(
        train_state, no_partitions_partitioner, self.tmp_dir, keep=2)

    checkpointer.save(update_train_state_step(train_state, 42))
    self.assertSequenceEqual(checkpointer.all_steps(), [42])

    checkpointer.save(update_train_state_step(train_state, 43))
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 43])

    checkpointer.save(update_train_state_step(train_state, 44))
    self.assertSequenceEqual(checkpointer.all_steps(), [43, 44])

    checkpointer.keep = 1
    checkpointer.save(update_train_state_step(train_state, 45))
    self.assertSequenceEqual(checkpointer.all_steps(), [45])

    checkpointer.keep = 3
    checkpointer.save(update_train_state_step(train_state, 46))
    self.assertSequenceEqual(checkpointer.all_steps(), [45, 46])

  @mock.patch('time.time', return_value=0)
  def test_keep_pinned(self, unused_mock_time):
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state
    checkpointer = checkpoints.Checkpointer(
        train_state, no_partitions_partitioner, self.tmp_dir, keep=1)

    checkpointer.save(update_train_state_step(train_state, 42))
    self.assertSequenceEqual(checkpointer.all_steps(), [42])

    # Mark the checkpoint as pinned by creating the ALWAYS KEEP file.
    ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{42}')
    ckpt_dir.create_file('PINNED')

    checkpointer.save(update_train_state_step(train_state, 43))

    # Assert both the pinned and the most recent checkpoints are saved.
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 43])

    checkpointer.save(update_train_state_step(train_state, 44))

    # Assert the non-pinned checkpoint gets deleted, but the pinned and the most
    # recent one are still saved.
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 44])

  @mock.patch('time.time', return_value=0)
  def test_keep_with_save_best_checkpointer(self, unused_mock_time):
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state

    checkpointer = checkpoints.SaveBestCheckpointer(
        train_state,
        no_partitions_partitioner,
        self.tmp_dir,
        keep=2,
        metric_name_to_monitor='train/accuracy',
        metric_mode='max',
        keep_checkpoints_without_metrics=False)

    # Test that without a valid set of metrics deletion falls back to oldest
    # step (since keep_checkpoints_without_metrics is set to False).
    checkpointer.save(update_train_state_step(train_state, 41))
    self.assertSequenceEqual(checkpointer.all_steps(), [41])
    checkpointer.save(update_train_state_step(train_state, 42))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42])
    checkpointer.save(update_train_state_step(train_state, 43))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43])
    checkpointer.save(update_train_state_step(train_state, 44))
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44])

    # Now create some metrics for steps 42, 43 and 44.
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(self.tmp_dir, 'train'))
    summary_writer.scalar('accuracy', 0.9, 42)
    summary_writer.scalar('accuracy', 0.8, 43)
    summary_writer.scalar('accuracy', 0.7, 44)

    # Verify that both the newest (without a metrics) and best accuracy
    # checkpoints are kept.
    checkpointer.save(update_train_state_step(train_state, 45))
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 45])

    # Change mode to `min` and check that the checkpoints with highest accuracy
    # are removed.
    checkpointer._metric_mode = 'min'

    # Add metrics to newly created checkpoint as well as a new checkpoint.
    summary_writer.scalar('accuracy', 0.95, 45)
    checkpointer.save(update_train_state_step(train_state, 46))
    summary_writer.scalar('accuracy', 0.99, 46)
    checkpointer.save(update_train_state_step(train_state, 47))
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 47])

  @mock.patch('time.time', return_value=0)
  def test_keep_pinned_save_best_checkpointer(self, unused_mock_time):
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state

    checkpointer = checkpoints.SaveBestCheckpointer(
        train_state,
        no_partitions_partitioner,
        self.tmp_dir,
        keep=2,
        metric_name_to_monitor='train/accuracy',
        metric_mode='max',
        keep_checkpoints_without_metrics=False)

    summary_writer = tensorboard.SummaryWriter(
        os.path.join(self.tmp_dir, 'train'))

    checkpointer.save(update_train_state_step(train_state, 42))
    summary_writer.scalar('accuracy', 0.9, 42)
    checkpointer.save(update_train_state_step(train_state, 43))
    summary_writer.scalar('accuracy', 0.7, 43)
    checkpointer.save(update_train_state_step(train_state, 44))
    summary_writer.scalar('accuracy', 0.8, 44)
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44])

    # Mark checkpoint 43 as always keep.
    ckpt_dir = self.checkpoints_dir.mkdir(f'checkpoint_{43}')
    always_keep_ckpt_43 = ckpt_dir.create_file('PINNED')

    # Verify that the pinned checkpoint 43 is always saved even though it does
    # not have the best metrics, and keep = 2.
    checkpointer.save(update_train_state_step(train_state, 45))
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 43, 44, 45])
    checkpointer.save(update_train_state_step(train_state, 46))
    summary_writer.scalar('accuracy', 0.6, 46)

    # Remove the ALWAYS KEEP file for checkpoint 43.
    gfile.rmtree(always_keep_ckpt_43.full_path)

    # Checkpoint 43 should get deleted in the next update since it is not
    # pinned and does not have the best metrics.
    checkpointer.save(update_train_state_step(train_state, 47))
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 44, 47])

  @mock.patch('time.time', return_value=0)
  def test_keep_pinned_save_best_checkpointer_missing_metrics(
      self, unused_mock_time):
    """Test for `keep_checkpoints_without_metrics` behavior."""
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state

    # Use SaveBestCheckpointer with default keep_checkpoints_without_metrics.
    checkpointer = checkpoints.SaveBestCheckpointer(
        train_state,
        no_partitions_partitioner,
        self.tmp_dir,
        keep=1,
        metric_name_to_monitor='train/accuracy',
        metric_mode='max')

    # Pre-create metrics for only some of the steps.
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(self.tmp_dir, 'train'))
    summary_writer.scalar('accuracy', 0.5, 43)
    summary_writer.scalar('accuracy', 0.4, 44)
    summary_writer.scalar('accuracy', 0.8, 45)
    summary_writer.scalar('accuracy', 0.3, 46)

    # Verify that we keep checkpoints for 41 and 42 even without metrics.
    checkpointer.save(update_train_state_step(train_state, 41))
    checkpointer.save(update_train_state_step(train_state, 42))
    checkpointer.save(update_train_state_step(train_state, 43))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43])

    # Mark 41 and 43 checkpoints as pinned / to not be removed.
    ckpt_dir_41 = self.checkpoints_dir.mkdir(f'checkpoint_{41}')
    ckpt_dir_41.create_file('PINNED')
    ckpt_dir_43 = self.checkpoints_dir.mkdir(f'checkpoint_{43}')
    ckpt_dir_43.create_file('PINNED')

    # Checkpoints 41 and 43 should always be kept because they are pinned.
    checkpointer.save(update_train_state_step(train_state, 44))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 44])
    # Checkpoint 44 should get deleted on next save. 43 is saved inspite of
    #  it's low accuracy because it is pinned.
    checkpointer.save(update_train_state_step(train_state, 45))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 45])

  @mock.patch('time.time', return_value=0)
  def test_save_best_checkpointer_from_restart(self, unused_mock_time):
    """Emulate restart/preempt condition."""
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state

    # First, create a checkpointer that saves all checkpoints.
    checkpointer = checkpoints.Checkpointer(
        train_state, no_partitions_partitioner, self.tmp_dir, keep=None)

    # Create a series of checkpoints. Create many checkpoints to stress test
    # event collection (some methods employ lossy/sampling collection).
    for i in range(100):
      checkpointer.save(update_train_state_step(train_state, i))
    self.assertSequenceEqual(checkpointer.all_steps(), list(range(100)))

    # Now create some metrics for all steps, with high metrics on specific
    # steps.
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(self.tmp_dir, 'train'))
    for i in range(100):
      if i in (42, 53):
        summary_writer.scalar('accuracy', i * 0.01, i)
      else:
        summary_writer.scalar('accuracy', i * 0.001, i)

    # Replace checkpointer with SaveBest variant.
    checkpointer = checkpoints.SaveBestCheckpointer(
        train_state,
        no_partitions_partitioner,
        self.tmp_dir,
        keep=2,
        metric_name_to_monitor='train/accuracy',
        metric_mode='max')

    # Verify that pre-existing metrics are read and the appropriate checkpoints
    # are deleted.
    checkpointer.save(update_train_state_step(train_state, 101))
    self.assertSequenceEqual(checkpointer.all_steps(), [42, 53, 101])

  def test_save_best_checkpointer_force_keep_period(self):
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state

    checkpointer = checkpoints.SaveBestCheckpointer(
        train_state,
        no_partitions_partitioner,
        self.tmp_dir,
        keep=2,
        metric_name_to_monitor='train/accuracy',
        metric_mode='max',
        keep_checkpoints_without_metrics=False,
        force_keep_period=3)

    summary_writer = tensorboard.SummaryWriter(
        os.path.join(self.tmp_dir, 'train'))

    # save checkpoints 0..9 with increasing accuracy
    dict_actual_steps = {}
    for c in range(10):
      checkpointer.save(update_train_state_step(train_state, c))
      summary_writer.scalar('accuracy', c / 100, c)
      dict_actual_steps[c] = checkpointer.all_steps()

    # Check when the last step=8 is not divisible by the keep_period=3
    actual_steps_8 = dict_actual_steps[8]
    expected_steps_8 = [0, 3, 5, 6, 7, 8]
    self.assertSequenceEqual(actual_steps_8, expected_steps_8)

    # Check when the last step=9 is divisible by the keep_period=3
    actual_steps_9 = dict_actual_steps[9]
    expected_steps_9 = [0, 3, 6, 7, 8, 9]
    self.assertSequenceEqual(actual_steps_9, expected_steps_9)

  @mock.patch('time.time', return_value=0)
  def test_save_best_checkpointer_missing_metrics(self, unused_mock_time):
    """Test for `keep_checkpoints_without_metrics` behavior."""
    no_partitions_partitioner = self.get_partitioner(0, 1, 1)
    train_state = self.train_state

    # Replace checkpointer with SaveBest variant.
    checkpointer = checkpoints.SaveBestCheckpointer(
        train_state,
        no_partitions_partitioner,
        self.tmp_dir,
        keep=1,
        metric_name_to_monitor='train/accuracy',
        metric_mode='max')

    # Pre-create metrics for only some of the steps.
    summary_writer = tensorboard.SummaryWriter(
        os.path.join(self.tmp_dir, 'train'))
    summary_writer.scalar('accuracy', 0.6, 43)
    summary_writer.scalar('accuracy', 0.5, 44)
    summary_writer.scalar('accuracy', 0.4, 45)

    # Verify that we always keep checkpoints for 41 and 42 (no metrics) and that
    # number to keep applies to other checkpoints.
    checkpointer.save(update_train_state_step(train_state, 41))
    self.assertSequenceEqual(checkpointer.all_steps(), [41])
    checkpointer.save(update_train_state_step(train_state, 42))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42])
    checkpointer.save(update_train_state_step(train_state, 43))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43])
    checkpointer.save(update_train_state_step(train_state, 44))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 44])
    # Checkpoint 44 should get deleted on next save.
    checkpointer.save(update_train_state_step(train_state, 45))
    self.assertSequenceEqual(checkpointer.all_steps(), [41, 42, 43, 45])

    # When switching keep_checkpoints_without_metrics to False, we should see
    # checkpoints 41 and 42 also be deleted.
    checkpointer._keep_checkpoints_without_metrics = False
    checkpointer.save(update_train_state_step(train_state, 46))
    self.assertSequenceEqual(checkpointer.all_steps(), [43, 46])

  def test_assignment_map(self):
    self.validate_save(1, 1)
    # Change optimizer
    optimizer = optimizers.Optimizer(
        optimizers.sgd(0.1),
        state=optimizers.OptimizerState(
            step=np.int32(42),
            param_states={
                'bias': np.int32(1),
                'kernel': np.array([1, 2], np.uint8)
            }),
        target={
            'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
            'layer1': {
                'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
                'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
            },
            'layer2': {
                'bias': np.arange(32, dtype=np.float32).reshape((2, 16)),
                'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
            }
        })
    self.train_state = FlaxOptimTrainState(optimizer)

    actual_train_state = self.call_host_checkpointer(
        0,
        1,
        self.get_partitioner(
            0, 1, 1, mesh_axes=jax.tree_map(lambda x: None, self.train_state)),
        lambda c: c.restore(  # pylint:disable=g-long-lambda
            step=42,
            state_transformation_fns=[
                functools.partial(
                    state_utils.apply_assignment_map,
                    assignment_map=[('target/layer2/bias', 'target/kernel'),
                                    ('target/layer\\d/(.*)', 'target/\\1')])
            ]),
        np.float32,
        None)
    self.assertEqual(actual_train_state.step, 42)
    self.assertEqual(actual_train_state._optimizer.optimizer_def,
                     self.train_state._optimizer.optimizer_def)
    jax.tree_multimap(np.testing.assert_array_equal,
                      actual_train_state.param_states,
                      self.train_state.param_states)
    jax.tree_multimap(np.testing.assert_array_equal, actual_train_state.params,
                      self.train_state.params)

  def test_assignment_map_unused(self):
    self.validate_save(1, 1)
    with self.assertRaisesWithLiteralMatch(
        ValueError,
        "Unused patterns in `assignment_map`: {'target/layer\\d/(.*)'}"):
      self.call_host_checkpointer(
          0,
          1,
          self.get_partitioner(0, 1, 1),
          lambda c: c.restore(  # pylint:disable=g-long-lambda
              step=42,
              state_transformation_fns=[
                  functools.partial(
                      state_utils.apply_assignment_map,
                      assignment_map=[('target/layer\\d/(.*)', 'target/\\1')])
              ]),
          np.float32,
          None)

  def test_assignment_map_noexists(self):
    self.validate_save(1, 1)
    with self.assertRaisesWithLiteralMatch(
        ValueError,
        "Parameter 'target/layer/bias' does not exist in restore checkpoint. "
        "Must be one of: ['state/param_states/bias', "
        "'state/param_states/kernel', 'state/step', 'target/bias', "
        "'target/kernel']"):
      self.call_host_checkpointer(
          0,
          1,
          self.get_partitioner(0, 1, 1),
          lambda c: c.restore(  # pylint:disable=g-long-lambda
              step=42,
              state_transformation_fns=[
                  functools.partial(
                      state_utils.apply_assignment_map,
                      assignment_map=[('target/(.*)', 'target/layer/\\1')])
              ]),
          np.float32,
          None)

  def test_assignment_map_partial_restore(self):
    self.validate_save(1, 1)
    # Change optimizer
    optimizer = optimizers.Optimizer(
        optimizers.sgd(0.1),
        state=optimizers.OptimizerState(
            step=np.int32(42),
            param_states={
                'bias': np.int32(1),
                'kernel': np.array([1, 2], np.uint8)
            }),
        target={
            'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
            'layer1': {
                'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
                'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
            },
            'layer2': {
                'bias': np.arange(32, dtype=np.float32).reshape((2, 16)),
                'kernel': np.arange(32, dtype=np.float32).reshape((2, 16))
            }
        })
    self.train_state = FlaxOptimTrainState(optimizer)

    actual_train_state = self.call_host_checkpointer(
        0,
        1,
        self.get_partitioner(
            0, 1, 1, mesh_axes=jax.tree_map(lambda x: None, self.train_state)),
        lambda c: c.restore(  # pylint:disable=g-long-lambda
            step=42,
            state_transformation_fns=[
                functools.partial(
                    state_utils.apply_assignment_map,
                    assignment_map=[
                        # Restore only the target kernels.
                        (r'target/layer(\d+)/kernel', r'target/kernel'),
                        (r'target.*bias', None),
                        (r'state.*', None)])
            ],
            fallback_state={
                # Initialize biases and optimizer state "from scratch"
                'target': {
                    'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
                    'layer1': {
                        'bias': np.arange(4, dtype=jnp.bfloat16).reshape((4, 1)),
                    },
                    'layer2': {
                        'bias': np.arange(32, dtype=np.float32).reshape((2, 16)),
                    }
                },
                'state': {
                    'step': 1337,  # Note: original optimizer is step=42
                    'param_states': {
                        'bias': 1,
                        'kernel': np.array([1, 2], np.uint8)
                    }
                }
            }),
        np.float32,
        None)
    self.assertEqual(actual_train_state._optimizer.optimizer_def,
                     self.train_state._optimizer.optimizer_def)
    self.assertEqual(actual_train_state.step, 1337)  # note: from-scratch
    jax.tree_multimap(np.testing.assert_array_equal,
                      actual_train_state.param_states,
                      self.train_state.param_states)
    jax.tree_multimap(np.testing.assert_array_equal, actual_train_state.params,
                      self.train_state.params)

  def verify_restore_checkpoint_from_path(
      self,
      path,
      model,
      decoder_only=False,
      partitioner_class=partitioning.PjitPartitioner):
    partitioner = partitioner_class(num_partitions=1)

    input_features = {'decoder_input_tokens': tf.zeros([2, 8])}
    if not decoder_only:
      input_features['encoder_input_tokens'] = tf.zeros([2, 8])
    train_ds = tf.data.Dataset.from_tensors(input_features)

    train_state_initializer = utils.TrainStateInitializer(
        optimizer_def=model.optimizer_def,
        init_fn=model.get_initial_variables,
        input_shapes={k: v.shape for k, v in train_ds.element_spec.items()},
        partitioner=partitioner)

    restored = list(
        train_state_initializer.from_checkpoints(
            [utils.RestoreCheckpointConfig(mode='specific', path=path)]))
    self.assertLen(restored, 1)
    return restored[0]

  def test_checkpointer_in_threaded_env(self):
    """Tests use of asyncio in checkpointer works with non-main threads."""
    executor = concurrent.futures.thread.ThreadPoolExecutor(max_workers=1)
    save = executor.submit(self.validate_save, 1, 1)
    save.result()
    restore = executor.submit(self.validate_restore, 1, 1)
    restore.result()

  def test_find_checkpoint(self):
    # `model_dir` with no step
    self.assertEqual(
        checkpoints.find_checkpoint(self.fake_checkpoints),
        os.path.join(self.fake_checkpoints, f'checkpoint_{self.steps[-1]}',
                     'checkpoint'))
    # `model_dir` with step
    step = 100
    self.assertEqual(
        checkpoints.find_checkpoint(self.fake_checkpoints, step),
        os.path.join(self.fake_checkpoints, f'checkpoint_{step}', 'checkpoint'))
    # checkpoint_dir
    self.assertEqual(
        checkpoints.find_checkpoint(
            os.path.join(self.fake_checkpoints, f'checkpoint_{step}')),
        os.path.join(self.fake_checkpoints, f'checkpoint_{step}', 'checkpoint'))
    # checkpoint_dir with step
    with self.assertRaises(ValueError):
      _ = checkpoints.find_checkpoint(
          os.path.join(self.fake_checkpoints, f'checkpoint_{step}'), 1000),
    # checkpoint_file
    path = os.path.join(self.fake_checkpoints, f'checkpoint_{step}',
                        'checkpoint')
    self.assertEqual(checkpoints.find_checkpoint(path), path)
    # checkpoint_file with step
    self.assertEqual(checkpoints.find_checkpoint(path, 1000), path)
    # Error with step
    with self.assertRaises(ValueError):
      checkpoints.find_checkpoint(self.fake_checkpoints, 1000)
    # Error
    with self.assertRaises(ValueError):
      checkpoints.find_checkpoint(
          os.path.join(self.fake_checkpoints, 'checkpoint'))

  def test_restore_tf_as_t5x(self):
    checkpoint_path = os.path.join(TESTDATA, 'mtf_tiny_t5')
    partitioner = self.get_partitioner(0, 1, 1)
    with self.assertRaisesRegex(
        ValueError,
        'Attempting to restore a TensorFlow checkpoint as a native T5X '
        'checkpoint. Use `restore_from_tf_checkpoint` instead. Path: .*'):
      self.call_host_checkpointer(0, 1, partitioner,
                                  lambda c: c.restore(path=checkpoint_path),
                                  np.float32, None)

  def test_restore_from_invalid_path(self):
    with self.assertRaisesRegex(ValueError,
                                r'Path is not a valid T5X checkpoint: .*'):
      self.verify_restore_checkpoint_from_path(TESTDATA,
                                               test_utils.get_t5_test_model())

    with self.assertRaisesRegex(ValueError,
                                r'Path is not a valid T5X checkpoint: .*'):
      self.verify_restore_checkpoint_from_path(
          os.path.join(TESTDATA, 'checkpoint'), test_utils.get_t5_test_model())

  def test_save_lazy_optimizer(self):
    # Call save one to get the parameters onto disk
    self.validate_save(1, 1)
    # Load the parameters in a lazy way
    partitioner = self.get_partitioner(0, 1, 1, params_on_devices=False)
    step = 42
    train_state = self.call_host_checkpointer(
        0,
        1,
        partitioner,
        lambda c: c.restore(  # pylint: disable=g-long-lambda
            step=step, lazy_parameters=True),
        np.float32,
        None)
    # Increment the step so we can save it
    new_step = train_state.step.get() + 1
    state_dict = train_state.state_dict()
    state_dict['state']['step'] = new_step
    train_state = train_state.restore_state(state_dict)

    # Save the train state that is made of lazy parameters.
    self.call_host_checkpointer(
        0, 1, partitioner,
        lambda c: c.save(train_state=train_state, concurrent_gb=2), np.float32,
        None)

    # Load what we just saved to inspect values
    loaded_train_state = checkpoints.load_t5x_checkpoint(
        self.tmp_dir, step=new_step)
    # Make sure the parameters are the same.
    train_state = jax.tree_map(
        lambda x: x.get()  # pylint: disable=g-long-lambda
        if isinstance(x, LazyArray) else x,
        train_state)
    jax.tree_multimap(np.testing.assert_allclose, train_state.state_dict(),
                      loaded_train_state)

  def test_update_ts_from_gfile_to_gcs(self):
    ckpt_contents = {
        'version': 3,
        'optimizer': {
            'target': {
                'unsharded_param': np.ones((5, 5), dtype=np.int32),
                'sharded_param': {
                    'driver': 'zarr',
                    'dtype': 'float32',
                    'kvstore': {
                        'driver': 'file',
                        'path': 'target.sharded_param'
                    },
                    'metadata': {
                        'chunks': [768, 768],
                        'compressor': {
                            'id': 'gzip',
                            'level': 1
                        },
                        'shape': [768, 768]
                    }
                }
            }
        }
    }

    expected = {
        'version': 3,
        'optimizer': {
            'target': {
                # np.ndarray should not change
                'unsharded_param': np.ones((5, 5), dtype=np.int32),
                'sharded_param': {
                    'driver': 'zarr',
                    'dtype': 'float32',
                    'kvstore': {
                        'bucket': 't5x-dummy-bucket',
                        'driver': 'gcs'
                    },
                    'metadata': {
                        'chunks': [768, 768],
                        'compressor': {
                            'id': 'gzip',
                            'level': 1
                        },
                        'shape': [768, 768]
                    },
                    'path': 'target.sharded_param',
                }
            }
        }
    }
    actual = checkpoints._maybe_update_ts_from_file_to_gcs(ckpt_contents)
    jax.tree_multimap(np.testing.assert_array_equal, actual, expected)

  def test_update_ts_from_gcs_to_file(self):
    ckpt_contents = {
        'version': 3,
        'optimizer': {
            'target': {
                # np.ndarray should not change
                'unsharded_param': np.ones((5, 5), dtype=np.int32),
                'sharded_param': {
                    'driver': 'zarr',
                    'dtype': 'float32',
                    'kvstore': {
                        'bucket': 't5x-dummy-bucket',
                        'driver': 'gcs'
                    },
                    'metadata': {
                        'chunks': [768, 768],
                        'compressor': {
                            'id': 'gzip',
                            'level': 1
                        },
                        'shape': [768, 768]
                    },
                    'path': 'target.sharded_param',
                }
            }
        }
    }

    driver = 'file'
    expected = {
        'version': 3,
        'optimizer': {
            'target': {
                'unsharded_param': np.ones((5, 5), dtype=np.int32),
                'sharded_param': {
                    'driver': 'zarr',
                    'dtype': 'float32',
                    'kvstore': {
                        'driver': driver,
                        'path': 'target.sharded_param'
                    },
                    'metadata': {
                        'chunks': [768, 768],
                        'compressor': {
                            'id': 'gzip',
                            'level': 1
                        },
                        'shape': [768, 768]
                    }
                }
            }
        }
    }

    actual = checkpoints._maybe_update_ts_from_gcs_to_file(ckpt_contents)
    jax.tree_multimap(np.testing.assert_array_equal, actual, expected)

  def assert_update_ts_path_from_relative_to_absolute(self, ts_spec_dict,
                                                      expected, ckpt_dir):
    """Tests that `ts_spec_dict` gets updated with `ckpt_dir` to `expected`."""

    # Test with normalization (corresponds to tensorstore>=0.1.14)
    normalized_ts_spec_dict = ts.Spec(ts_spec_dict).to_json()
    checkpoints._update_ts_path_from_relative_to_absolute(
        ckpt_dir, normalized_ts_spec_dict)
    normalized_ts_spec_dict = ts.Spec(normalized_ts_spec_dict).to_json()
    normalized_expected = ts.Spec(expected).to_json()
    jax.tree_multimap(np.testing.assert_array_equal, normalized_ts_spec_dict,
                      normalized_expected)

    # Test without normalization (corresponds to tensorstore<0.1.14)
    checkpoints._update_ts_path_from_relative_to_absolute(
        ckpt_dir, ts_spec_dict)
    jax.tree_multimap(np.testing.assert_array_equal, ts_spec_dict, expected)

  def test_update_ts_path_from_relative_to_absolute_gfile(self):
    ts_spec_dict = {
        'driver': 'zarr',
        'dtype': 'float32',
        'kvstore': {
            'driver': 'file',
            'path': 'target.encoder.layers_0.attention.query.kernel'
        },
        'metadata': {
            'chunks': [768, 768],
            'compressor': {
                'id': 'gzip',
                'level': 1
            },
            'shape': [768, 768]
        }
    }

    expected = {
        'driver': 'zarr',
        'dtype': 'float32',
        'kvstore': {
            'driver': 'file',
            # Path becomes absolute.
            'path': '/dir1/dir2/target.encoder.layers_0.attention.query.kernel'
        },
        'metadata': {
            'chunks': [768, 768],
            'compressor': {
                'id': 'gzip',
                'level': 1
            },
            'shape': [768, 768]
        }
    }
    ckpt_dir = '/dir1/dir2'

    self.assert_update_ts_path_from_relative_to_absolute(
        ts_spec_dict, expected, ckpt_dir)

  def test_update_ts_path_from_relative_to_absolute_gcs(self):
    ts_spec_dict = {
        'driver': 'zarr',
        'dtype': 'float32',
        'kvstore': {
            'bucket': 't5x-dummy-bucket',
            'driver': 'gcs'
        },
        'metadata': {
            'chunks': [768, 768],
            'compressor': {
                'id': 'gzip',
                'level': 1
            },
            'shape': [768, 768]
        },
        'path': 'target.encoder.layers_0.attention.query.kernel',
        'transform': {
            'input_exclusive_max': [[768], [768]],
            'input_inclusive_min': [0, 0]
        }
    }

    expected = {
        'driver': 'zarr',
        'dtype': 'float32',
        'kvstore': {
            'bucket': 'test-bucket',  # bucket should be changed.
            'driver': 'gcs'
        },
        'metadata': {
            'chunks': [768, 768],
            'compressor': {
                'id': 'gzip',
                'level': 1
            },
            'shape': [768, 768]
        },
        # Path becomes absolute without the "gs://bucket" portion stripped.
        'path': 'dir1/dir2/target.encoder.layers_0.attention.query.kernel',
        'transform': {
            'input_exclusive_max': [[768], [768]],
            'input_inclusive_min': [0, 0]
        }
    }

    ckpt_dir = 'gs://test-bucket/dir1/dir2'

    self.assert_update_ts_path_from_relative_to_absolute(
        ts_spec_dict, expected, ckpt_dir)

  def test_restore_tf_checkpoint(self):
    self.verify_restore_checkpoint_from_path(
        os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0'),
        test_utils.get_t5_test_model(
            emb_dim=32, head_dim=64, num_heads=2, mlp_dim=64))

  def test_restore_tf_checkpoint_wrong_config(self):
    with self.assertRaisesRegex(ValueError, r'Variable .* has shape .* != .*'):
      self.verify_restore_checkpoint_from_path(
          os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0'),
          test_utils.get_t5_test_model())

  def test_convert_tf_checkpoint(self):
    checkpoint_path = os.path.join(TESTDATA, 'mtf_tiny_t5/model.ckpt-0')

    # Minimal setup to create an optimizer with the matching config.
    model = test_utils.get_t5_test_model(
        emb_dim=32, head_dim=64, num_heads=2, mlp_dim=64)

    partitioner = partitioning.PjitPartitioner(num_partitions=1)

    def initialize_params_fn(rng):
      initial_variables = model.get_initial_variables(
          rng=rng,
          input_shapes={
              'encoder_input_tokens': (2, 512),
              'decoder_input_tokens': (2, 114),
          })
      return FlaxOptimTrainState.create(model.optimizer_def, initial_variables)

    train_state = jax.eval_shape(initialize_params_fn, jax.random.PRNGKey(0))
    checkpointer = checkpoints.Checkpointer(train_state, partitioner,
                                            self.tmp_dir)
    _ = checkpointer.convert_from_tf_checkpoint(checkpoint_path)

  def test_load_matched(self):
    checkpoint = os.path.join(TESTDATA, 'test_t5_tiny.checkpoint_0')
    train_state = self.verify_restore_checkpoint_from_path(
        checkpoint, test_utils.get_t5_test_model())
    state_dict = train_state._optimizer.state_dict()
    ckpt = checkpoints.load_t5x_checkpoint(checkpoint)
    jax.tree_multimap(np.testing.assert_array_equal, state_dict, ckpt)



if __name__ == '__main__':
  absltest.main()
