# Copyright 2020 Self-Distillation Authors.


# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

# https://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# This code has been tested on Colab using Tensorflow version 1.15.0,
# Keras version 2.2.5, and numpy version 1.17.5

import tensorflow as tf
from tensorflow import keras
tf.enable_eager_execution()

# the main loss function for training with self-distillation
def self_distillation_loss(labels, logits, model, reg_coef, teacher=None, data=None):
  if teacher is None:
    main_loss = tf.reduce_mean(tf.squared_difference(labels, 
                                                     tf.nn.softmax(logits)))
  else:
    main_loss = tf.reduce_mean(tf.squared_difference(tf.nn.softmax(teacher(data)), 
                                                     tf.nn.softmax(logits)))
  reg_loss = reg_coef*tf.add_n([tf.nn.l2_loss(w) for w in model.trainable_weights])
  total_loss = main_loss + reg_loss
  return total_loss

def get_metrics(model, x_test, y_test, teacher=None):
  y_test_pred = model.predict(x_test)
  acc = tf.reduce_mean(tf.cast(tf.nn.in_top_k(y_test_pred, tf.argmax(y_test, axis=1), 
                                              k=1), tf.float32))
  loss = self_distillation_loss(y_test, y_test_pred, model, reg_coef, teacher, x_test)
  return acc.numpy(), loss.numpy()
  
def get_cifar10():
  (x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()
  x_train, x_test = x_train.astype('float32') / 255.,  x_test.astype('float32') / 255.
  y_train, y_test = keras.utils.to_categorical(y_train, 10), keras.utils.to_categorical(y_test, 10)
  return  x_train, y_train, x_test, y_test
  
def get_resnet_model():
  return keras.applications.resnet.ResNet50(
      include_top=True, weights=None, input_tensor=None, 
      input_shape=[32, 32, 3], classes=10)

# the main training procedure for single step of self-distillation
def self_distillation_train(model, train_dataset, optimizer, reg_coef=1e-5, 
                            epochs=60, teacher=None):
  for epoch in range(epochs):
    print('Start of epoch %d' % (epoch,))
    for iter, (x_batch_train, y_batch_train) in enumerate(train_dataset):
      with tf.GradientTape() as tape:
        logits = model(x_batch_train, training=True)
        loss_value = self_distillation_loss(y_batch_train, logits, model, 
                                            reg_coef, teacher, x_batch_train)
      grads = tape.gradient(loss_value, model.trainable_weights)
      optimizer.apply_gradients(zip(grads, model.trainable_weights))
    if verbose and epoch % 2 == 0:
        acc, loss = get_metrics(model, x_test, y_test, teacher)
        print('epoch %d test accuracy %s and loss %s (for 1 batch)' % (epoch, acc, loss))
  return model

# hyperparameters
batch_size = 16
epochs = 20  # ~64000*16/5000
reg_coef = 1e-4
learning_rate = 1e-3

# self-distillation parameters
distillation_steps = 10
verbose = True

# reading data  
x_train, y_train, x_test, y_test = get_cifar10()
train_dataset = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(batch_size)

# self-distillation steps
teacher = None
test_accs = []
for step in range(distillation_steps):
  model = get_resnet_model()
  optimizer = keras.optimizers.Adam(learning_rate=learning_rate)
  model = self_distillation_train(model, train_dataset, optimizer, 
                                  reg_coef, epochs, teacher)
  acc, loss = get_metrics(model, x_test, y_test, teacher)
  test_accs.append(acc)
  print('distillation step %d, test accuracy %f and loss %f ' % (step, acc, loss))
  teacher = model
  