import tensorflow as tf
import os
SIZE = 224
CROP_FRACTION = 0.875

MEAN_IMAGENET = tf.constant([0.485, 0.456, 0.406], shape=[3], dtype=tf.bfloat16)
STD_IMAGENET  =  tf.constant([0.229, 0.224, 0.225], shape=[3], dtype=tf.bfloat16)
DIVISOR = tf.cast(1.0 /255.0 , tf.bfloat16)
STD_DIVISOR = tf.cast(1.0 / STD_IMAGENET, tf.bfloat16)
AUTO = tf.data.AUTOTUNE
IMAGENET_FEATURE_DESCRIPTION = {
    'image/encoded': tf.io.FixedLenFeature([], tf.string),
    'image/class/label': tf.io.FixedLenFeature([], tf.int64),
}

def mixup(images, labels, proba = 0.5):
  
    img1, img2 = images[0], images[1]
    lab1, lab2 = labels[0], labels[1]

    sz = img1.shape[0]

    p1 = tf.cast( tf.random.uniform([],0,1) <= proba, tf.bfloat16) 
    p2 = tf.cast( tf.random.uniform([],0,1) <= proba, tf.bfloat16) 

    alpha_1 = tf.random.uniform([], 0, 1, dtype=tf.bfloat16) * p1
    alpha_2 = tf.random.uniform([], 0, 1, dtype=tf.bfloat16) * p2

    img_mixup_1 = (1. - alpha_1) * img1 + alpha_1 * img2
    img_mixup_2 = (1. - alpha_2) * img2 + alpha_2 * img1

    label_mixup_1 = (1. - alpha_1) * lab1 + alpha_1 * lab2
    label_mixup_2 = (1. - alpha_2) * lab2 + alpha_2 * lab1

    images = tf.stack([img_mixup_1, img_mixup_2])
    labels = tf.stack([label_mixup_1, label_mixup_2])

    images = tf.reshape(images, (2, sz, sz, 3))
    labels = tf.reshape(labels, (2, 1000))

    return images, labels

def flip_left_right(image, label):

    seed = tf.random.uniform([2], maxval=1_000, dtype=tf.int32)

    image = tf.image.stateless_random_flip_left_right(image, seed)

    return image, label



def _distorted_bounding_box_crop(image_bytes,
                                bbox,
                                min_object_covered=0.1,
                                aspect_ratio_range=(0.75, 1.33),
                                area_range=(0.05, 1.0),
                                max_attempts=100,
                                scope=None):
    shape = tf.image.extract_jpeg_shape(image_bytes)
    sample_distorted_bounding_box = tf.image.sample_distorted_bounding_box(
        shape,
        bounding_boxes=bbox,
        min_object_covered=min_object_covered,
        aspect_ratio_range=aspect_ratio_range,
        area_range=area_range,
        max_attempts=max_attempts,
        use_image_if_no_bounding_boxes=True)
    bbox_begin, bbox_size, _ = sample_distorted_bounding_box

    # Crop the image to the specified bounding box.
    offset_y, offset_x, _ = tf.unstack(bbox_begin)
    target_height, target_width, _ = tf.unstack(bbox_size)
    crop_window = tf.stack([offset_y, offset_x, target_height, target_width])

    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)

    return image


def decode_and_center_crop(image_bytes, image_size):
    shape = tf.image.extract_jpeg_shape(image_bytes)
    image_height = shape[0]
    image_width = shape[1]

    # crop_fraction = image_size / (image_size + crop_padding)
    crop_padding = round(image_size * (1/CROP_FRACTION - 1))
    padded_center_crop_size = tf.cast(
      ((image_size / (image_size + crop_padding)) *
       tf.cast(tf.minimum(image_height, image_width), tf.float32)),
      tf.int32)

    offset_height = ((image_height - padded_center_crop_size) + 1) // 2
    offset_width = ((image_width - padded_center_crop_size) + 1) // 2
    crop_window = tf.stack([offset_height, offset_width,
                          padded_center_crop_size, padded_center_crop_size])

    image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3)
    image = tf.image.resize([image], [image_size, image_size], method='bicubic')[0]

    return image

def _at_least_x_are_equal(a, b, x):
    match = tf.equal(a, b)
    match = tf.cast(match, tf.int32)
    return tf.greater_equal(tf.reduce_sum(match), x)

def decode_and_random_crop(image_bytes, image_size):
    bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4])
    image = _distorted_bounding_box_crop(
      image_bytes,
      bbox,
      min_object_covered=0.1,
      aspect_ratio_range=(3. / 4, 4. / 3.),
      area_range=(0.08, 1.0),
      max_attempts=10,
      scope=None)

    original_shape = tf.image.extract_jpeg_shape(image_bytes)
    bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3)

    image = tf.cond(
      bad,
      lambda: decode_and_center_crop(image_bytes, image_size),
      lambda: tf.image.resize([image], [image_size, image_size], method='bicubic')[0])

    return image



def normalize(image):
    image = tf.cast(image, tf.bfloat16)

    image = image * DIVISOR
    image = image - MEAN_IMAGENET
    image = image * STD_DIVISOR

    return image

