
import os
import sys
import glob
import datetime
import pickle
import re
import numpy as np
from collections import OrderedDict 
import scipy.ndimage
import PIL.Image

import config
import dataset
import legacy

#----------------------------------------------------------------------------
# Convenience wrappers for pickle that are able to load data produced by
# older versions of the code.

def load_pkl(filename):
    with open(filename, 'rb') as file:
        return legacy.LegacyUnpickler(file, encoding='latin1').load()

def save_pkl(obj, filename):
    with open(filename, 'wb') as file:
        pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL)

#----------------------------------------------------------------------------
# Image utils.

def adjust_dynamic_range(data, drange_in, drange_out):
    if drange_in != drange_out:
        scale = (np.float32(drange_out[1]) - np.float32(drange_out[0])) / (np.float32(drange_in[1]) - np.float32(drange_in[0]))
        bias = (np.float32(drange_out[0]) - np.float32(drange_in[0]) * scale)
        data = data * scale + bias
    return data

def create_image_grid(images, grid_size=None):
    assert images.ndim == 3 or images.ndim == 4
    num, img_w, img_h = images.shape[0], images.shape[-1], images.shape[-2]

    if grid_size is not None:
        grid_w, grid_h = tuple(grid_size)
    else:
        grid_w = max(int(np.ceil(np.sqrt(num))), 1)
        grid_h = max((num - 1) // grid_w + 1, 1)

    grid = np.zeros(list(images.shape[1:-2]) + [grid_h * img_h, grid_w * img_w], dtype=images.dtype)
    for idx in range(num):
        x = (idx % grid_w) * img_w
        y = (idx // grid_w) * img_h
        grid[..., y : y + img_h, x : x + img_w] = images[idx]
    return grid

def convert_to_pil_image(image, drange=[0,1]):
    assert image.ndim == 2 or image.ndim == 3
    if image.ndim == 3:
        if image.shape[0] == 1:
            image = image[0] # grayscale CHW => HW
        else:
            image = image.transpose(1, 2, 0) # CHW -> HWC

    image = adjust_dynamic_range(image, drange, [0,255])
    image = np.rint(image).clip(0, 255).astype(np.uint8)
    format = 'RGB' if image.ndim == 3 else 'L'
    return PIL.Image.fromarray(image, format)

def save_image(image, filename, drange=[0,1], quality=95):
    img = convert_to_pil_image(image, drange)
    if '.jpg' in filename:
        img.save(filename,"JPEG", quality=quality, optimize=True)
    else:
        img.save(filename)

def save_image_grid(images, filename, drange=[0,1], grid_size=None):
    convert_to_pil_image(create_image_grid(images, grid_size), drange).save(filename)

#----------------------------------------------------------------------------
# Logging of stdout and stderr to a file.

class OutputLogger(object):
    def __init__(self):
        self.file = None
        self.buffer = ''

    def set_log_file(self, filename, mode='wt'):
        assert self.file is None
        self.file = open(filename, mode)
        if self.buffer is not None:
            self.file.write(self.buffer)
            self.buffer = None

    def write(self, data):
        if self.file is not None:
            self.file.write(data)
        if self.buffer is not None:
            self.buffer += data

    def flush(self):
        if self.file is not None:
            self.file.flush()

class TeeOutputStream(object):
    def __init__(self, child_streams, autoflush=False):
        self.child_streams = child_streams
        self.autoflush = autoflush
 
    def write(self, data):
        for stream in self.child_streams:
            stream.write(data)
        if self.autoflush:
            self.flush()

    def flush(self):
        for stream in self.child_streams:
            stream.flush()

output_logger = None

def init_output_logging():
    global output_logger
    if output_logger is None:
        output_logger = OutputLogger()
        sys.stdout = TeeOutputStream([sys.stdout, output_logger], autoflush=True)
        sys.stderr = TeeOutputStream([sys.stderr, output_logger], autoflush=True)

def set_output_log_file(filename, mode='wt'):
    if output_logger is not None:
        output_logger.set_log_file(filename, mode)

#----------------------------------------------------------------------------
# Reporting results.

def create_result_subdir(result_dir, desc):

    # Select run ID and create subdir.
    while True:
        run_id = 0
        for fname in glob.glob(os.path.join(result_dir, '*')):
            try:
                fbase = os.path.basename(fname)
                ford = int(fbase[:fbase.find('-')])
                run_id = max(run_id, ford + 1)
            except ValueError:
                pass

        result_subdir = os.path.join(result_dir, '%03d-%s' % (run_id, desc))
        try:
            os.makedirs(result_subdir)
            break
        except OSError:
            if os.path.isdir(result_subdir):
                continue
            raise

    print("Saving results to", result_subdir)
    set_output_log_file(os.path.join(result_subdir, 'log.txt'))

    # Export config.
    try:
        with open(os.path.join(result_subdir, 'config.txt'), 'wt') as fout:
            for k, v in sorted(config.__dict__.items()):
                if not k.startswith('_'):
                    fout.write("%s = %s\n" % (k, str(v)))
    except:
        pass

    return result_subdir

def format_time(seconds):
    s = int(np.rint(seconds))
    if s < 60:         return '%ds'                % (s)
    elif s < 60*60:    return '%dm %02ds'          % (s // 60, s % 60)
    elif s < 24*60*60: return '%dh %02dm %02ds'    % (s // (60*60), (s // 60) % 60, s % 60)
    else:              return '%dd %02dh %02dm'    % (s // (24*60*60), (s // (60*60)) % 24, (s // 60) % 60)

#----------------------------------------------------------------------------
# Locating results.

def locate_result_subdir(run_id_or_result_subdir):
    if isinstance(run_id_or_result_subdir, str) and os.path.isdir(run_id_or_result_subdir):
        return run_id_or_result_subdir

    searchdirs = []
    searchdirs += ['']
    searchdirs += ['results']
    searchdirs += ['networks']

    for searchdir in searchdirs:
        dir = config.result_dir if searchdir == '' else os.path.join(config.result_dir, searchdir)
        dir = os.path.join(dir, str(run_id_or_result_subdir))
        if os.path.isdir(dir):
            return dir
        prefix = '%03d' % run_id_or_result_subdir if isinstance(run_id_or_result_subdir, int) else str(run_id_or_result_subdir)
        dirs = sorted(glob.glob(os.path.join(config.result_dir, searchdir, prefix + '-*')))
        dirs = [dir for dir in dirs if os.path.isdir(dir)]
        if len(dirs) == 1:
            return dirs[0]
    raise IOError('Cannot locate result subdir for run', run_id_or_result_subdir)

def list_network_pkls(run_id_or_result_subdir, include_final=True):
    result_subdir = locate_result_subdir(run_id_or_result_subdir)
    pkls = sorted(glob.glob(os.path.join(result_subdir, 'network-*.pkl')))
    if len(pkls) >= 1 and os.path.basename(pkls[0]) == 'network-final.pkl':
        if include_final:
            pkls.append(pkls[0])
        del pkls[0]
    return pkls

def locate_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot=None):
    if isinstance(run_id_or_result_subdir_or_network_pkl, str) and os.path.isfile(run_id_or_result_subdir_or_network_pkl):
        return run_id_or_result_subdir_or_network_pkl

    pkls = list_network_pkls(run_id_or_result_subdir_or_network_pkl)
    if len(pkls) >= 1 and snapshot is None:
        return pkls[-1]
    for pkl in pkls:
        try:
            name = os.path.splitext(os.path.basename(pkl))[0]
            number = int(name.split('-')[-1])
            if number == snapshot:
                return pkl
        except ValueError: pass
        except IndexError: pass
    raise IOError('Cannot locate network pkl for snapshot', snapshot)

def get_id_string_for_network_pkl(network_pkl):
    p = network_pkl.replace('.pkl', '').replace('\\', '/').split('/')
    return '-'.join(p[max(len(p) - 2, 0):])

#----------------------------------------------------------------------------
# Loading and using trained networks.

def load_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot=None):
    return load_pkl(locate_network_pkl(run_id_or_result_subdir_or_network_pkl, snapshot))

def random_latents(num_latents, G, random_state=None):
    if random_state is not None:
        return random_state.randn(num_latents, *G.input_shape[1:]).astype(np.float32)
    else:
        return np.random.randn(num_latents, *G.input_shape[1:]).astype(np.float32)

def load_dataset_for_previous_run(run_id, **kwargs): # => dataset_obj, mirror_augment
    result_subdir = locate_result_subdir(run_id)

    # Parse config.txt.
    parsed_cfg = dict()
    with open(os.path.join(result_subdir, 'config.txt'), 'rt') as f:
        for line in f:
            if line.startswith('dataset =') or line.startswith('train ='):
                exec(line, parsed_cfg, parsed_cfg)
    dataset_cfg = parsed_cfg.get('dataset', dict())
    train_cfg = parsed_cfg.get('train', dict())
    mirror_augment = train_cfg.get('mirror_augment', False)

    # Handle legacy options.
    if 'h5_path' in dataset_cfg:
        dataset_cfg['tfrecord_dir'] = dataset_cfg.pop('h5_path').replace('.h5', '')
    if 'mirror_augment' in dataset_cfg:
        mirror_augment = dataset_cfg.pop('mirror_augment')
    if 'max_labels' in dataset_cfg:
        v = dataset_cfg.pop('max_labels')
        if v is None: v = 0
        if v == 'all': v = 'full'
        dataset_cfg['max_label_size'] = v
    if 'max_images' in dataset_cfg:
        dataset_cfg.pop('max_images')

    # Handle legacy dataset names.
    v = dataset_cfg['tfrecord_dir']
    v = v.replace('-32x32', '').replace('-32', '')
    v = v.replace('-128x128', '').replace('-128', '')
    v = v.replace('-256x256', '').replace('-256', '')
    v = v.replace('-1024x1024', '').replace('-1024', '')
    v = v.replace('celeba-hq', 'celebahq')
    v = v.replace('cifar-10', 'cifar10')
    v = v.replace('cifar-100', 'cifar100')
    v = v.replace('mnist-rgb', 'mnistrgb')
    v = re.sub('lsun-100k-([^-]*)', 'lsun-\\1-100k', v)
    v = re.sub('lsun-full-([^-]*)', 'lsun-\\1-full', v)
    dataset_cfg['tfrecord_dir'] = v

    # Load dataset.
    dataset_cfg.update(kwargs)
    dataset_obj = dataset.load_dataset(data_dir=config.data_dir, **dataset_cfg)
    return dataset_obj, mirror_augment

def apply_mirror_augment(minibatch):
    mask = np.random.rand(minibatch.shape[0]) < 0.5
    minibatch = np.array(minibatch)
    minibatch[mask] = minibatch[mask, :, :, ::-1]
    return minibatch

#----------------------------------------------------------------------------
# Text labels.

_text_label_cache = OrderedDict()

def draw_text_label(img, text, x, y, alignx=0.5, aligny=0.5, color=255, opacity=1.0, glow_opacity=1.0, **kwargs):
    color = np.array(color).flatten().astype(np.float32)
    assert img.ndim == 3 and img.shape[2] == color.size or color.size == 1
    alpha, glow = setup_text_label(text, **kwargs)
    xx, yy = int(np.rint(x - alpha.shape[1] * alignx)), int(np.rint(y - alpha.shape[0] * aligny))
    xb, yb = max(-xx, 0), max(-yy, 0)
    xe, ye = min(alpha.shape[1], img.shape[1] - xx), min(alpha.shape[0], img.shape[0] - yy)
    img = np.array(img)
    slice = img[yy+yb : yy+ye, xx+xb : xx+xe, :]
    slice[:] = slice * (1.0 - (1.0 - (1.0 - alpha[yb:ye, xb:xe]) * (1.0 - glow[yb:ye, xb:xe] * glow_opacity)) * opacity)[:, :, np.newaxis]
    slice[:] = slice + alpha[yb:ye, xb:xe, np.newaxis] * (color * opacity)[np.newaxis, np.newaxis, :]
    return img

def setup_text_label(text, font='Calibri', fontsize=32, padding=6, glow_size=2.0, glow_coef=3.0, glow_exp=2.0, cache_size=100): # => (alpha, glow)
    # Lookup from cache.
    key = (text, font, fontsize, padding, glow_size, glow_coef, glow_exp)
    if key in _text_label_cache:
        value = _text_label_cache[key]
        del _text_label_cache[key] # LRU policy
        _text_label_cache[key] = value
        return value

    # Limit cache size.
    while len(_text_label_cache) >= cache_size:
        _text_label_cache.popitem(last=False)

    # Render text.
    import moviepy.editor # pip install moviepy
    alpha = moviepy.editor.TextClip(text, font=font, fontsize=fontsize).mask.make_frame(0)
    alpha = np.pad(alpha, padding, mode='constant', constant_values=0.0)
    glow = scipy.ndimage.gaussian_filter(alpha, glow_size)
    glow = 1.0 - np.maximum(1.0 - glow * glow_coef, 0.0) ** glow_exp

    # Add to cache.
    value = (alpha, glow)
    _text_label_cache[key] = value
    return value

#----------------------------------------------------------------------------
