from model.ops import f_eff
from model.ntk import ReLU_NTK as NTK
import torch


def weight_project(target, init_net, dataset, rho=1.0, T=1.0, N=1000, device_name='cuda:0'):

    dataset.datanum = N
    dataset.online = True

    with torch.no_grad():
        X = dataset[:][1]
        y = f_eff(target(X), rho, T) - f_eff(init_net(X), rho, T)
        Theta = NTK(X, init_net.weight_std, init_net.bias_std)

        temp, _ = torch.solve(y, Theta)

    return torch.dot(y.squeeze(), temp.squeeze()).item()
