from copy import deepcopy
from collections import OrderedDict
import torch


class ModelEma:
    def __init__(self, model, decay=0.9999, device=''):
        self.ema = deepcopy(model)
        self.ema.eval()
        self.decay = decay
        self.device = device
        if device:
            self.ema.to(device=device)
        self.ema_is_dp = hasattr(self.ema, 'module')
        for p in self.ema.parameters():
            p.requires_grad_(False)

    def load_checkpoint(self, checkpoint):
        if isinstance(checkpoint, str):
            checkpoint = torch.load(checkpoint)

        assert isinstance(checkpoint, dict)
        if 'model_ema' in checkpoint:
            new_state_dict = OrderedDict()
            for k, v in checkpoint['model_ema'].items():
                if self.ema_is_dp:
                    name = k if k.startswith('module') else 'module.' + k
                else:
                    name = k.replace('module.', '') if k.startswith('module') else k
                new_state_dict[name] = v
            self.ema.load_state_dict(new_state_dict)

    def state_dict(self):
        return self.ema.state_dict()

    def update(self, model):
        pre_module = hasattr(model, 'module') and not self.ema_is_dp
        with torch.no_grad():
            curr_msd = model.state_dict()
            for k, ema_v in self.ema.state_dict().items():
                k = 'module.' + k if pre_module else k
                model_v = curr_msd[k].detach()
                if self.device:
                    model_v = model_v.to(device=self.device)
                ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)

