# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# This work is licensed under a Creative Commons
# Attribution-NonCommercial-ShareAlike 4.0 International License.
# You should have received a copy of the license along with this
# work. If not, see http://creativecommons.org/licenses/by-nc-sa/4.0/

"""Generate random images using the techniques described in the paper
"Elucidating the Design Space of Diffusion-Based Generative Models"."""

import os
import re
import click
import tqdm
import pickle
import numpy as np
import torch
import PIL.Image
import dnnlib
from torch_utils import distributed as dist
from torchvision.utils import make_grid, save_image
from torch.distributions import Beta
import glob
from torch_utils import misc
import wandb
#----------------------------------------------------------------------------
# Proposed EDM sampler (Algorithm 2).

def edm_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=0.002, sigma_max=80, rho=7,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=0,
    pfgmpp=False, ensemble=1, inter_steps=18, restart=0, restart_value=-1,restart_info="", restart_gamma=0
):

    def get_steps(min_t, max_t, num_steps, rho):

         step_indices = torch.arange(num_steps, dtype=torch.float, device=latents.device)
         t_steps = (max_t ** (1 / rho) + step_indices / (num_steps - 1) * (min_t ** (1 / rho) - max_t ** (1 / rho))) ** rho
         # t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])

         return t_steps


    N = net.img_channels * net.img_resolution * net.img_resolution
    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Time step discretization.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    t_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (
                sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho
    t_steps = torch.cat([net.round_sigma(t_steps), torch.zeros_like(t_steps[:1])])  # t_N = 0
    total_step = len(t_steps)
    #print(t_steps)
    if pfgmpp:
        x_next = latents.to(torch.float64)
    else:
        x_next = latents.to(torch.float64) * t_steps[0]
        # Main sampling loop.

    # {restart_index : [number of restart iteration, restart_max_time, num_steps], ... }
    # step = 18:
    import json
    restart_list = json.loads(restart_info)
    old_steps = get_steps(sigma_min, sigma_max, 18, rho)
    # cast keys to int
    restart_list = {int(torch.argmin(abs(t_steps - old_steps[int(k)]), dim=0)): v for k, v in restart_list.items()}
    dist.print0(f"list:{restart_list}")
    # restart_list = {-1: 0}
    # # step = 32:
    # restart_list = {26: [20, 5, 3], 7: [1, 4, 3]}
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])):  # 0, ..., N-1

        x_cur = x_next
        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
        t_hat = net.round_sigma(t_cur + gamma * t_cur)
        x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
        # Euler step.

        denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
        d_cur = (x_hat - denoised) / t_hat
        x_next = x_hat + (t_next - t_hat) * d_cur

        # Apply 2nd order correction.
        if i < num_steps - 1:
            denoised = net(x_next, t_next, class_labels).to(torch.float64)
            d_prime = (x_next - denoised) / t_next
            x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

        # ================= restart ================== #
        if i + 1 in restart_list.keys():
            restart_idx = i + 1

            for restart_iter in range(restart_list[restart_idx][0]):

                # TODO tuning max_t and number_steps
                new_t_steps = get_steps(min_t=t_steps[restart_idx], max_t=restart_list[restart_idx][1],
                                        num_steps=restart_list[restart_idx][2], rho=rho)

                new_total_step = len(new_t_steps)
                if pfgmpp:
                    # TODO: modify with min_t != 0
                    beta_gen = Beta(torch.FloatTensor([N / 2.]), torch.FloatTensor([net.D / 2.]))
                    sample_norm = beta_gen.sample(torch.Size([len(x_next)])).to(x_next.device).double()
                    # inverse beta distribution
                    inverse_beta = sample_norm / (1 - sample_norm)

                    sample_norm = torch.sqrt(inverse_beta) * t_steps[restart_idx] * np.sqrt(net.D)
                    gaussian = torch.randn(N).to(sample_norm.device)
                    unit_gaussian = gaussian / torch.norm(gaussian, p=2)
                    init_sample = unit_gaussian * sample_norm
                    x_next = x_next + init_sample.view_as(x_next) * S_noise
                else:
                    x_next = x_next + randn_like(x_next) * (new_t_steps[0] ** 2 - new_t_steps[-1] ** 2).sqrt() * S_noise

                dist.print0(new_t_steps)


                for j, (t_cur, t_next) in enumerate(zip(new_t_steps[:-1], new_t_steps[1:])):  # 0, ..., N-1

                    x_cur = x_next
                    # gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= t_cur <= S_max else 0
                    gamma = restart_gamma if S_min <= t_cur <= S_max else 0
                    t_hat = net.round_sigma(t_cur + gamma * t_cur)

                    x_hat = x_cur + (t_hat ** 2 - t_cur ** 2).sqrt() * S_noise * randn_like(x_cur)
                    denoised = net(x_hat, t_hat, class_labels).to(torch.float64)
                    d_cur = (x_hat - denoised) / (t_hat)
                    x_next = x_hat + (t_next - t_hat) * d_cur

                    # Apply 2nd order correction.
                    if j < new_total_step - 2 or new_t_steps[-1] != 0:
                        denoised = net(x_next, t_next, class_labels).to(torch.float64)
                        d_prime = (x_next - denoised) / t_next
                        x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)

    return x_next


