import multiprocessing
from datetime import datetime

import math
from dataclasses import dataclass
from concurrent.futures import ProcessPoolExecutor
import numpy as np
import init

if init.enable_GPU:
    import cupy as xp
else:
    import numpy as xp


def slow_decay(size, rate, gap_rate):
    """
    :param size:
    :param rate: r
    :param gap_rate: g
    :return:  an array that arr[i^g] = i^-r and other is zero
    """
    l = int(size ** (1 / gap_rate))
    res = xp.zeros(size, dtype=float)
    idx = xp.arange(l) + 1
    res[xp.array(xp.asarray(idx, float) ** gap_rate, dtype=int) - 1] = idx ** (-float(rate))
    return res


def decay_v(length, power):
    return (xp.arange(length) + 1) ** (-float(power))


def seqLinearGD_only_gen(z_list, T, eta_list, theta_list, ticks=100):
    t_eval = xp.linspace(0, T, ticks)
    # res = z_list[:, xp.newaxis] * (1-xp.exp(-eta_list * t_eval))
    # res = z_list[:, xp.newaxis] * (-xp.expm1(-eta_list[:, xp.newaxis] * t_eval[xp.newaxis, :]))

    gen_errs = xp.zeros(ticks, dtype=float)
    for i in range(ticks):
        estimator = z_list * (-xp.expm1(-eta_list * t_eval[i]))
        gen_errs[i] = xp.sum((estimator - theta_list) ** 2)
    return gen_errs


def seq_abDb_gen(z, Lambda, T, theta_star, D=0, b0=1, ticks=1000):
    a = xp.sqrt(Lambda)
    b = xp.ones_like(z) * b0
    beta = xp.zeros_like(z)
    dt = T / ticks
    # L = (z-\theta)^2 /2, \theta = a b^D \beta
    gen_errs = xp.zeros(ticks, dtype=np.float64)
    for t in range(ticks):
        theta = a * b ** D * beta
        gen_errs[t] = xp.sum((theta - theta_star) ** 2)

        grad_L_theta = theta - z
        grad_theta_a = b ** D * beta
        grad_theta_b = a * D * b ** (D - 1) * beta
        grad_theta_beta = a * b ** D

        a -= dt * grad_L_theta * grad_theta_a
        b -= dt * grad_L_theta * grad_theta_b
        beta -= dt * grad_L_theta * grad_theta_beta

    return gen_errs  # shape (ticks,)


@dataclass
class ExprRes:
    ns: list[int]
    Ns: list[int]
    b0_list: list[float]
    T_ada_list: list[float]
    gen_err_list: list[np.ndarray]  # each element is a 3d array (repeats, N, ticks)
    meta: dict

    # D: int
    # repeats: int
    # f_decay: float
    # f_gap: float
    # lambda_decay: float
    # eta_decay: float
    # ticks: int

    def __getitem__(self, item):
        """Get the result of the item-th expr"""
        return ExprRes(self.ns[item], self.Ns[item], self.b0_list[item], self.T_ada_list[item],
                       self.gen_err_list[item], self.meta)


def repeatedExpr(ns, f_decay, f_gap, f_factor,
                 lambda_decay, lambda_factor=1.,
                 D=0, b0_factor=1., ticks=1000, repeats=10, N0=None,
                 T_factor=1.0, seed=1001, verbose=True):
    meta = {
        "ns": ns, "repeats": repeats,
        "D": D, "b0_factor": b0_factor,
        "f_decay": f_decay, "f_gap": f_gap, "f_factor": f_factor,
        "lambda_decay": lambda_decay, "lambda_factor": lambda_factor,
        "ticks": ticks, "T_factor": T_factor
    }

    start_time = datetime.now()
    if verbose:
        print("Running: ", meta)
        print(f"Start time: {datetime.now():%m/%d %H:%M}")

    results = []
    if N0 is None:
        N0 = 2 * max(ns)
    Ns = [N0] * len(ns)
    if D == -1:
        T_ada_list = [T_factor * n for n in ns]
        b0_list = [b0_factor for _ in ns]
    else:
        T_ada_list = [T_factor * n ** ((D + 1) / (D + 2)) for n in ns]
        b0_list = [b0_factor * n ** (-D / (2 * D + 2)) for n in ns]
    rng = np.random.default_rng(seed)

    for n, N, T_Ada, b0 in zip(ns, Ns, T_ada_list, b0_list):
        start_time_n = datetime.now()
        f_abs = slow_decay(N, f_decay, f_gap) * f_factor

        lambda_list = decay_v(N, lambda_decay)

        gen_err_list = []
        for _ in range(repeats):
            z = f_abs + rng.standard_normal(N) / math.sqrt(n)
            if D == -1:
                gen_errs = seqLinearGD_only_gen(z, T_Ada, lambda_list, f_abs, ticks=ticks)
            else:
                gen_errs = seq_abDb_gen(z, lambda_list, T_Ada, f_abs, D=D, b0=b0, ticks=ticks)
            # # convert to numpy
            # if init.enable_GPU:
            #     gen_errs = xp.asnumpy(gen_errs)
            gen_err_list.append(gen_errs)
        results.append(np.array(gen_err_list))
        if verbose:
            s = str(datetime.now() - start_time_n)
            print(f"n={n}; Time={s[:s.rfind('.')]}")
    if verbose:
        print("Finished. Total time =", datetime.now() - start_time)
    return ExprRes(ns, Ns, b0_list, T_ada_list, results, meta)


def merge_results(res_list):
    ns = res_list[0].ns
    Ns = res_list[0].Ns
    b0_list = res_list[0].b0_list
    T_ada_list = res_list[0].T_ada_list
    gen_err_list = []
    for i in range(len(ns)):
        gen_err_list.append(xp.concatenate([res.gen_err_list[i] for res in res_list], axis=0))
    meta = res_list[0].meta
    meta["repeats"] = sum(res.meta["repeats"] for res in res_list)
    return ExprRes(ns, Ns, b0_list, T_ada_list, gen_err_list, meta)


def repeatedExpr_parallel(ns, f_decay, f_gap,
                          lambda_decay, f_factor=1.0, lambda_factor=1.,
                          D=2, b0_factor=1.0, ticks=1000,
                          repeats=10, N0=None,
                          T_factor=1.0, seed=1001, max_workers=None, verbose=True):
    start_time = datetime.now()
    m = min(multiprocessing.cpu_count(), repeats)
    if max_workers is not None:
        m = min(m, max_workers)
    print(f"Running in parallel, using {m} workers for {repeats} repeats.")
    print(f"Start time: {datetime.now():%m/%d %H:%M}")
    with ProcessPoolExecutor() as executor:
        # divide the repeats into max_workers parts, the number of each part should be as equal as possible
        repeats_list = [repeats // m] * m
        for i in range(repeats % m):
            repeats_list[i] += 1

        random_seeds = [seed + i * 1000 for i in range(m)]
        future_list = [
            executor.submit(repeatedExpr, ns, f_decay, f_gap, f_factor,
                            lambda_decay, lambda_factor,
                            D, b0_factor, ticks, repeats_list[i],
                            N0, T_factor, random_seeds[i], False) for i in range(1, m)]

        first = repeatedExpr(ns, f_decay, f_gap, f_factor, lambda_decay, lambda_factor,
                             D, b0_factor, ticks, repeats_list[0],
                             N0, T_factor, random_seeds[0], verbose=verbose)
        res_list = [first] + [future.result() for future in future_list]
    final_res = merge_results(res_list)

    print("Parallel job finished. Total time =", datetime.now() - start_time)

    return final_res
