import torch
import torch.nn as nn

import exp_utils as PQ


class SGradOptimizer(nn.Module):
    def __init__(self, crabs, state_box):
        super().__init__()
        self.crabs = crabs
        self.z = nn.Parameter(torch.randn(10000, *state_box.shape), requires_grad=True)
        self.opt = torch.optim.Adam([self.z], lr=1e-3)
        self.state_box = state_box

    @property
    def s(self):
        return self.state_box.decode(self.z)

    def step(self):
        result = self.crabs.obj_eval(self.s)
        obj = result['hard_obj']
        loss = (-obj).mean()

        self.opt.zero_grad()
        loss.mean().backward()
        self.opt.step()
        return loss

    @torch.no_grad()
    def reinit(self):
        PQ.log.debug("[SGradOpt] reinit")
        # self.state_box.fill_(self.z)
        nn.init.uniform_(self.z, -1., 1.)

    def evaluate(self, *, step):
        result = self.crabs.obj_eval(self.s)
        hardD = result['hard_obj']
        L = result['constraint']
        U = result['obj']
        idx = hardD.argmax()
        # nmPQ = self.crabs.barrier.net[0]
        # print(nmPQ(self.s[idx]).cpu().detach().numpy())
        max_obj = hardD.max().item()
        if max_obj > 0:
            method = PQ.log.warning
        else:
            method = PQ.log.debug
        method(f"[S grad opt] hardD = {max_obj:.6f}, L = {L[idx].item():.6f}, U = {U[idx].item():.6f}, "
               f"inside = {(L <= 0).sum().item()}, s = {self.s[idx].cpu().detach()}")

        return {
            'optimal': hardD.max().item(),
        }
