import matplotlib.pyplot as plt
import numpy as np

from sklearn.datasets import load_breast_cancer
from sklearn.tree import DecisionTreeClassifier
from sklearn.metrics import accuracy_score

from sklearn.model_selection import train_test_split
from sklearn.model_selection import learning_curve


################### policy learning
train_sizes = [
    0,
    500_000,
    1_000_000,
    1_500_000,
    2_000_000,
    2_500_000,
    3_000_000,
    3_500_000,
    4_000_000,
]

train_sizes = [
    "0",
    "0.5M",
    "1M",
    "1.5M",
    "2M",
    "2.5M",
    "3M",
    "3.5M",
    "4M",
    ''
]

train_scores_means = {
    "EmbCLIP":  np.array([0, 73.03, 81.58, 88.82, 89.47, 88.82, 97.37, 94.08, 95.39]),
    "ATC":      np.array([0, 67.76, 75.66, 78.95, 78.95, 77.63, 88.16, 83.55, 77.63]),
    "ConPE":    np.array([0, 76.32, 80.26, 90.79, 96.05, 97.37, 95.39, 94.08, 95.39]),
}
train_scores_stds = {
    "EmbCLIP":  np.array([0, 16.79, 24.4, 8.8, 9.12, 12.11, 2.63, 7.53, 2.18]),
    "ATC":      np.array([0, 10.59, 7.76, 24.48, 18.42, 32.68, 14.59, 22.44, 32.79]),
    "ConPE":    np.array([0, 6.71, 16.06, 5.43, 2.94, 1.86, 3.89, 7.3, 3.89]),
}

test_scores_means = {
    "EmbCLIP":  np.array([0, 59.17, 72.76, 69.08, 63.03, 69.04, 75.74, 68.56, 69.78]),
    "ATC":      np.array([0, 57.02, 64.43, 63.25, 59.12, 57.68, 65.88, 61.32, 54.17]),
    "ConPE":    np.array([0, 63.86, 67.63, 83.03, 77.72, 83.56, 83.2, 76.19, 76.27]),
}
test_scores_stds = {
    "EmbCLIP":  np.array([0, 18.43, 18.1, 18.33, 21.62, 16.84, 17.87, 18.65, 18.14]),
    "ATC":      np.array([0, 13.39, 14.92, 22.97, 16.99, 22.84, 19.36, 22.01, 23.53]),
    "ConPE":    np.array([0, 11.09, 14.71, 13.81, 12.03, 10.75, 8.67, 12.43, 12.68]),
}

baselines = ["EmbCLIP", "ATC", "ConPE"]


######################### prompt ensemble adaptation
# train_sizes = [
#     0,
#     100_000,
#     200_000,
#     300_000,
#     400_000,
#     500_000,
#     600_000,
#     700_000,
#     800_000,
# ]

# train_sizes = [
#     "0",
#     "0.1M",
#     "0.2M",
#     "0.3M",
#     "0.4M",
#     "0.5M",
#     "0.6M",
#     "0.7M",
#     "0.8M",
#     " "
# ]

# train_scores_means = {
#     "Pretrained":         np.array([95.27, 95.27, 95.27, 95.27, 95.27, 95.27, 95.27, 95.27, 95.27, 95.27]),
#     "ATTEMPT":  np.array([0, 20., 20., 20.83, 20., 20., 20., 20., 20.]),
#     "SESoM":    np.array([0, 22.15, 27.02, 23.79, 24.84, 22.31, 23.12, 23.79, 25.48]),
#     "ConPE":    np.array([0, 100., 98.33, 96.67, 97.5, 100., 97.5, 99.17, 100]),
# }
# train_scores_stds = {
#     "Pretrained":         np.array([4.59, 4.59, 4.59, 4.59, 4.59, 4.59, 4.59, 4.59, 4.59]),
#     "ATTEMPT":  np.array([0, 0., 0., 1.44, 0., 0., 0., 0., 0.]),
#     "SESoM":    np.array([0, 2.87, 5.12, 3.74, 3.17, 1.37, 3.13, 1.69, 5.22]),
#     "ConPE":    np.array([0, 0., 2.89, 4.08, 2.76, 0., 2.76, 1.44, 0.]),
# }

# test_scores_means = {
#     "Pretrained":         np.array([80.94, 80.94, 80.94, 80.94, 80.94, 80.94, 80.94, 80.94, 80.94, 80.94,]),
#     "ATTEMPT":  np.array([0, 20.56, 20.61, 20.39, 21., 20.23, 21.06, 20.2, 20.68]),
#     "SESoM":    np.array([0, 24.79, 24.94, 23.7, 24.07, 24.24, 23.97, 23.84, 23.41]),
#     "ConPE":    np.array([0, 83.09, 85.42, 82.24, 84.55, 85.23, 83.81, 83.07, 83.59]),
# }
# test_scores_stds = {
#     "Pretrained":         np.array([1.64, 1.64, 1.64, 1.64, 1.64, 1.64, 1.64, 1.64, 1.64]),
#     "ATTEMPT":  np.array([0, 1.23, 1.24, 1.07, 1.98, 0.7, 1.89, 0.79, 1.61]),
#     "SESoM":    np.array([0, 3.31, 3.06, 3.19, 3.12, 3.74, 2.96, 2.65, 2.16]),
#     "ConPE":    np.array([0, 9.75, 8.76, 10.61, 9.34, 8.68, 10.28, 10.79, 8.38]),
# }