#----------------------------------------------------------------------------
# Generalized ablation sampler, representing the superset of all sampling
# methods discussed in the paper.

def ablation_sampler(
    net, latents, class_labels=None, randn_like=torch.randn_like,
    num_steps=18, sigma_min=None, sigma_max=None, rho=7,
    solver='heun', discretization='edm', schedule='linear', scaling='none',
    epsilon_s=1e-3, C_1=0.001, C_2=0.008, M=1000, alpha=1,
    S_churn=0, S_min=0, S_max=float('inf'), S_noise=1,
):
    assert solver in ['euler', 'heun']
    assert discretization in ['vp', 've', 'iddpm', 'edm']
    assert schedule in ['vp', 've', 'linear']
    assert scaling in ['vp', 'none']

    # Helper functions for VP & VE noise level schedules.
    vp_sigma = lambda beta_d, beta_min: lambda t: (np.e ** (0.5 * beta_d * (t ** 2) + beta_min * t) - 1) ** 0.5
    vp_sigma_deriv = lambda beta_d, beta_min: lambda t: 0.5 * (beta_min + beta_d * t) * (sigma(t) + 1 / sigma(t))
    vp_sigma_inv = lambda beta_d, beta_min: lambda sigma: ((beta_min ** 2 + 2 * beta_d * (sigma ** 2 + 1).log()).sqrt() - beta_min) / beta_d
    ve_sigma = lambda t: t.sqrt()
    ve_sigma_deriv = lambda t: 0.5 / t.sqrt()
    ve_sigma_inv = lambda sigma: sigma ** 2

    # Select default noise level range based on the specified time step discretization.
    if sigma_min is None:
        vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=epsilon_s)
        sigma_min = {'vp': vp_def, 've': 0.02, 'iddpm': 0.002, 'edm': 0.002}[discretization]
    if sigma_max is None:
        vp_def = vp_sigma(beta_d=19.1, beta_min=0.1)(t=1)
        sigma_max = {'vp': vp_def, 've': 100, 'iddpm': 81, 'edm': 80}[discretization]

    # Adjust noise levels based on what's supported by the network.
    sigma_min = max(sigma_min, net.sigma_min)
    sigma_max = min(sigma_max, net.sigma_max)

    # Compute corresponding betas for VP.
    vp_beta_d = 2 * (np.log(sigma_min ** 2 + 1) / epsilon_s - np.log(sigma_max ** 2 + 1)) / (epsilon_s - 1)
    vp_beta_min = np.log(sigma_max ** 2 + 1) - 0.5 * vp_beta_d

    # Define time steps in terms of noise level.
    step_indices = torch.arange(num_steps, dtype=torch.float64, device=latents.device)
    if discretization == 'vp':
        orig_t_steps = 1 + step_indices / (num_steps - 1) * (epsilon_s - 1)
        sigma_steps = vp_sigma(vp_beta_d, vp_beta_min)(orig_t_steps)
    elif discretization == 've':
        orig_t_steps = (sigma_max ** 2) * ((sigma_min ** 2 / sigma_max ** 2) ** (step_indices / (num_steps - 1)))
        sigma_steps = ve_sigma(orig_t_steps)
    elif discretization == 'iddpm':
        u = torch.zeros(M + 1, dtype=torch.float64, device=latents.device)
        alpha_bar = lambda j: (0.5 * np.pi * j / M / (C_2 + 1)).sin() ** 2
        for j in torch.arange(M, 0, -1, device=latents.device): # M, ..., 1
            u[j - 1] = ((u[j] ** 2 + 1) / (alpha_bar(j - 1) / alpha_bar(j)).clip(min=C_1) - 1).sqrt()
        u_filtered = u[torch.logical_and(u >= sigma_min, u <= sigma_max)]
        sigma_steps = u_filtered[((len(u_filtered) - 1) / (num_steps - 1) * step_indices).round().to(torch.int64)]
    else:
        assert discretization == 'edm'
        sigma_steps = (sigma_max ** (1 / rho) + step_indices / (num_steps - 1) * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))) ** rho

    # Define noise level schedule.
    if schedule == 'vp':
        sigma = vp_sigma(vp_beta_d, vp_beta_min)
        sigma_deriv = vp_sigma_deriv(vp_beta_d, vp_beta_min)
        sigma_inv = vp_sigma_inv(vp_beta_d, vp_beta_min)
    elif schedule == 've':
        sigma = ve_sigma
        sigma_deriv = ve_sigma_deriv
        sigma_inv = ve_sigma_inv
    else:
        assert schedule == 'linear'
        sigma = lambda t: t
        sigma_deriv = lambda t: 1
        sigma_inv = lambda sigma: sigma

    # Define scaling schedule.
    if scaling == 'vp':
        s = lambda t: 1 / (1 + sigma(t) ** 2).sqrt()
        s_deriv = lambda t: -sigma(t) * sigma_deriv(t) * (s(t) ** 3)
    else:
        assert scaling == 'none'
        s = lambda t: 1
        s_deriv = lambda t: 0

    # Compute final time steps based on the corresponding noise levels.
    t_steps = sigma_inv(net.round_sigma(sigma_steps))
    t_steps = torch.cat([t_steps, torch.zeros_like(t_steps[:1])]) # t_N = 0

    # Main sampling loop.
    t_next = t_steps[0]
    x_next = latents.to(torch.float64) * (sigma(t_next) * s(t_next))
    for i, (t_cur, t_next) in enumerate(zip(t_steps[:-1], t_steps[1:])): # 0, ..., N-1
        x_cur = x_next

        # Increase noise temporarily.
        gamma = min(S_churn / num_steps, np.sqrt(2) - 1) if S_min <= sigma(t_cur) <= S_max else 0
        t_hat = sigma_inv(net.round_sigma(sigma(t_cur) + gamma * sigma(t_cur)))
        x_hat = s(t_hat) / s(t_cur) * x_cur + (sigma(t_hat) ** 2 - sigma(t_cur) ** 2).clip(min=0).sqrt() * s(t_hat) * S_noise * randn_like(x_cur)

        # Euler step.
        h = t_next - t_hat
        denoised = net(x_hat / s(t_hat), sigma(t_hat), class_labels).to(torch.float64)
        d_cur = (sigma_deriv(t_hat) / sigma(t_hat) + s_deriv(t_hat) / s(t_hat)) * x_hat - sigma_deriv(t_hat) * s(t_hat) / sigma(t_hat) * denoised
        x_prime = x_hat + alpha * h * d_cur
        t_prime = t_hat + alpha * h

        # Apply 2nd order correction.
        if solver == 'euler' or i == num_steps - 1:
            x_next = x_hat + h * d_cur
        else:
            assert solver == 'heun'
            denoised = net(x_prime / s(t_prime), sigma(t_prime), class_labels).to(torch.float64)
            d_prime = (sigma_deriv(t_prime) / sigma(t_prime) + s_deriv(t_prime) / s(t_prime)) * x_prime - sigma_deriv(t_prime) * s(t_prime) / sigma(t_prime) * denoised
            x_next = x_hat + h * ((1 - 1 / (2 * alpha)) * d_cur + 1 / (2 * alpha) * d_prime)

    return x_next

