
import pickle
import inspect
import numpy as np

import tfutil
import networks

#----------------------------------------------------------------------------
# Custom unpickler that is able to load network pickles produced by
# the old Theano implementation.

class LegacyUnpickler(pickle.Unpickler):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def find_class(self, module, name):
        if module == 'network' and name == 'Network':
            return tfutil.Network
        return super().find_class(module, name)

#----------------------------------------------------------------------------
# Import handler for tfutil.Network that silently converts networks produced
# by the old Theano implementation to a suitable format.

theano_gan_remap = {
    'G_paper':          'G_paper',
    'G_progressive_8':  'G_paper',
    'D_paper':          'D_paper',
    'D_progressive_8':  'D_paper'}

def patch_theano_gan(state):
    if 'version' in state or state['build_func_spec']['func'] not in theano_gan_remap:
        return state

    spec = dict(state['build_func_spec'])
    func = spec.pop('func')
    resolution = spec.get('resolution', 32)
    resolution_log2 = int(np.log2(resolution))
    use_wscale = spec.get('use_wscale', True)

    assert spec.pop('label_size',       0)          == 0
    assert spec.pop('use_batchnorm',    False)      == False
    assert spec.pop('tanh_at_end',      None)       is None
    assert spec.pop('mbstat_func',      'Tstdeps')  == 'Tstdeps'
    assert spec.pop('mbstat_avg',       'all')      == 'all'
    assert spec.pop('mbdisc_kernels',   None)       is None
    spec.pop(       'use_gdrop',        True)       # doesn't make a difference
    assert spec.pop('use_layernorm',    False)      == False
    spec[           'fused_scale']                  = False
    spec[           'mbstd_group_size']             = 16

    vars = []
    param_iter = iter(state['param_values'])
    relu = np.sqrt(2); linear = 1.0
    def flatten2(w): return w.reshape(w.shape[0], -1)
    def he_std(gain, w): return gain / np.sqrt(np.prod(w.shape[:-1]))
    def wscale(gain, w): return w * next(param_iter) / he_std(gain, w) if use_wscale else w
    def layer(name, gain, w): return [(name + '/weight', wscale(gain, w)), (name + '/bias', next(param_iter))]
    
    if func.startswith('G'):
        vars += layer('4x4/Dense', relu/4, flatten2(next(param_iter).transpose(1,0,2,3)))
        vars += layer('4x4/Conv', relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1])
        for res in range(3, resolution_log2 + 1):
            vars += layer('%dx%d/Conv0' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1])
            vars += layer('%dx%d/Conv1' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1])
        for lod in range(0, resolution_log2 - 1):
            vars += layer('ToRGB_lod%d' % lod, linear, next(param_iter)[np.newaxis, np.newaxis])

    if func.startswith('D'):
        vars += layer('FromRGB_lod0', relu, next(param_iter)[np.newaxis, np.newaxis])
        for res in range(resolution_log2, 2, -1):
            vars += layer('%dx%d/Conv0' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1])
            vars += layer('%dx%d/Conv1' % (2**res, 2**res), relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1])
            vars += layer('FromRGB_lod%d' % (resolution_log2 - (res - 1)), relu, next(param_iter)[np.newaxis, np.newaxis])
        vars += layer('4x4/Conv', relu, next(param_iter).transpose(2,3,1,0)[::-1,::-1])
        vars += layer('4x4/Dense0', relu, flatten2(next(param_iter)[:,:,::-1,::-1]).transpose())
        vars += layer('4x4/Dense1', linear, next(param_iter))

    vars += [('lod', state['toplevel_params']['cur_lod'])]

    return {
        'version':          2,
        'name':             func,
        'build_module_src': inspect.getsource(networks),
        'build_func_name':  theano_gan_remap[func],
        'static_kwargs':    spec,
        'variables':        vars}

tfutil.network_import_handlers.append(patch_theano_gan)

#----------------------------------------------------------------------------
# Import handler for tfutil.Network that ignores unsupported/deprecated
# networks produced by older versions of the code.

def ignore_unknown_theano_network(state):
    if 'version' in state:
        return state

    print('Ignoring unknown Theano network:', state['build_func_spec']['func'])
    return {
        'version':          2,
        'name':             'Dummy',
        'build_module_src': 'def dummy(input, **kwargs): input.set_shape([None, 1]); return input',
        'build_func_name':  'dummy',
        'static_kwargs':    {},
        'variables':        []}

tfutil.network_import_handlers.append(ignore_unknown_theano_network)

#----------------------------------------------------------------------------
