---
dataset:
    task:
        background_class: {use: true}
        preprocess:
            shape:
                height: 299
                width: 299
                channels: 3
            validate: {type: central_crop, fraction: 0.875}
            final_cpu:
                - {type: resize, fill: false}
                - {type: linear_map, scale: 2.0, shift: -1.0}
            # validate: null
            # _final_cpu_imagenet:
            #     - {type: subtract_channel_means}
            #     - {type: linear_map, scale: 255.0}
            # _final_cpu:
            #     - {type: linear_map, scale: 2.0, shift: -1.0}
            # final_cpu: !arith >
            #     $(dataset.task._final_cpu_imagenet)
            #     if $(dataset.name) == 'imagenet' else '_final_cpu'
            # validate:
            #     - {type: central_crop, fraction: 0.875}
            # final_cpu:
            #     - {type: subtract_channel_means}
            #     - {type: linear_map, scale: 255.0}
            #     # - {type: normalize_channels}
model:
    name: resnet_v2_50_slim
    description:
        ResNet50 implementation from::
            https://github.com/tensorflow/models/blob/master/research/slim/nets/resnet_v2.py
    layers:
        _conv: &conv
            type: convolution
            weights_initializer: &initializer
                type: tensorflow.variance_scaling_initializer
            weights_regularizer: &regularizer
                type: tensorflow.contrib.layers.l2_regularizer
                scale: 0.0001
            activation_fn: &activator tensorflow.nn.relu
            normalizer_fn: tensorflow.contrib.slim.batch_norm
            normalizer_params: &normalizer_params
                decay: 0.997
                epsilon: 0.00001
                scale: True
                center: True
        _bottleneck: &neck
            type: module
            kwargs:
                depth: null
                neck_depth: null
                stride: null
                padding: null
                shortcut: null
            layers:
                norm: &norm
                    <<: *normalizer_params
                    type: batch_normalization
                    activation_fn: *activator
                conv_shortcut:
                    <<: *conv
                    kernel_size: 1
                    stride: ^(stride)
                    num_outputs: ^(depth)
                    normalizer_fn: null
                    activation_fn: null
                pool_shortcut:
                    type: max_pool
                    kernel_size: 1
                    stride: ^(stride)
                conv1: &bottleneck_conv1
                    <<: *conv
                    kernel_size: 1
                    stride: 1
                    num_outputs: ^(neck_depth)
                conv2:
                    <<: *conv
                    kernel_size: 3
                    stride: ^(stride)
                    padding: ^(padding)
                    num_outputs: ^(neck_depth)
                conv3:
                    <<: *bottleneck_conv1
                    kernel_size: 1
                    stride: 1
                    num_outputs: ^(depth)
                    normalizer_fn: null
                    activation_fn: null
                add: {type: add}
            graph:
                - {from: input, with: norm, to: preact}
                - from: !arith "'preact' if ^(shortcut) == 'conv' else 'input'"
                  with: ^(shortcut)_shortcut
                  to: shortcut
                - {from: preact, with: [conv1, conv2, conv3], to: residual}
                - {from: [shortcut, residual], with: add, to: output}
        # root block
        conv1:
            <<: *conv
            kernel_size: 7
            stride: 2
            padding: 3
            num_outputs: 64
            activation_fn: null
            normalizer_fn: null
        pool1: {type: max_pool, kernel_size: 3, stride: 2, padding: same}
        _start: &start {<<: *neck, shortcut: conv, stride: 1, padding: same}
        _mid: &mid {<<: *neck, shortcut: pool, stride: 1, padding: same}
        _end: &end {<<: *neck, shortcut: pool, stride: 2, padding: 1}
        b11: {<<: *start, depth: 256, neck_depth: 64}
        b12: {<<: *mid, depth: 256, neck_depth: 64}
        b13: {<<: *end, depth: 256, neck_depth: 64}
        b21: {<<: *start, depth: 512, neck_depth: 128}
        b22: {<<: *mid, depth: 512, neck_depth: 128}
        b23: {<<: *mid, depth: 512, neck_depth: 128}
        b24: {<<: *end, depth: 512, neck_depth: 128}
        b31: {<<: *start, depth: 1024, neck_depth: 256}
        b32: {<<: *mid, depth: 1024, neck_depth: 256}
        b33: {<<: *mid, depth: 1024, neck_depth: 256}
        b34: {<<: *mid, depth: 1024, neck_depth: 256}
        b35: {<<: *mid, depth: 1024, neck_depth: 256}
        b36: {<<: *end, depth: 1024, neck_depth: 256}
        b41: {<<: *start, depth: 2048, neck_depth: 512}
        b42: {<<: *mid, depth: 2048, neck_depth: 512}
        # last bottleneck has stride=1
        b43: {<<: *mid, depth: 2048, neck_depth: 512}
        postnorm: {<<: *norm}
        pool5: {type: average_pool, kernel_size: 10}
        # pool5: {type: reduce_mean, axis: [1, 2], keep_dims: True}
        logits:
            <<: *conv
            kernel_size: 1
            num_outputs: $(dataset.task.num_classes)
            activation_fn: null
            normalizer_fn: null
        squeeze: {type: squeeze, axis: [1, 2]}
    graph:
        from: input
        with: [
            conv1, pool1,
            b11, b12, b13,
            b21, b22, b23, b24,
            b31, b32, b33, b34, b35, b36,
            b41, b42, b43,
            postnorm, pool5, logits, squeeze]
        to: output
