import numpy as np

from MarkovChains.MarkovChain import MarkovChain


def generate_worstcase2_mc(eps, beta, num_states=2, num_dimensions=100, seed=777, cov_eigengap_threshold=60):

    # assert(num_states == num_dimensions)

    np.random.seed(seed)
    # eps = 0.2
    transition_matrix = (np.ones((num_states,num_states))-np.eye(num_states))*eps/(num_states-1) + np.eye(num_states)*(1-eps)

    # covariance_matrices = []
    # v = np.exp(-1 * np.arange(num_dimensions))*np.random.uniform(100,1000)
    # for i in range(num_states):
    #     cov_i = np.zeros((num_dimensions, num_dimensions))
    #     # index = np.random.choice(np.arange(num_dimensions),2,replace=False)
    #     # for j in index:
    #     #     cov_i[j, j] = v[j]
    #     cov_i[i,i] = v[i]
    #     cov_i = (cov_i + np.transpose(cov_i)) / 2
    #     assert ((cov_i == np.transpose(cov_i)).all())
    #     covariance_matrices.append(cov_i)

    # means = np.random.uniform(0, 0.5, (num_states, num_dimensions))
    # covariance_matrices = []
    # for i in range(num_states):
    #     cov_i = np.diag([means[i][0]*(1-means[i][0])])
    #     cov_i = (cov_i + np.transpose(cov_i)) / 2
    #     assert ((cov_i == np.transpose(cov_i)).all())
    #     covariance_matrices.append(cov_i)

    # means = np.array([np.random.uniform(0, 0.05, num_states)])
    means = np.array([[0.00913972, 0.02798253, 0.0339301 , 0.04385393, 0.01067654,
                       0.0423545 , 0.02732101, 0.00999117, 0.00352119, 0.00229116]])
    # means = np.array([np.random.uniform(0, num_states, num_states)])
    means = np.tile(means.transpose(), (1, num_dimensions))
    print("Shape of means : ", means.shape)

    covariance_matrices = []
    for i in range(num_states):
        c = 1 + (9*i)/num_states
        cov_i = np.zeros((num_dimensions, num_dimensions))
        for x in range(num_dimensions):
            for y in range(num_dimensions):
                cov_i[x, y] = np.exp(-abs(x - y) * c) * (5 * np.power(x + 1, -beta)) * (5 * np.power(y + 1, -beta))
        cov_i = (cov_i + np.transpose(cov_i)) / 2
        assert ((cov_i == np.transpose(cov_i)).all())
        covariance_matrices.append(cov_i)

    # covariance_matrices = []
    # for i in range(num_states):
    #     v = np.exp(-1 * np.arange(num_dimensions)) * np.random.uniform(100, 1000)
    #     cov_i = np.diag(v)
    #     cov_i = (cov_i + np.transpose(cov_i)) / 2
    #     assert ((cov_i == np.transpose(cov_i)).all())
    #     covariance_matrices.append(cov_i)

    initial_distribution = abs(np.random.randn(num_states))
    initial_distribution /= np.sum(initial_distribution)

    markov_chain = MarkovChain(transition_matrix=transition_matrix,
                               means=means,
                               covariance_matrices=covariance_matrices,
                               initial_distribution=initial_distribution,
                               seed=seed)

    markov_chain.print()
    return markov_chain