#----------------------------------------------------------------------------
# Wrapper for torch.Generator that allows specifying a different random seed
# for each sample in a minibatch.

class StackedRandomGenerator:
    def __init__(self, device, seeds):
        super().__init__()
        self.generators = [torch.Generator(device).manual_seed(int(seed) % (1 << 32)) for seed in seeds]
        self.seeds = seeds
        self.device = device

    def randn(self, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack([torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators])

    def rand_beta_prime(self, size, N=3072, D=128, **kwargs):
        # sample from beta_prime (N/2, D/2)
        # print(f"N:{N}, D:{D}")
        assert size[0] == len(self.seeds)
        latent_list = []
        beta_gen = Beta(torch.FloatTensor([N / 2.]), torch.FloatTensor([D / 2.]))
        for seed in self.seeds:
            torch.manual_seed(seed)
            sample_norm = beta_gen.sample().to(kwargs['device']).double()
            # inverse beta distribution
            inverse_beta = sample_norm / (1-sample_norm)

            if N < 256 * 256 * 3:
                sigma_max = 80
            else:
                raise NotImplementedError

            sample_norm = torch.sqrt(inverse_beta) * sigma_max * np.sqrt(D)
            gaussian = torch.randn(N).to(sample_norm.device)
            unit_gaussian = gaussian / torch.norm(gaussian, p=2)
            init_sample = unit_gaussian * sample_norm
            latent_list.append(init_sample.reshape((1, *size[1:])))

        latent = torch.cat(latent_list, dim=0)
        return latent

    def randn_like(self, input):
        return self.randn(input.shape, dtype=input.dtype, layout=input.layout, device=input.device)

    def randint(self, *args, size, **kwargs):
        assert size[0] == len(self.generators)
        return torch.stack([torch.randint(*args, size=size[1:], generator=gen, **kwargs) for gen in self.generators])

#----------------------------------------------------------------------------
# Parse a comma separated list of numbers or ranges and return a list of ints.
# Example: '1,2,5-10' returns [1, 2, 5, 6, 7, 8, 9, 10]

def parse_int_list(s):
    if isinstance(s, list): return s
    ranges = []
    range_re = re.compile(r'^(\d+)-(\d+)$')
    for p in s.split(','):
        m = range_re.match(p)
        if m:
            ranges.extend(range(int(m.group(1)), int(m.group(2))+1))
        else:
            ranges.append(int(p))
    return ranges

#----------------------------------------------------------------------------

@click.command()
@click.option('--network', 'network_pkl',  help='Network pickle filename', metavar='PATH|URL',                      type=str)
@click.option('--outdir',                  help='Where to save the output images', metavar='DIR',                   type=str, required=True)
@click.option('--seeds',                   help='Random seeds (e.g. 1,2,5-10)', metavar='LIST',                     type=parse_int_list, default='0-63', show_default=True)
@click.option('--subdirs',                 help='Create subdirectory for every 1000 seeds',                         is_flag=True)
@click.option('--save_images',             help='only save a batch images for grid visualization',                     is_flag=True)
@click.option('--class', 'class_idx',      help='Class label  [default: random]', metavar='INT',                    type=click.IntRange(min=0), default=None)
@click.option('--batch', 'max_batch_size', help='Maximum batch size', metavar='INT',                                type=click.IntRange(min=1), default=64, show_default=True)

@click.option('--steps', 'num_steps',      help='Number of sampling steps', metavar='INT',                          type=click.IntRange(min=1), default=18, show_default=True)
@click.option('--inter_steps',      help='Number of intermediate sampling steps', metavar='INT',                          type=click.IntRange(min=0), default=0, show_default=True)
@click.option('--sigma_min',               help='Lowest noise level  [default: varies]', metavar='FLOAT',           type=click.FloatRange(min=0, min_open=True))
@click.option('--sigma_max',               help='Highest noise level  [default: varies]', metavar='FLOAT',          type=click.FloatRange(min=0, min_open=True))
@click.option('--rho',                     help='Time step exponent', metavar='FLOAT',                              type=click.FloatRange(min=0, min_open=True), default=7, show_default=True)
@click.option('--S_churn', 'S_churn',      help='Stochasticity strength', metavar='FLOAT',                          type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_min', 'S_min',          help='Stoch. min noise level', metavar='FLOAT',                          type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_max', 'S_max',          help='Stoch. max noise level', metavar='FLOAT',                          type=click.FloatRange(min=0), default=0, show_default=True)
@click.option('--S_noise', 'S_noise',      help='Stoch. noise inflation', metavar='FLOAT',                          type=float, default=0, show_default=True)
@click.option('--ckpt', 'ckpt',      help='begin ckpt', metavar='INT',                          type=int, default=0, show_default=True)
@click.option('--resume', 'resume',      help='resume ckpt', metavar='INT',                          type=int, default=None, show_default=True)
@click.option('--end_ckpt', 'end_ckpt',      help='end ckpt', metavar='INT',                          type=int, default=100000000, show_default=True)

@click.option('--solver',                  help='Ablate ODE solver', metavar='euler|heun',                          type=click.Choice(['euler', 'heun']))
@click.option('--disc', 'discretization',  help='Ablate time step discretization {t_i}', metavar='vp|ve|iddpm|edm', type=click.Choice(['vp', 've', 'iddpm', 'edm']))
@click.option('--schedule',                help='Ablate noise schedule sigma(t)', metavar='vp|ve|linear',           type=click.Choice(['vp', 've', 'linear']))
@click.option('--scaling',                 help='Ablate signal scaling s(t)', metavar='vp|none',                    type=click.Choice(['vp', 'none']))
@click.option('--use_pickle',          help='load model by pickle', metavar='BOOL',              type=bool, default=False, show_default=True)
@click.option('--name',          help='ckpt name',              type=str, default=None, show_default=True)
@click.option('--restart_info', 'restart_info',             help='restart information', metavar='STR', type = str, default='0', show_default=True)
@click.option('--restart_gamma', 'restart_gamma',             help='restart gamma', metavar='FLOAT',                            type=click.FloatRange(min=0), default=0, show_default=True)

@click.option('--pfgmpp',          help='Train PFGM++', metavar='BOOL',              type=bool, default=False, show_default=True)
@click.option('--aug_dim',             help='additional dimension', metavar='INT',                            type=click.IntRange(min=2), default=128, show_default=True)
@click.option('--ensemble',             help='ensemble number', metavar='INT',                            type=click.IntRange(min=1), default=1, show_default=True)
@click.option('--restart',             help='restart number', metavar='INT',                            type=click.IntRange(min=0), default=0, show_default=True)
@click.option('--restart_value',             help='restart value', metavar='FLOAT',                            type=click.FloatRange(min=-1), default=-1, show_default=True)
def main(ckpt, end_ckpt, outdir, subdirs, seeds, class_idx, max_batch_size, save_images, pfgmpp, aug_dim, ensemble, use_pickle, name, device=torch.device('cuda'), **sampler_kwargs):
    """Generate random images using the techniques described in the paper
    "Elucidating the Design Space of Diffusion-Based Generative Models".

    Examples:

    """
    wandb_enabled = False
    dist.init()
    num_batches = ((len(seeds) - 1) // (max_batch_size * dist.get_world_size()) + 1) * dist.get_world_size()
    all_batches = torch.as_tensor(seeds).tensor_split(num_batches)
    rank_batches = all_batches[dist.get_rank() :: dist.get_world_size()]

    if use_pickle:
        stats = glob.glob(os.path.join(outdir, "*.pkl"))
    else:
        stats = glob.glob(os.path.join(outdir, "training-state-*.pt"))

    done_list = []

    for ckpt_dir in stats:
            
        if wandb_enabled:
            wandb.init(project="sde_restart", entity="goodeat", settings=wandb.Settings(start_method='fork'), config = sampler_kwargs)
        # Load network.
        dist.print0(f'Loading network from "{ckpt_dir}"...')
        # Rank 0 goes first.
        if dist.get_rank() != 0:
            torch.distributed.barrier()

        if use_pickle:
            with dnnlib.util.open_url(ckpt_dir, verbose=(dist.get_rank() == 0)) as f:
                net = pickle.load(f)['ema'].to(device)
                ckpt_num = 000000
        else:
            data = torch.load(ckpt_dir, map_location=torch.device('cpu'))
            net = data['ema'].eval().to(device)
            ckpt_num = int(ckpt_dir[-9:-3])

            assert net.D == aug_dim

        if use_pickle:
            restart = sampler_kwargs['restart']
            v = sampler_kwargs['restart_value']
            steps = sampler_kwargs['num_steps']
            if name is None:
                temp_dir = os.path.join(outdir, f'ckpt_{ckpt_num:06d}_steps_{steps}')
            else:
                temp_dir = os.path.join(outdir, f'ckpt_{ckpt_num:06d}_{name}')

        elif seeds[-1] > 49999 and seeds[-1] <= 99999:
            temp_dir = os.path.join(outdir, f'ckpt_2_{ckpt_num:06d}')
        elif seeds[-1] > 99999:
            temp_dir = os.path.join(outdir, f'ckpt_3_{ckpt_num:06d}')
        else:
            restart = sampler_kwargs['restart']
            steps = sampler_kwargs['num_steps']
            if name is None:
                temp_dir = os.path.join(outdir, f'ckpt_{ckpt_num:06d}_steps_{steps}')
            else:
                temp_dir = os.path.join(outdir, f'ckpt_{ckpt_num:06d}_steps_{steps}_{name}')

        # Other ranks follow.
        if dist.get_rank() == 0:
            torch.distributed.barrier()

        if not use_pickle:
            if ckpt_num < ckpt or ckpt_num > end_ckpt or ckpt_num in done_list:
                continue
        # if os.path.exists(temp_dir) and not save_images:
        #     print(f"{temp_dir} already exists")
        #     continue

        # Loop over batches.
        dist.print0(f'Generating {len(seeds)} images to "{temp_dir}"...')
        for batch_seeds in tqdm.tqdm(rank_batches, unit='batch', disable=(dist.get_rank() != 0)):
            torch.distributed.barrier()
            batch_size = len(batch_seeds)
            if batch_size == 0:
                continue

            N = net.img_channels * net.img_resolution * net.img_resolution
            # Pick latents and labels.
            rnd = StackedRandomGenerator(device, batch_seeds)
            if pfgmpp:
                latents = rnd.rand_beta_prime([batch_size, net.img_channels, net.img_resolution, net.img_resolution],
                                    N=N,
                                    D=aug_dim,
                                    pfgmpp=pfgmpp,
                                    device=device)
            else:
                latents = rnd.randn([batch_size, net.img_channels, net.img_resolution, net.img_resolution],
                                    device=device)
            class_labels = None
            if net.label_dim:
                class_labels = torch.eye(net.label_dim, device=device)[
                    rnd.randint(net.label_dim, size=[batch_size], device=device)]
            if class_idx is not None:
                class_labels[:, :] = 0
                class_labels[:, class_idx] = 1

            # Generate images.
            sampler_kwargs = {key: value for key, value in sampler_kwargs.items() if value is not None}
            have_ablation_kwargs = any(x in sampler_kwargs for x in ['solver', 'discretization', 'schedule', 'scaling'])
            sampler_fn = ablation_sampler if have_ablation_kwargs else edm_sampler
            with torch.no_grad():
                images = sampler_fn(net, latents, class_labels, randn_like=rnd.randn_like,
                                    pfgmpp=pfgmpp, ensemble=ensemble, **sampler_kwargs)

            if save_images:
                # save a small batch of images
                images_ = (images + 1) / 2.
                print("len:", len(images))
                image_grid = make_grid(images_, nrow=int(np.sqrt(len(images))))
                save_image(image_grid, os.path.join(outdir, f'ode_images_{ckpt_num}.png'))
                exit(0)
                break
            # Save images.
            images_np = (images * 127.5 + 128).clip(0, 255).to(torch.uint8).permute(0, 2, 3, 1).cpu().numpy()

            for seed, image_np in zip(batch_seeds, images_np):

                #image_dir = os.path.join(temp_dir, f'{seed - seed % 1000:06d}') if subdirs else outdir
                image_dir = os.path.join(temp_dir, f'{seed - seed % 1000:06d}')
                os.makedirs(image_dir, exist_ok=True)
                image_path = os.path.join(image_dir, f'{seed:06d}.png')
                if image_np.shape[2] == 1:
                    PIL.Image.fromarray(image_np[:, :, 0], 'L').save(image_path)
                else:
                    PIL.Image.fromarray(image_np, 'RGB').save(image_path)

        # Done.
        torch.distributed.barrier()
        dist.print0('Done.')

#----------------------------------------------------------------------------

if __name__ == "__main__":
    main()

#----------------------------------------------------------------------------