# baselines = ["ATTEMPT", "SESoM", "ConPE", "Pretrained"]



_, axes = plt.subplots(1, 2, figsize=(26, 7))
plt.rc('legend', fontsize=26)  # 범례 폰트 크기
# plt.style.use('ggplot')

axes[0].grid()
# comp 1
axes[0].fill_between(train_sizes[:-1], train_scores_means[baselines[0]] - train_scores_stds[baselines[0]],
train_scores_means[baselines[0]] + train_scores_stds[baselines[0]], alpha=0.1, color="r")
axes[0].plot(train_sizes[:-1], train_scores_means[baselines[0]], 'o-', color="r", label=baselines[0], linewidth=3, markersize=12)
# comp 2
axes[0].fill_between(train_sizes[:-1], train_scores_means[baselines[1]] - train_scores_stds[baselines[1]],
train_scores_means[baselines[1]] + train_scores_stds[baselines[1]], alpha=0.1, color="g")
axes[0].plot(train_sizes[:-1], train_scores_means[baselines[1]], 'o-', color="g", label=baselines[1], linewidth=3, markersize=12)
# comp 3
axes[0].fill_between(train_sizes[:-1], train_scores_means[baselines[2]] - train_scores_stds[baselines[2]],
train_scores_means[baselines[2]] + train_scores_stds[baselines[2]], alpha=0.1, color="b")
axes[0].plot(train_sizes[:-1], train_scores_means[baselines[2]], 'o-', color="b", label=baselines[2], linewidth=3, markersize=12)

# Pretrained
# axes[0].plot(train_sizes, train_scores_means[baselines[3]], '--', color="black", label=baselines[3], linewidth=3)

# deco
axes[0].legend(loc="best")
axes[0].set_xlabel("Time-steps")
axes[0].set_ylabel("Success Rate")
# axes[0].set_title("Train Performance")
axes[0].title.set_size(30)
axes[0].xaxis.label.set_size(26)
axes[0].yaxis.label.set_size(26)
axes[0].tick_params(axis="x", labelsize=22)
axes[0].tick_params(axis="y", labelsize=22)
axes[0].set_xlim([0, train_sizes[-1]])
axes[0].set_ylim([0, 100+5])
axes[0].spines['right'].set_visible(False)
axes[0].spines['top'].set_visible(False)
# axes[0].set_xticks(train_sizes)
# axes[0].set_facecolor('#E6F0F8')

axes[1].grid()
# comp 1
axes[1].fill_between(train_sizes[:-1], test_scores_means[baselines[0]] - test_scores_stds[baselines[0]],
test_scores_means[baselines[0]] + test_scores_stds[baselines[0]], alpha=0.1, color="r")
axes[1].plot(train_sizes[:-1], test_scores_means[baselines[0]], 'o-', color="r", label=baselines[0], linewidth=3, markersize=12)
# comp 2
axes[1].fill_between(train_sizes[:-1], test_scores_means[baselines[1]] - test_scores_stds[baselines[1]],
test_scores_means[baselines[1]] + test_scores_stds[baselines[1]], alpha=0.1, color="g")
axes[1].plot(train_sizes[:-1], test_scores_means[baselines[1]], 'o-', color="g", label=baselines[1], linewidth=3, markersize=12)
# comp 3
axes[1].fill_between(train_sizes[:-1], test_scores_means[baselines[2]] - test_scores_stds[baselines[2]],
test_scores_means[baselines[2]] + test_scores_stds[baselines[2]], alpha=0.1, color="b")
axes[1].plot(train_sizes[:-1], test_scores_means[baselines[2]], 'o-', color="b", label=baselines[2], linewidth=3, markersize=12)
axes[1].set_xlabel("Time-steps")
axes[1].set_ylabel("Success Rate")
# axes[1].set_title("Adaptation Success Rate")
# axes[1].set_xticks(train_sizes)

# Pretrained
# axes[1].plot(train_sizes, test_scores_means[baselines[3]], '--', color="black", label=baselines[3], linewidth=3)

# deco
axes[1].title.set_size(30)
axes[1].xaxis.label.set_size(26)
axes[1].yaxis.label.set_size(26)
axes[1].tick_params(axis="x", labelsize=22)
axes[1].tick_params(axis="y", labelsize=22)
axes[1].set_xlim([0, train_sizes[-1]])
axes[1].set_ylim([0, 100+5])
axes[1].spines['right'].set_visible(False)
axes[1].spines['top'].set_visible(False)
# axes[1].set_facecolor('#E6F0F8')

plt.savefig("ded.png")