import numpy as np
import torch
import wandb
from torch import nn

class LinearData:
    
    def __init__(self, X, w_gt, b_gt, S_gt):
        self.X = X
        self.w_gt = w_gt
        self.b_gt = b_gt
        self.S_gt = S_gt

        with torch.no_grad():
            self.y = self.X.T.double() @ w_gt.double() + b_gt.double()
            self.ymean = torch.mean(self.y)
            self.tily = self.y - self.ymean
            self.ystd = torch.mean((self.y - self.ymean) ** 2) ** 0.5
            
            self.xmean = torch.mean(self.X.double(), dim=1)
            self.tilX = self.X.double() - self.xmean.view(-1, 1)
            self.xcov = self.tilX @ self.tilX.T / self.X.shape[1]
            
            self.w_opt = self.tilX @ torch.tensor(
                np.linalg.solve((self.tilX.T @ self.tilX).cpu().numpy(), self.tily.cpu().numpy()),
                dtype=self.tilX.dtype, device=self.tilX.device
            )
            self.b_opt = self.ymean - torch.inner(self.w_opt, self.xmean)
            
            self.y = self.y.float()
            self.ymean = self.ymean.float()
            self.tily = self.tily.float()
            self.ystd = self.ystd.float()
            self.xmean = self.xmean.float()
            self.tilX = self.tilX.float()
            self.xcov = self.xcov.float()
            self.w_opt = self.w_opt.float()
            self.b_opt = self.b_opt.float()

            assert self.train_loss(self.w_opt, self.b_opt) < 1e-5, f'train loss for lstsq: {self.train_loss(self.w_opt, self.b_opt)}'

            self.z = self.tilX @ self.tily / self.ystd / self.X.shape[1]
            self.H0 = 2 * (self.xcov - torch.outer(self.z, self.z))

    def eff_wb(self, gamma, beta, w):
        tilw = gamma * w / (w.T @ self.xcov @ w) ** 0.5
        tilb = beta - torch.inner(tilw, self.xmean)
        return tilw, tilb
    
    def train_loss(self, w, b):
        return ((self.X.T @ w + b - self.y) ** 2).mean()
    
    def direct_train_loss(self, gamma, beta, w):
        out = self.X.T @ w
        out = out - torch.mean(out)
        out = out / torch.mean(out ** 2) ** 0.5
        out = gamma * out + beta
        return ((out - self.y) ** 2).mean()
    
    def test_loss(self, w, b):
        delta = w - self.w_gt
        return delta.T @ self.S_gt @ delta + (b - self.b_gt) ** 2
    
    def hess(self, gamma, beta, w):
        def cur_loss(ww):
            return self.direct_train_loss(gamma, beta, ww)
        return torch.autograd.functional.hessian(cur_loss, w)
    

def get_data(N, d, seed=777, dtype=torch.float32, device='cuda'):
    state = torch.get_rng_state()
    
    torch.manual_seed(seed)
    with torch.no_grad():
        w_gt = torch.randn(d, dtype=torch.float32, device='cuda')
        w_gt /= w_gt.norm()
        b_gt = torch.randn([], dtype=torch.float32, device='cuda') * 0.1
        S_gt = torch.tensor([1 - k / d for k in range(d)], dtype=torch.float32, device='cuda').diag()
        
        X = (S_gt ** 0.5) @ torch.randn(d, N, dtype=torch.float32, device='cuda')
    
    torch.set_rng_state(state)

    return LinearData(X, w_gt, b_gt, S_gt)