def normalize_vgg(image):
    image = tf.cast(image, tf.bfloat16)

    #image = image * DIVISOR
    image = image -( MEAN_IMAGENET *255.0)
    #image = image * STD_DIVISOR

    return image

def _apply_mixup(dataset):
    dataset = dataset.shuffle(1024)
    dataset = dataset.batch(2)
    dataset = dataset.map(mixup, num_parallel_calls=AUTO)
    dataset = dataset.unbatch()
    dataset = dataset.shuffle(1024)
    return dataset


def _preprocess_for_train(image_bytes, image_size=SIZE):
    image = decode_and_random_crop(image_bytes, image_size)

    image = tf.reshape(image, [image_size, image_size, 3])
    #image = tf.cast(image, tf.bfloat16)
    image = normalize_vgg(image)
    return image


def _preprocess_for_eval(image_bytes, image_size=SIZE):
    image = decode_and_center_crop(image_bytes, image_size)

    image = tf.reshape(image, [image_size, image_size, 3])
    #image = normalize_vgg(image)
    image = tf.cast(image, tf.bfloat16)

    return image


def _init_shards(shards, training=False):
    ignore_order = tf.data.Options()
    ignore_order.experimental_deterministic = False

    dataset = tf.data.TFRecordDataset(shards, num_parallel_reads=AUTO, buffer_size=8 * 1024 * 1024)
    dataset = dataset.with_options(ignore_order)

    dataset = dataset.repeat()

    #if training:
    #    dataset = dataset.shuffle(100)

    return dataset


def _parse_imagenet_prototype(prototype, size, training=False):
    data = tf.io.parse_single_example(prototype, IMAGENET_FEATURE_DESCRIPTION)

    # imagenet label are [1-1000] -> [0, 999]
    # also, ensure all labels are one-hot for mixup
    label = tf.cast(data['image/class/label'], tf.int32) - 1
    label = tf.one_hot(label, 1_000, dtype=tf.bfloat16)

    if training:
        image = _preprocess_for_eval(data['image/encoded'], size)
    else:
        image = _preprocess_for_eval(data['image/encoded'], size)

    #heatmap = tf.zeros((SIZE, SIZE, 1), dtype=tf.bfloat16)

    return image, label

def get_train_dataset(batch_size, mixup=False, size=SIZE, imagenet_train_rep = "/mnt/terminus/imagenet/extracted/data/imagenet2012/5.1.0/" ):
    imagenet_train_shards = [f"{imagenet_train_rep}{f}" for f in os.listdir(imagenet_train_rep) if 'train' in f]
    #print(os.listdir(imagenet_train_rep)[:10])
    imagenet_train_shards = imagenet_train_shards[:2]
    print(imagenet_train_shards)
    imagenet_dataset = _init_shards(imagenet_train_shards, training=True).map(
        lambda proto: _parse_imagenet_prototype(proto, size, training=True), num_parallel_calls=AUTO)

    imagenet_dataset = imagenet_dataset.apply(tf.data.experimental.ignore_errors())

    if mixup:
        imagenet_dataset = _apply_mixup(imagenet_dataset)

    train_dataset = imagenet_dataset.map(flip_left_right, num_parallel_calls=AUTO)
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)

    train_dataset = train_dataset.prefetch(AUTO)

    return train_dataset

def get_imagenet_val_dataset(batch_size, size,imagenet_validation_rep = "/mnt/terminus/imagenet/extracted/data/imagenet2012/5.1.0/"):
    imagenet_validation_shards = [f"{imagenet_validation_rep}{f}" for f in os.listdir(imagenet_validation_rep) if 'validation' in f]
    print(imagenet_validation_shards[:10])
    imagenet_validation_shards = imagenet_validation_shards[:2]
    val_imagenet = _init_shards(imagenet_validation_shards, training=False).map(
        lambda proto: _parse_imagenet_prototype(proto, size, training=False), num_parallel_calls=AUTO)

    val_imagenet = val_imagenet.batch(batch_size, drop_remainder=True)
    val_imagenet = val_imagenet.prefetch(AUTO)

    return val_imagenet


def random_pca(image,pca_std = 0.1):
    eigval = tf.transpose(tf.convert_to_tensor ([[55.46, 4.794, 1.148]]))
    eigvec = tf.convert_to_tensor ([[-0.5836, -0.6948, 0.4203],
          [-0.5808, -0.0045, -0.8140],
          [-0.5675, 0.7192, 0.4009]])
    alpha = tf.random.normal((3,),0, pca_std)
    offset = (alpha*eigvec )@ eigval
    image = image + tf.squeeze(offset)
    return tf.clip_by_value(image,0,255)




def imagenet_dataset(batch_size,
                     preprocess = "VGG",
                     size_min = 256,
                     size_max = 386,
                     shuffle = 0,
                     write_dir = "/mnt/terminus/imagenet/extracted/data/imagenet2012/5.1.0/",
                     compute_train_val = False):
    
    
    train = get_train_dataset(batch_size, mixup=False, size=SIZE, imagenet_train_rep = write_dir )
    
    val = get_imagenet_val_dataset(batch_size, size=SIZE,imagenet_validation_rep = write_dir)
    if compute_train_val :
        
        return train, val,val, None
    else :
        return train, val, None