(setv attach-dir (quote "."))
(require [hy.contrib.walk [let]])

(import jax
        [jax.numpy :as jnp]
        [jax.experimental.stax :as stax]
        [neural_tangents :as nt]
        [neural_tangents [stax :as nt-stax]]
        [jax.experimental.optimizers :as optimizers]
        [jax.flatten_util [ravel_pytree]]
        [numpy :as np]
        [matplotlib.pyplot :as plt]
        [tqdm [tqdm trange]]
        [sklearn.model_selection [train_test_split]]
        [toolz.dicttoolz [merge]]
        [math [ceil]]
        [nn_utilities :as nn_utils]
        os
        pickle)

(defmacro bound? [x]
  `(try ~x
        (except [NameError] False)
        (else True)))

(defmacro default [x d]
  `(if (bound? ~x) ~x ~d))

(setv num-outputs 10
      net
      (do
        (defn IdentityBlock [out-chan]
          (setv block
                (stax.serial
                  (stax.Conv :out-chan out-chan :filter-shape (, 3 3) :padding "SAME")
                  (stax.BatchNorm)
                  stax.Relu
                  (stax.Conv :out-chan out-chan :filter-shape (, 3 3) :padding "SAME")
                  (stax.BatchNorm)))
          (stax.serial
            (stax.FanOut 2)
            (stax.parallel block stax.Identity)
            stax.FanInSum
            stax.Relu))

        (defn ConvBlock [out-chan [stride 1]]
          (setv block
                (stax.serial
                  (stax.Conv :out-chan out-chan :filter-shape (, 3 3) :strides (, stride stride) :padding "SAME")
                  (stax.BatchNorm)
                  stax.Relu
                  (stax.Conv :out-chan out-chan :filter-shape (, 3 3) :padding "SAME")
                  (stax.BatchNorm))
                shortcut
                (stax.serial
                  (stax.Conv :out-chan out-chan :filter-shape (, 3 3) :strides (, stride stride) :padding "SAME")
                  (stax.BatchNorm)))
          (stax.serial
            (stax.FanOut 2)
            (stax.parallel block shortcut)
            stax.FanInSum
            stax.Relu))

        [(stax.Conv :out-chan 16 :filter-shape (, 3 3) :strides (, 1 1) :padding "SAME")
         (stax.BatchNorm)
         stax.Relu

         (IdentityBlock 16)
         (IdentityBlock 16)
         (IdentityBlock 16)

         (ConvBlock 32 :stride 2)
         (IdentityBlock 32)
         (IdentityBlock 32)

         (ConvBlock 64 :stride 2)
         (IdentityBlock 64)
         (IdentityBlock 64)

         (stax.AvgPool :window-shape (, 4 4) :padding "VALID")
         stax.Flatten
         (stax.Dense num-outputs)])
      train-net (.copy net)
      test-net (.copy net))

(for [i (range 9)]
  (.insert train-net (+ 4 (* i 2)) (stax.Dropout :rate 0.8))
  (.insert test-net (+ 4 (* i 2)) (stax.Dropout :rate 0.8 :mode "test")))

(setv [train-images test-images train-labels test-labels] (mnist-data :train-set "fashion_motion_blur"
                                                                      :test-set "fashion"
                                                                      :conv True)
      input-shape (get (np.shape train-images) 1)
      [train-apply calc-loss-train opt-update opt-get new-opt-state] (mnist-train-net net input-shape :conv True)
      [test-apply calc-loss-test] (mnist-test-net net)
      penalty (constantly 0.0)
      opt-step (nn-utils.create-opt-step calc-loss-train penalty opt-update opt-get)
      splitter (StratifiedShuffleSplit :n-splits 3 :train-size (/ 6000 (np.size train-labels)))
      epochs 15)

(mnist-eval train-images test-images train-labels test-labels
            train-apply test-apply calc-loss-test epochs
            opt-get opt-step new-opt-state
            :splitter splitter
            :jax-rng (jax.random.PRNGKey 0)
            :batch-size 32
            :attach-dir attach-dir
            :label-noise 0.3
            :show-progress False)
