import jax
import jax.numpy as np

def cosine_annealing_lr(i, lr_min, lr_max, n_epochs):
    return lr_min + 0.5 * (lr_max - lr_min) * (1 + np.cos(i / n_epochs * np.pi))

def swa_lr_schedule(i, base_lr, swa_lr, begin_epoch, n_epochs):
    t = i / begin_epoch
    lr_ratio = swa_lr / base_lr
    
    t1 = t <= 0.5
    t2 = (t > 0.5) * (t <=0.9)
    t3 = t > 0.9

    lr_scale = t1 + t2 * (1.0 - (1.0 - lr_ratio) * (t - 0.5) / 0.4) + t3 * lr_ratio
    return lr_scale * base_lr
