import torch
import numpy as np
from torch.autograd import Variable
from sklearn.model_selection import KFold
from binarygridsearch.binarygridsearch import compareValsBaseCase

def vprint(v, *args):
    if v:
        print(*args)
        
def get_l2_reg(model, alpha, n):
    assert alpha>=0
    l2_reg = 0.
    if alpha > 0:
        for name, prm in model.named_parameters():
            if 'weight' in name:
                l2_reg = l2_reg + (prm ** 2).sum()/n
    return l2_reg

def get_l1_reg(model, alpha, n):
    assert alpha>=0
    l2_reg = 0.
    if alpha > 0:
        for name, prm in model.named_parameters():
            if 'weight' in name:
                l2_reg = l2_reg + prm.abs().sum()/n
    return l2_reg

def null(*args, **kwargs):
    return 0.

#TODO currently returns average over *parameters*, not parameters winning majority vote
def line_search(x, y, w, prm_str, score,
                n_splits=3, lo=0., hi=0.5, decim=3,
                random_state=None):
    best_prms = []
    kf = KFold(n_splits=n_splits, random_state=random_state)
    for trn_idx, val_idx in kf.split(x):
        x_trn, y_trn, w_trn = x[trn_idx,:], y[trn_idx], w[trn_idx]
        x_val, y_val, w_val = x[val_idx,:], y[val_idx], w[val_idx]
        prms = {'w_trn':w_trn, 'w_val':w_val, 'x_val':x_val, 'y_val':y_val}
        tbl = compareValsBaseCase(x_trn, y_trn, score, prms, prm_str,
                                    decim, lo, hi) #hack to input seed
        best_prm_k = tbl[prm_str].iloc[-1]
        best_prms.append(best_prm_k)
    best_prm = np.mean(np.asarray(best_prms)) 
    return best_prm