---
dataset:
    _sizes:
        imagenet: 224
        cifar10: 32
        cifar100: 32
    _size: &size $(dataset._sizes.$(dataset.name))
    _large: !arith $(dataset._size) > 32
    task:
        background_class: {use: false}
        preprocess:
            shape:
                height: *size
                width: *size
                channels: 3
            _validate:
                imagenet:
                    # alternative validate preprocessing pipeline
                    # - {type: resize, height: 256, width: 256, fill: true}
                    # - {type: crop_or_pad, height: 224, width: 224}
                    - {type: central_crop, fraction: 0.875}
                cifar10: null
                cifar100: null
            validate: $(dataset.task.preprocess._validate.$(dataset.name))
            final_cpu:
                - {type: normalize_channels}
            final_gpu: null
model:
    name: resnet18
    description:
        ResNet18 implementation from::
            https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    layers:
        _conv: &conv
            type: convolution
            normalizer_fn: tensorflow.contrib.slim.batch_norm
            normalizer_params:
                scale: true
                decay: !arith 0.997 if $(dataset._large) else 0.9
                epsilon: 0.00001
            weights_initializer: &initializer
                type: tensorflow.initializers.he_normal
            weights_regularizer: &regularizer
                type: tensorflow.contrib.layers.l2_regularizer
                scale: !arith 0.0001 if $(dataset._large) else 0.0005
        conv1:
            <<: *conv
            kernel_size: !arith 7 if $(dataset._large) else 3
            stride: !arith 2 if $(dataset._large) else 1
            padding: !arith 3 if $(dataset._large) else 1
            num_outputs: 64
        pool1:
            type: !arith "'max_pool' if $(dataset._large) else 'identity'"
            kernel_size: 3
            stride: 2
            padding: !arith 1 if $(dataset._large) else None
        _basic_block: &bb
            type: module
            kwargs: {stride: null, depth: null, shortcut: null}
            layers:
                conv1: &bbconv
                    <<: *conv
                    kernel_size: 3
                    stride: ^(stride)
                    padding: 1
                    num_outputs: ^(depth)
                conv2:
                    <<: *bbconv
                    stride: 1
                    activation_fn: null
                downsample_shortcut:
                    <<: *bbconv
                    kernel_size: 1
                    padding: valid
                    activation_fn: null
                identity_shortcut: {type: identity}
                add: {type: add}
                relu: {type: activation, mode: relu}
            graph:
                - {from: input, with: [conv1, conv2], to: a}
                - {from: input, with: ^(shortcut)_shortcut, to: b}
                - {from: [a, b], with: add, to: preact}
                - {from: preact, with: relu, to: output}
        b11: {<<: *bb, stride: 1, depth: 64, shortcut: identity}
        b12: {<<: *bb, stride: 1, depth: 64, shortcut: identity}
        b21: {<<: *bb, stride: 2, depth: 128, shortcut: downsample}
        b22: {<<: *bb, stride: 1, depth: 128, shortcut: identity}
        b31: {<<: *bb, stride: 2, depth: 256, shortcut: downsample}
        b32: {<<: *bb, stride: 1, depth: 256, shortcut: identity}
        b41: {<<: *bb, stride: 2, depth: 512, shortcut: downsample}
        b42: {<<: *bb, stride: 1, depth: 512, shortcut: identity}
        pool4: {type: average_pool, kernel_size: global}
        flatten4: {type: flatten}
        fc5:
            type: fully_connected
            num_outputs: $(dataset.task.num_classes)
            activation_fn: null
            weights_initializer: *initializer
            # weights_regularizer: *regularizer
    graph:
        from: input
        with: [
            conv1, pool1,
            b11, b12, b21, b22, b31, b32, b41, b42,
            pool4, flatten4, fc5]
        to: output
