from itertools import product


def get_marginals(t, var_indices, joint_density):
    pd = [[0, 1], [0, 1], [0, 1], [0, 1]]
    for i in var_indices:
        pd[i] = [t[i]]
    v_keys = [(a, b, c, d) for a, b, c, d in product(pd[0], pd[1], pd[2], pd[3])]
    v_vals = []
    for k in v_keys:
        v_vals += [joint_density[k]]
    p_v = sum(v_vals)
    return p_v

'''
joint_density: observed joint, dictionary
intervened_var: a list of variables. E.g., ['x1', 'y2']
intervened_val: a list of values. E.g., [0, 1]
dag: 'xtoy', 'ytox', or 'xindy'
'''
def iid_causal_effect_computation(joint_density, intervened_var, intervened_val, dag):
    intervention_desc = ('do(' + intervened_var + ')' + '=' + str(intervened_val))
    var_to_idx = {'x1': 0, 'x2': 1, 'y1': 2, 'y2': 3}
    iid_trunc_fac = {intervention_desc: {}}
    tuples = [(a, b, c, d) for a, b, c, d in product([0, 1], [0, 1], [0, 1], [0, 1])]
    for t in tuples:
        consistent = True
        # if intervened value not consistent with the value in front of the bar
        if t[var_to_idx[intervened_var]] != intervened_val:
            iid_trunc_fac[intervention_desc][t] = 0
            consistent = False
            #break  # as long as one is inconsistent, no need to check the rest
        if consistent:
            p_x1 = get_marginals(t, [0], joint_density)
            p_x2 = get_marginals(t, [1], joint_density)
            p_y1 = get_marginals(t, [2], joint_density)
            p_y2 = get_marginals(t, [3], joint_density)
            p_x1_y1 = get_marginals(t, [0, 2], joint_density)
            p_x2_y2 = get_marginals(t, [1, 3], joint_density)
            if dag == 'xindy':
                fx1 = p_x1
                fx2 = p_x2
                fy1 = p_y1
                fy2 = p_y2
                if 'x1' in intervened_var:
                    fx1 = 1
                if 'x2' in intervened_var:
                    fx2 = 1
                if 'y1' in intervened_var:
                    fy1 = 1
                if 'y2' in intervened_var:
                    fy2 = 1
                iid_trunc_fac[intervention_desc][t] = fx1 * fx2 * fy1 * fy2
            elif dag == 'xtoy':
                # if intervene on x1 and y2
                fx1 = p_x1
                fx2 = p_x2
                fy1 = p_x1_y1 / p_x1
                fy2 = p_x2_y2 / p_x2
                if 'x1' in intervened_var:
                    fx1 = 1
                if 'x2' in intervened_var:
                    fx2 = 1
                if 'y1' in intervened_var:
                    fy1 = 1
                if 'y2' in intervened_var:
                    fy2 = 1
                iid_trunc_fac[intervention_desc][t] = fx1 * fx2 * fy1 * fy2
            else:  # dag == 'ytox'
                fx1 = p_x1_y1 / p_y1
                fx2 = p_x2_y2 / p_y2
                fy1 = p_y1
                fy2 = p_y2
                if 'x1' in intervened_var:
                    fx1 = 1
                if 'x2' in intervened_var:
                    fx2 = 1
                if 'y1' in intervened_var:
                    fy1 = 1
                if 'y2' in intervened_var:
                    fy2 = 1
                iid_trunc_fac[intervention_desc][t] = fx1 * fx2 * fy1 * fy2

    return iid_trunc_fac
