import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import os

train_loss_surface = np.loadtxt("train_loss_surface.txt", delimiter=' ')
test_loss_surface = np.loadtxt("test_loss_surface.txt", delimiter=' ')
test_loss_surface_average = test_loss_surface - np.min(test_loss_surface) + np.min(train_loss_surface)
test_loss_surface_shifted = np.loadtxt("test_loss_surface.txt", delimiter=' ')


data_shape = train_loss_surface.shape
throw_num = 20

distances_scale = 0.1
window_len = data_shape[1] - throw_num

print(data_shape)
#os.mkdir('pics')

shift_data = np.zeros(data_shape[0])
shift_loss = np.ones(data_shape[0]) * 100
shift_loss_before = np.ones(data_shape[0]) * 100
# train_loss_surface = train_loss_surface[:, int(throw_num / 2) : int(throw_num / 2) + window_len]
# print('cut done', train_loss_surface.shape)

shift_list = np.array([0,1, 2,3, 4,5, 6,7, 8])
color_list = ['r','orange','y','yellowgreen','dodgerblue','g','purple','pink','black'
]
shift_list = shift_list + throw_num / 2
for shift in shift_list:
    for parallel_num in range(data_shape[0]):
        # print(np.min(train_loss_surface[parallel_num]))
        # print(np.min(test_loss_surface_average[parallel_num]))
        shift = int(shift)
        print('shift is', shift)
        test_loss_surface_average[parallel_num] = test_loss_surface[parallel_num] - np.min(
        test_loss_surface[parallel_num]) + np.min(train_loss_surface[parallel_num])

        sum = np.max(abs(train_loss_surface[parallel_num][int(throw_num / 2) : int(throw_num / 2) + window_len] - test_loss_surface_average[parallel_num][shift : shift + window_len]))

        print(sum)
        if shift==throw_num/2:
            shift_loss_before[parallel_num] = sum
        shift_loss[parallel_num] = sum
        shift_data[parallel_num] = shift


    print(shift_loss)
    print('---')
    print(shift_data)
    print(shift_loss_before)
    # plt.rcParams['figure.figsize'] = (5.0, 4.0)
    plt.subplots_adjust(bottom=.15, top=.99, left=.15, right=.99)
    plt.tick_params(labelsize=18)
    plt.scatter(0.1*(shift_data-throw_num/2) , shift_loss/shift_loss_before, c=color_list[int(shift-throw_num/2)], marker='x',s=120)
    plt.ylabel(r'$\xi_{\bf{\delta}} /\xi_0 $',fontsize=25)
    plt.xlabel(r'$\bf{\delta}$',fontsize=25)
    plt.xlim((-0.05, 0.85))
    # plt.rcParams['figure.figsize'] = (5.0, 4.0)
    plt.savefig('ratio_2.png')
    plt.savefig('ratio_2.pdf')
    # plt.show()


