---
dataset:
    task:
        background_class: {use: false}
        preprocess:
            shape:
                height: 224
                width: 224
                channels: 3
            validate:
                - {type: resize, height: 256, width: 256, fill: true}
                - {type: crop_or_pad, height: 224, width: 224}
            final_cpu:
                - {type: linear_map, scale: 255.0, shift: 0.0}
                - type: subtract_channel_means
                  means: [123.680, 116.779, 103.939]
                - {type: permute_channels, order: [2, 1, 0]}
model:
    version: 1
    name: resnet_v$(model.version)_50
    description: |
        ResNet-50 implementation from::
        https://github.com/keras-team/keras-applications/blob/master/keras_applications/resnet50.py
    layers:
        _common: &common
            weights_initializer:
                type: tensorflow.initializers.he_normal
            normalizer_fn: tensorflow.contrib.slim.batch_norm
            normalizer_params:
                center: true
                scale: true
                decay: 0.99
                epsilon: 0.001
        _conv: &conv
            <<: *common
            type: convolution
            padding: valid
        conv1:
            <<: *conv
            kernel_size: 7
            stride: 2
            num_outputs: 64
            padding: 3
        pool1:
            type: max_pool
            kernel_size: 3
            stride: 2
            padding: 1
        _bottleneck: &bb
            type: module
            kwargs: {stride: null, depth: null, shortcut: null}
            layers:
                conv1:
                    <<: *conv
                    kernel_size: 1
                    stride: ^(stride)
                    num_outputs: ^(depth)
                conv2:
                    <<: *conv
                    kernel_size: 3
                    stride: 1
                    padding: same
                    num_outputs: ^(depth)
                conv3:
                    <<: *conv
                    kernel_size: 1
                    stride: 1
                    num_outputs: &final_depth !arith 4 * ^(depth)
                    activation_fn: null
                conv_shortcut:
                    <<: *conv
                    kernel_size: 1
                    stride: ^(stride)
                    num_outputs: *final_depth
                    activation_fn: null
                identity_shortcut: {type: identity}
                add: {type: add}
                relu: {type: activation, mode: relu}
            graph:
                - {from: input, with: [conv1, conv2, conv3], to: residual}
                - {from: input, with: ^(shortcut)_shortcut, to: shortcut}
                - {from: [residual, shortcut], with: [add, relu], to: output}
        b2a: {<<: *bb, stride: 1, depth: 64, shortcut: conv}
        b2b: {<<: *bb, stride: 1, depth: 64, shortcut: identity}
        b2c: {<<: *bb, stride: 1, depth: 64, shortcut: identity}
        b3a: {<<: *bb, stride: 2, depth: 128, shortcut: conv}
        b3b: {<<: *bb, stride: 1, depth: 128, shortcut: identity}
        b3c: {<<: *bb, stride: 1, depth: 128, shortcut: identity}
        b3d: {<<: *bb, stride: 1, depth: 128, shortcut: identity}
        b4a: {<<: *bb, stride: 2, depth: 256, shortcut: conv}
        b4b: {<<: *bb, stride: 1, depth: 256, shortcut: identity}
        b4c: {<<: *bb, stride: 1, depth: 256, shortcut: identity}
        b4d: {<<: *bb, stride: 1, depth: 256, shortcut: identity}
        b4e: {<<: *bb, stride: 1, depth: 256, shortcut: identity}
        b4f: {<<: *bb, stride: 1, depth: 256, shortcut: identity}
        b5a: {<<: *bb, stride: 2, depth: 512, shortcut: conv}
        b5b: {<<: *bb, stride: 1, depth: 512, shortcut: identity}
        b5c: {<<: *bb, stride: 1, depth: 512, shortcut: identity}
        avg_pool: {type: average_pool, kernel_size: global}
        flatten: {type: flatten}
        fc:
            <<: *common
            type: fully_connected
            num_outputs: $(dataset.task.num_classes)
            activation_fn: null
            normalizer_fn: null
    graph:
        from: input
        with: [
            conv1, pool1,
            b2a, b2b, b2c,
            b3a, b3b, b3c, b3d,
            b4a, b4b, b4c, b4d, b4e, b4f,
            b5a, b5b, b5c,
            avg_pool, flatten, fc]
        to: output