def main(N, d=40, eta=0.02, peta=5e-3, lam=0.2, T=1_000_000, seed=777, train_gamma=False, train_beta=False, compute_phi=True):
    wandb.init(
        project='__project_name__',
        save_code=True,
        config={
            'N': N,
            'd': d,
            'lr': eta,
            'phi_lr': peta,
            'wd': lam,
            'T': T,
            'seed': seed
        }
    )

    data = get_data(N, d, seed=seed)

    opt_test_loss = data.test_loss(data.w_opt, data.b_opt).item()
    wandb.run.summary['wb_opt/test_loss'] = opt_test_loss
    wandb.run.summary['norm/w_opt'] = torch.linalg.norm(data.w_opt).item()
    print('expected test loss:', opt_test_loss)

    torch.manual_seed(2455)

    with torch.no_grad():
        if train_gamma:
            gamma = nn.Parameter(torch.ones([], dtype=torch.float32, device='cuda'), requires_grad=True)
        else:
            gamma = nn.Parameter(data.ystd, requires_grad=False)
        if train_beta:
            beta = nn.Parameter(torch.zeros([], dtype=torch.float32, device='cuda'), requires_grad=True)
        else:
            beta = nn.Parameter(data.ymean, requires_grad=False)
        w = nn.Parameter(torch.randn(d, dtype=torch.float32, device='cuda') / d ** 0.5, requires_grad=True)
    
    def Phi(theta):
        theta = nn.Parameter(theta.clone(), requires_grad=True)

        tot = 0
        while True:
            with torch.enable_grad():
                l = data.direct_train_loss(gamma, beta, theta)
                if l < 1e-8:
                    print('total projection steps:', tot)
                    return theta.data
                l.backward()
                tot += 1
            
            with torch.no_grad():
                theta.data -= peta * theta.grad
                theta.data /= torch.linalg.norm(theta)
                theta.grad = None
    
    vrand = np.random.randn(d)
    
    for t in range(T):
        l = data.direct_train_loss(gamma, beta, w)
        l.backward()
        
        with torch.no_grad():
            w_norm = w.norm()
            theta = w / w_norm

            extra = {}
            if compute_phi:
                phi = Phi(theta)
                x = theta - phi

                phitilw, phitilb = data.eff_wb(gamma, beta, phi)

                phess = (torch.linalg.norm(phitilw) ** 2 * data.H0).cpu().numpy()
                peigs, peigv = np.linalg.eigh(phess)

                if np.inner(peigv[:, -1], vrand) > 0:
                    tpeigv = peigv[:, -1]
                else:
                    tpeigv = -peigv[:, -1]
                
                extra.update({
                    'norm/x': torch.linalg.norm(x),
                    'h': np.inner(x.cpu().numpy(), tpeigv),
                    **{f'pseigs/{k}': peigs[-(k + 1)] for k in range(5)},
                    'phi/test_loss': data.test_loss(phitilw, phitilb).item(),
                    'norm/phitilw': torch.linalg.norm(phitilw),
                    'val/phitilb': phitilb,
                    'phidist/w_opt': (phitilw - data.w_opt).norm(),
                    'phidist/b_opt': (phitilb - data.b_opt).norm(),
                })

            tilw, tilb = data.eff_wb(gamma, beta, w)
            hess = data.hess(gamma, beta, theta).cpu().numpy()
            eigs, eigv = np.linalg.eigh(hess)
            
            eeta = eta / ((1 - eta * lam) * w_norm ** 2)
            
            wandb.log({
                'gamma': gamma.item(),
                'beta': beta.item(),
                'norm/w': w_norm,
                'norm/tilw': tilw.norm(),
                'val/tilb': tilb,
                'loss/train': l,
                'loss/test': data.test_loss(tilw, tilb).item(),
                **{f'seigs/{k}': eigs[-(k + 1)] for k in range(5)},
                'dist/w_opt': (tilw - data.w_opt).norm(),
                'dist/b_opt': (tilb - data.b_opt).norm(),
                'elr/si': eeta,
                'two_over_elr/si': 2 / eeta,
                **extra,
            }, step=t, commit=(t % 1000 == 0))
        
        with torch.no_grad():
            w.data -= eta * (w.grad + lam * w.data)
            w.grad = None

            if train_gamma:
                gamma -= eta * gamma.grad
                gamma.grad = None
            if train_beta:
                beta -= eta * beta.grad
                beta.grad = None
            

if __name__ == '__main__':
    np.set_printoptions(precision=4, suppress=True)

    main(N=20, eta=0.5, lam=0.02, seed=33523325326)