import numpy as np
import matplotlib.pyplot as plt
plt.switch_backend('agg')
import os

train_loss_surface = np.loadtxt("train_loss_surface.txt", delimiter=' ')
test_loss_surface = np.loadtxt("test_loss_surface.txt", delimiter=' ')
test_loss_surface_average = test_loss_surface - np.min(test_loss_surface) + np.min(train_loss_surface)
test_loss_surface_shifted = np.loadtxt("test_loss_surface.txt", delimiter=' ')


data_shape = train_loss_surface.shape
throw_num = 20

distances_scale = 0.1
window_len = data_shape[1] - throw_num

print(data_shape)
#os.mkdir('pics')

shift_data = np.zeros(data_shape[0])
shift_loss = np.ones(data_shape[0]) * 100
shift_loss_before = np.ones(data_shape[0]) * 100
# train_loss_surface = train_loss_surface[:, int(throw_num / 2) : int(throw_num / 2) + window_len]
# print('cut done', train_loss_surface.shape)


for parallel_num in range(data_shape[0]):
    # print(np.min(train_loss_surface[parallel_num]))
    # print(np.min(test_loss_surface_average[parallel_num]))
    test_loss_surface_average[parallel_num] = test_loss_surface[parallel_num] - np.min(
    test_loss_surface[parallel_num]) + np.min(train_loss_surface[parallel_num])

    for shift in range(0, throw_num + 1):
        # print(train_loss_surface[parallel_num][int(throw_num / 2) : int(throw_num / 2) + window_len])
        # print(test_loss_surface_average[parallel_num][shift : shift + window_len])

        sum = np.max(abs(train_loss_surface[parallel_num][int(throw_num / 2) : int(throw_num / 2) + window_len] - test_loss_surface_average[parallel_num][shift : shift + window_len]))
        print('shift is', shift)
        print(sum)
        if shift==throw_num/2:
            shift_loss_before[parallel_num] = sum
        if sum < shift_loss[parallel_num]:
            shift_loss[parallel_num] = sum
            shift_data[parallel_num] = shift
    # if test_loss_surface_shifted.all()==test_loss_surface.all():
    #     print('equal')
    # else:
    #     print('no')
    plt.figure()
    plt.rcParams['figure.figsize'] = (5.0, 4.0)
    plt.subplots_adjust(bottom=.15, top=.99, left=.14, right=.99)
    plt.tick_params(labelsize=18)
    # plt.axes(yscale='log')
    # plt.subplots_adjust(bottom=.01, top=.99, left=.01, right=.99)
    plt.plot(np.linspace(-(window_len-1) / 2 * distances_scale, (window_len- 1) / 2 * distances_scale, num = window_len),
             train_loss_surface[parallel_num, int(throw_num / 2) : int(throw_num / 2) + window_len],
             label=r'Training loss $\hat L$',  # + str(parallel_num)
             color='dodgerblue',
             linewidth=3)
    plt.plot(np.linspace(-(data_shape[1]- 1) / 2 * distances_scale, (data_shape[1]- 1) / 2 * distances_scale, num = data_shape[1]),
             test_loss_surface_average[parallel_num],
             label=r"Test loss $L'$",
             alpha=0.6,
             color='r',
             linewidth=3)
    plt.plot(np.linspace((-(data_shape[1]- 1) / 2-shift_data[parallel_num]+throw_num/2) * distances_scale, ((data_shape[1]- 1) / 2-shift_data[parallel_num]+throw_num/2) * distances_scale, num = data_shape[1]),
             test_loss_surface_average[parallel_num],
             label='Shifted test loss',
             linestyle='dashed',
             alpha=0.6,
             color='r',
             linewidth=3)

    # plt.arrow(-1, 0.2, 2, 0, head_width=0.05, head_length=0.1, fc='r', ec='r')
    plt.annotate("", xy=(-1,0.2), xytext=(1, 0.2),
                 arrowprops=dict(arrowstyle="-|>, head_width=.4, head_length=.8",
                                 facecolor='r'))
    plt.text(-1.3, 0.25, 'Shift direction', ha='left', rotation=0, wrap=True, size=16)
    plt.legend(loc='upper right', fontsize=20)
    fig_name = 'parallel' + str(parallel_num)
    plt.xlabel(r'An asymmetric direction $\bf{u}$',fontsize=25)
    plt.ylabel('Loss',fontsize=25)
    # plt.rcParams['figure.figsize'] = (5.0, 4.0)
    plt.savefig(os.path.join(fig_name))
    plt.savefig(os.path.join(fig_name + '.pdf'))
print(shift_loss)
print('---')
print(shift_data)
print(shift_loss_before)




