{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 302,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import math\n",
    "import os\n",
    "import pickle\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import scipy.special as ss\n",
    "import scipy.stats as sst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 141,
   "metadata": {},
   "outputs": [],
   "source": [
    "class empty_class:\n",
    "    pass"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 177,
   "metadata": {},
   "outputs": [],
   "source": [
    "args = empty_class()\n",
    "args.d = 12\n",
    "args.k = 6\n",
    "args.n_samples = 10000000\n",
    "args.n_samples_d_f2 = 10000\n",
    "args.seed = 42\n",
    "args.alpha = 1\n",
    "args.gamma = 0.5"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 178,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_nu_samples(args):\n",
    "    fname = os.path.join('nu_samples', \n",
    "                         f'nu_samples_{args.d}_{args.k}_{args.n_samples}_{args.seed}.pkl')\n",
    "    if os.path.exists(fname):\n",
    "        X = pickle.load(open(fname, 'rb'))\n",
    "        return X\n",
    "    else:\n",
    "        q_k_d = ss.jacobi(args.k, (args.d-3)/2.0, (args.d-3)/2.0)\n",
    "        legendre_k_d = q_k_d/q_k_d(1)\n",
    "        X0 = torch.randn(args.n_samples,args.d)\n",
    "        X0 = torch.nn.functional.normalize(X0, p=2, dim=1)\n",
    "        #acceptance_prob = torch.from_numpy(0.49 + 0.49*legendre_k_d(X0[:,args.d-1]))\n",
    "        acceptance_prob = torch.nn.functional.relu(torch.from_numpy(0.99*legendre_k_d(X0[:,args.d-1])))\n",
    "        acceptance_vector = torch.bernoulli(acceptance_prob)\n",
    "        print('Acc. prob. sum:', torch.norm(acceptance_prob, p=1), 'Acc. vec. sum:', torch.norm(acceptance_vector, p=1))\n",
    "        accepted_rows = []\n",
    "        for i in range(args.n_samples):\n",
    "            if acceptance_vector[i] == 1:\n",
    "                accepted_rows.append(i)\n",
    "        accepted_rows_tensor = torch.tensor(accepted_rows).unsqueeze(1).expand([len(accepted_rows),args.d])\n",
    "        X = torch.gather(X0, 0, accepted_rows_tensor)\n",
    "        if not os.path.exists('nu_samples'):\n",
    "            os.makedirs('nu_samples')\n",
    "        pickle.dump(X, open(fname, 'wb'))\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 214,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mu_samples(args):\n",
    "    fname = os.path.join('mu_samples', \n",
    "                         f'mu_samples_{args.d}_{args.k}_{args.n_samples}_{args.seed}.pkl')\n",
    "    if os.path.exists(fname):\n",
    "        X = pickle.load(open(fname, 'rb'))\n",
    "        return X\n",
    "    else:\n",
    "        #X = torch.randn(args.n_samples,args.d)\n",
    "        #X = torch.nn.functional.normalize(X, p=2, dim=1)\n",
    "        q_k_d = ss.jacobi(args.k, (args.d-3)/2.0, (args.d-3)/2.0)\n",
    "        legendre_k_d = q_k_d/q_k_d(1)\n",
    "        X0 = torch.randn(args.n_samples,args.d)\n",
    "        X0 = torch.nn.functional.normalize(X0, p=2, dim=1)\n",
    "        acceptance_prob = torch.nn.functional.relu(torch.from_numpy(-0.99*legendre_k_d(X0[:,args.d-1])))\n",
    "        acceptance_vector = torch.bernoulli(acceptance_prob)\n",
    "        print('Acc. prob. sum:', torch.norm(acceptance_prob, p=1), 'Acc. vec. sum:', torch.norm(acceptance_vector, p=1))\n",
    "        accepted_rows = []\n",
    "        for i in range(args.n_samples):\n",
    "            if acceptance_vector[i] == 1:\n",
    "                accepted_rows.append(i)\n",
    "        accepted_rows_tensor = torch.tensor(accepted_rows).unsqueeze(1).expand([len(accepted_rows),args.d])\n",
    "        X = torch.gather(X0, 0, accepted_rows_tensor)\n",
    "        if not os.path.exists('mu_samples'):\n",
    "            os.makedirs('mu_samples')\n",
    "        pickle.dump(X, open(fname, 'wb'))\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "q_k_d = ss.jacobi(args.k, (args.d-3)/2.0, (args.d-3)/2.0)\n",
    "legendre_k_d = q_k_d/q_k_d(1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 180,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc. prob. sum: tensor(30889.0918) Acc. vec. sum: tensor(30813.)\n"
     ]
    }
   ],
   "source": [
    "X_nu = get_nu_samples(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 181,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Acc. prob. sum: tensor(30745.7695) Acc. vec. sum: tensor(30843.)\n"
     ]
    }
   ],
   "source": [
    "X_mu = get_mu_samples(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 182,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([30813, 12]) torch.Size([30843, 12])\n",
      "torch.Size([30813, 12]) torch.Size([30813, 12])\n"
     ]
    }
   ],
   "source": [
    "print(X_nu.shape, X_mu.shape)\n",
    "min_num = np.min([X_nu.shape[0],X_mu.shape[0]])\n",
    "X_nu = X_nu[:(min_num),:]\n",
    "X_mu = X_mu[:(min_num),:]\n",
    "print(X_nu.shape, X_mu.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 183,
   "metadata": {},
   "outputs": [],
   "source": [
    "def d_f1_estimate(X_nu, X_mu, a, b):\n",
    "    gen_moment_nu_positive = a*torch.mean(torch.nn.functional.relu(X_nu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(-X_nu[:,args.d-1]))\n",
    "    gen_moment_nu_negative = a*torch.mean(torch.nn.functional.relu(-X_nu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(X_nu[:,args.d-1]))\n",
    "    gen_moment_mu_positive = a*torch.mean(torch.nn.functional.relu(X_mu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(-X_mu[:,args.d-1]))\n",
    "    gen_moment_mu_negative = a*torch.mean(torch.nn.functional.relu(-X_mu[:,args.d-1])) + b*torch.mean(torch.nn.functional.relu(X_mu[:,args.d-1]))\n",
    "    return torch.max(torch.abs(gen_moment_nu_positive - gen_moment_mu_positive),torch.abs(gen_moment_nu_negative - gen_moment_mu_negative))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 184,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0325)\n"
     ]
    }
   ],
   "source": [
    "print(d_f1_estimate(X_nu, X_mu, 1, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 185,
   "metadata": {},
   "outputs": [],
   "source": [
    "N_kd = (2*args.k + args.d - 2) * math.factorial(args.k + args.d - 3) / (math.factorial(args.k) * math.factorial(args.d -2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 186,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "11011.0 104.93331215586402\n"
     ]
    }
   ],
   "source": [
    "print(N_kd, np.sqrt(N_kd))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 187,
   "metadata": {},
   "outputs": [],
   "source": [
    "lambda_kd = (args.d - 2)*math.factorial(args.alpha)*math.gamma((args.d-1)/2)*math.gamma(args.k-args.alpha)/ \\\n",
    "((2**args.k)*math.gamma((args.k-args.alpha+1)/2)*math.gamma((args.k+args.d+args.alpha)/2))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 188,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1215.5000000000005\n"
     ]
    }
   ],
   "source": [
    "print(1/lambda_kd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 189,
   "metadata": {},
   "outputs": [],
   "source": [
    "def f2_kernel_evaluation(X0, X1, fill_diag = True):\n",
    "    #X = torch.cat((X_mu, -X_nu), 0)\n",
    "    if fill_diag:\n",
    "        inner_prod = torch.matmul(X0,X1.t()).fill_diagonal_(fill_value = 1)\n",
    "    else:\n",
    "        inner_prod = torch.matmul(X0,X1.t())\n",
    "    values = ((np.pi-torch.acos(inner_prod))*inner_prod \\\n",
    "            + torch.sqrt(1-inner_prod*inner_prod))/(2*np.pi*(args.d+1))\n",
    "    return values \n",
    "    #2*np.sqrt(torch.mean(values))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 190,
   "metadata": {},
   "outputs": [],
   "source": [
    "def d_f2_estimate_exact_kernel(X_nu, X_mu, a, b):\n",
    "    kernel_eval_X_mu_X_mu = f2_kernel_evaluation(X_mu, X_mu)\n",
    "    kernel_eval_X_nu_X_nu = f2_kernel_evaluation(X_nu, X_nu)\n",
    "    kernel_eval_X_mu_X_nu = f2_kernel_evaluation(X_mu, X_nu, fill_diag = False)\n",
    "    return np.sqrt(torch.mean(kernel_eval_X_mu_X_mu) + torch.mean(kernel_eval_X_nu_X_nu) - 2*torch.mean(kernel_eval_X_mu_X_nu))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 191,
   "metadata": {},
   "outputs": [],
   "source": [
    "def d_f2_estimate(X_nu, X_mu, a, b):\n",
    "    Y0 = torch.randn(args.d,args.n_samples_d_f2)\n",
    "    Y0 = torch.nn.functional.normalize(Y0, p=2, dim=0)\n",
    "    gen_moment_nu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0)\n",
    "    gen_moment_nu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0)\n",
    "    gen_moment_mu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0)\n",
    "    gen_moment_mu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0)\n",
    "    d_f2_sq = torch.mean(0.5*(gen_moment_nu_positive-gen_moment_mu_positive)**2 + 0.5*(gen_moment_nu_negative-gen_moment_mu_negative)**2)\n",
    "    return torch.sqrt(d_f2_sq)\n",
    "    #return torch.max(torch.abs(gen_moment_nu_positive - gen_moment_mu_positive),torch.abs(gen_moment_nu_negative - gen_moment_mu_negative))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 192,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0013)\n"
     ]
    }
   ],
   "source": [
    "print(d_f2_estimate(X_nu, X_mu, 1, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 193,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0063)\n"
     ]
    }
   ],
   "source": [
    "print(d_f2_estimate_exact_kernel(X_mu, X_nu, 1, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 174,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gamma_ratio(X_nu, X_mu):\n",
    "    return torch.mean(torch.nn.functional.relu(X_nu[:,args.d-1])) - torch.mean(torch.nn.functional.relu(X_mu[:,args.d-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 175,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gamma_ratio_2():\n",
    "    X0 = torch.randn(10000000,args.d)\n",
    "    X0 = torch.nn.functional.normalize(X0, p=2, dim=1)\n",
    "    q_k_d = ss.jacobi(args.k, (args.d-3)/2.0, (args.d-3)/2.0)\n",
    "    legendre_k_d = q_k_d/q_k_d(1)\n",
    "    acceptance_prob_plus = torch.nn.functional.relu(torch.from_numpy(0.99*legendre_k_d(X0[:,args.d-1])))\n",
    "    acceptance_prob_minus = torch.nn.functional.relu(torch.from_numpy(-0.99*legendre_k_d(X0[:,args.d-1])))\n",
    "    print(torch.sum(acceptance_prob_plus), torch.sum(acceptance_prob_minus))\n",
    "    return torch.sum(torch.nn.functional.relu(X0[:,args.d-1])*acceptance_prob_plus)/torch.sum(acceptance_prob_plus) \\\n",
    "        - torch.sum(torch.nn.functional.relu(X0[:,args.d-1])*acceptance_prob_minus)/torch.sum(acceptance_prob_minus)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 176,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(-0.0253)\n",
      "tensor(5748.1133) tensor(5755.8169)\n",
      "tensor(-0.0217)\n"
     ]
    }
   ],
   "source": [
    "print(gamma_ratio(X_nu, X_mu))\n",
    "print(gamma_ratio_2())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(4.3692)\n"
     ]
    }
   ],
   "source": [
    "print(gamma_ratio(X_nu, X_mu)/d_f2_estimate_exact_kernel(X_mu, X_nu, 1, 0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 231,
   "metadata": {},
   "outputs": [],
   "source": [
    "argssd = empty_class()\n",
    "argssd.d = 12\n",
    "argssd.k = 6\n",
    "argssd.n_samples = 1000\n",
    "argssd.n_samples_d_f2 = 10000\n",
    "argssd.seed = 42\n",
    "argssd.alpha = 1\n",
    "argssd.gamma = 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 202,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sd_f1_theoretical():\n",
    "    if (argssd.k%2 != argssd.alpha%2) and (argssd.k > argssd.alpha + 1):\n",
    "        lambda_alpha_p1_k_d = (argssd.d-2)*math.factorial(argssd.alpha + 1)*math.gamma((argssd.d-1)/2)* \\\n",
    "        math.gamma(argssd.k - argssd.alpha)/(2*np.pi*(2**argssd.k)*math.gamma((argssd.k - argssd.alpha + 1)/2)*math.gamma((argssd.k + argssd.d + argssd.alpha - 1)/2))\n",
    "    else:\n",
    "        lambda_alpha_p1_k_d = 0\n",
    "    result = argssd.gamma*lambda_alpha_p1_k_d*argssd.k*(argssd.d-3)/(argssd.alpha+1)\n",
    "    return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 203,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.020919505127274286\n"
     ]
    }
   ],
   "source": [
    "print(sd_f1_theoretical())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 204,
   "metadata": {},
   "outputs": [],
   "source": [
    "def sd_ratio_lower_bound_theoretical():\n",
    "    N_kd = (2*args.k + args.d - 2) * math.factorial(args.k + args.d - 3) / (math.factorial(args.k) * math.factorial(args.d -2))\n",
    "    numerator = argssd.k*(argssd.d-3)/(argssd.alpha+1)\n",
    "    denominator = np.sqrt(2*(argssd.k*(argssd.k + argssd.d - 2)*(argssd.d + argssd.alpha - 2)**2/(argssd.alpha + 1)**2 + numerator**2))\n",
    "    return numerator/(denominator*np.sqrt(N_kd))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 205,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.0030185799710393704\n"
     ]
    }
   ],
   "source": [
    "print(sd_ratio_lower_bound_theoretical())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 288,
   "metadata": {},
   "outputs": [],
   "source": [
    "def score_function(X):\n",
    "    derivative_factor = argssd.k*(argssd.k + argssd.d - 2)/(argssd.d - 1)\n",
    "    q_km1_dp2 = ss.jacobi(argssd.k-1, (argssd.d-1)/2.0, (argssd.d-3)/2.0)\n",
    "    legendre_km1_dp2 = q_km1_dp2/q_km1_dp2(1)\n",
    "    e_d = torch.zeros(1,argssd.d)\n",
    "    e_d[0,argssd.d-1] = 1\n",
    "    #print(torch.from_numpy(legendre_km1_dp2(X[:,args.d-1])).shape, e_d.repeat(argssd.n_samples,argssd.d).shape)\n",
    "    result = argssd.gamma*derivative_factor*torch.from_numpy(legendre_km1_dp2(X[:,args.d-1]))*e_d.repeat(argssd.n_samples,1)\n",
    "    #print(result.shape, X.squeeze(0).shape, torch.sum((X.squeeze(0)*result), dim=1).unsqueeze(1).shape)\n",
    "    result = result - torch.sum((X.squeeze(0)*result), dim=1).unsqueeze(1)*X.squeeze(0)\n",
    "    return result\n",
    "    \n",
    "#careful with X squeeze/unsqueeze\n",
    "\n",
    "def f2_kernel_evaluation(X):\n",
    "    dimension = X.shape[1]\n",
    "    inner_prod = torch.matmul(X,X.t()).fill_diagonal_(fill_value = 1)\n",
    "    return ((np.pi-torch.acos(inner_prod))*inner_prod \\\n",
    "            + torch.sqrt(1-inner_prod*inner_prod))/(2*np.pi*(dimension+1))\n",
    "\n",
    "def f2_kernel_derivatives(X):\n",
    "    dimension = X.shape[1]\n",
    "    inner_prod = torch.matmul(X,X.t()).fill_diagonal_(fill_value = 1)\n",
    "    gradient = (np.pi-torch.acos(inner_prod)).fill_diagonal_(fill_value  = 0).unsqueeze(2)*X.unsqueeze(0)/(2*np.pi*(dimension+1))\n",
    "    print(gradient.shape, X.shape)\n",
    "    return gradient - torch.sum((X.squeeze(0)*gradient), dim=1).unsqueeze(1)*X.squeeze(0) #fix this\n",
    "    #return ((np.pi-torch.acos(inner_prod)) \\\n",
    "            #+ 2*(inner_prod/torch.sqrt(1-inner_prod*inner_prod)).fill_diagonal_(fill_value  = 0)).unsqueeze(2)*X.unsqueeze(0)/(2*np.pi*(dimension+1))\n",
    "        \n",
    "    #subtract gradients\n",
    "def f2_kernel_trace(X):\n",
    "    dimension = X.shape[1]\n",
    "    inner_prod = torch.matmul(X,X.t()).fill_diagonal_(fill_value = 1)\n",
    "    return argssd.d*(np.pi-torch.acos(inner_prod)) + inner_prod/torch.sqrt(1-inner_prod*inner_prod)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 279,
   "metadata": {},
   "outputs": [],
   "source": [
    "def KSD_computation(X,kernel_eval,kernel_der,kernel_tr):\n",
    "    n_samples = X.shape[0]\n",
    "    data_dim = X.shape[1]\n",
    "    score = score_function(X.unsqueeze(0))\n",
    "    #print(score.shape, X.t().shape)\n",
    "    score_corrected = score.t() - (data_dim-1)*X.t()\n",
    "    first_term = torch.sum((torch.matmul(torch.t(score_corrected), score_corrected)*kernel_eval).fill_diagonal_(fill_value = 0))/(n_samples*(n_samples-1))\n",
    "    second_term = 2*torch.sum(torch.sum(kernel_der*score_corrected.t().unsqueeze(0), dim=2).fill_diagonal_(fill_value = 0))/(n_samples*(n_samples-1))\n",
    "    third_term = torch.sum(kernel_tr.fill_diagonal_(fill_value = 0))/(n_samples*(n_samples-1))\n",
    "    #fourth_term = torch.sum((torch.matmul(torch.t(score), score)*kernel_eval).fill_diagonal_(fill_value = 0))/(n_samples*(n_samples-1))\n",
    "    #stein_identity = torch.norm(torch.mean(score_corrected.t().unsqueeze(0)*kernel_eval.unsqueeze(2) + torch.transpose(kernel_der,0,1), dim=1), p=1)/(n_samples*data_dim)\n",
    "    #stein_identity_first = torch.norm(torch.mean(score_corrected.t().unsqueeze(0)*kernel_eval.unsqueeze(2), dim=1), p=1)/(n_samples*data_dim)\n",
    "    #stein_identity_second = torch.norm(torch.mean(torch.transpose(kernel_der,0,1), dim=1), p=1)/(n_samples*data_dim)\n",
    "    #stein_identity_third = torch.norm(torch.mean(score.t().unsqueeze(0)*kernel_eval.unsqueeze(2), dim=1), p=1)/(n_samples*data_dim)\n",
    "    #return (first_term + second_term + third_term, first_term, second_term, third_term, fourth_term, stein_identity, stein_identity_first, stein_identity_second, stein_identity_third)\n",
    "    print('first term', first_term)\n",
    "    print('second term', second_term)\n",
    "    print('third term', third_term)\n",
    "    return first_term + second_term + third_term"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 235,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_mu_samples_sd(argssd):\n",
    "    fname = os.path.join('mu_samples_sd', \n",
    "                         f'mu_samples_sd_{argssd.d}_{argssd.k}_{argssd.n_samples}_{argssd.seed}.pkl')\n",
    "    if os.path.exists(fname):\n",
    "        X = pickle.load(open(fname, 'rb'))\n",
    "        return X\n",
    "    else:\n",
    "        X = torch.randn(argssd.n_samples,argssd.d)\n",
    "        X = torch.nn.functional.normalize(X, p=2, dim=1)\n",
    "        if not os.path.exists('mu_samples_sd'):\n",
    "            os.makedirs('mu_samples_sd')\n",
    "        pickle.dump(X, open(fname, 'wb'))\n",
    "        return X"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 236,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_mu_sd = get_mu_samples_sd(argssd)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 237,
   "metadata": {},
   "outputs": [],
   "source": [
    "def KSD_f2():\n",
    "    kernel_eval_X = f2_kernel_evaluation(X_mu_sd)\n",
    "    kernel_der_X = f2_kernel_derivatives(X_mu_sd)\n",
    "    kernel_tr_X = f2_kernel_trace(X_mu_sd)\n",
    "    ksd_value = KSD_computation(X_mu_sd,kernel_eval_X,kernel_der_X,kernel_tr_X)\n",
    "    return ksd_value"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 289,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([1000, 1000, 12]) torch.Size([1000, 12])\n",
      "first term tensor(0.1924)\n",
      "second term tensor(34.8398)\n",
      "third term tensor(18.8411)\n",
      "tensor(53.8734)\n"
     ]
    }
   ],
   "source": [
    "print(KSD_f2())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 234,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "torch.Size([10000000, 12])\n"
     ]
    }
   ],
   "source": [
    "print(X_mu_sd.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 325,
   "metadata": {},
   "outputs": [],
   "source": [
    "argsw = empty_class()\n",
    "argsw.d = 12\n",
    "argsw.n_samples = 10000\n",
    "argsw.avg_samples = 10000\n",
    "argsw.seed = 42\n",
    "argsw.alpha = 1\n",
    "argsw.large_var = 1\n",
    "argsw.small_var = 0.1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 326,
   "metadata": {},
   "outputs": [],
   "source": [
    "mu_variance = argsw.large_var*np.identity(argsw.d)\n",
    "nu_variance = argsw.large_var*np.identity(argsw.d)\n",
    "nu_variance[argsw.d-1, argsw.d-1] = argsw.small_var"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 327,
   "metadata": {},
   "outputs": [],
   "source": [
    "X_mu_w = np.random.multivariate_normal(np.zeros(argsw.d), mu_variance, argsw.n_samples)\n",
    "X_nu_w = np.random.multivariate_normal(np.zeros(argsw.d), nu_variance, argsw.n_samples)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 328,
   "metadata": {},
   "outputs": [],
   "source": [
    "def max_sliced(X_mu_w,X_nu_w):\n",
    "    return sst.wasserstein_distance(X_mu_w[:,argsw.d-1], X_nu_w[:,argsw.d-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 329,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.5346138008407564\n"
     ]
    }
   ],
   "source": [
    "print(max_sliced(X_mu_w,X_nu_w))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 330,
   "metadata": {},
   "outputs": [],
   "source": [
    "def avg_sliced(X_mu_w,X_nu_w):\n",
    "    Y0 = torch.randn(argsw.d,argsw.avg_samples)\n",
    "    Y0 = torch.nn.functional.normalize(Y0, p=2, dim=0).double()\n",
    "    average = 0\n",
    "    for i in range(argsw.avg_samples):\n",
    "        #print(torch.from_numpy(X_mu_w).dtype, Y0[:,i].dtype)\n",
    "        #print(torch.matmul(torch.from_numpy(X_mu_w),Y0[:,i]).shape)\n",
    "        average = average + sst.wasserstein_distance(torch.matmul(torch.from_numpy(X_mu_w),Y0[:,i]), torch.matmul(torch.from_numpy(X_nu_w),Y0[:,i]))\n",
    "    average = average/argsw.avg_samples\n",
    "    return average"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 331,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0.041188734348935045\n"
     ]
    }
   ],
   "source": [
    "print(avg_sliced(X_mu_w,X_nu_w))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 341,
   "metadata": {},
   "outputs": [],
   "source": [
    "def d_f1_estimate_w(X_nu_w, X_mu_w, a, b):\n",
    "    ones_d = torch.ones(argsw.n_samples,1).double()\n",
    "    X_mu = torch.cat((torch.from_numpy(X_mu_w),ones_d), 1)\n",
    "    X_nu = torch.cat((torch.from_numpy(X_nu_w),ones_d), 1)\n",
    "    gen_moment_nu_positive = a*torch.mean(torch.nn.functional.relu(torch.sqrt(X_nu[:,args.d-1]**2 + 1))) + b*torch.mean(torch.nn.functional.relu(-torch.sqrt(X_nu[:,args.d-1]**2 + 1)))\n",
    "    gen_moment_nu_negative = a*torch.mean(torch.nn.functional.relu(-torch.sqrt(X_nu[:,args.d-1]**2 + 1))) + b*torch.mean(torch.nn.functional.relu(torch.sqrt(X_nu[:,args.d-1]**2 + 1)))\n",
    "    gen_moment_mu_positive = a*torch.mean(torch.nn.functional.relu(torch.sqrt(X_mu[:,args.d-1]**2 + 1))) + b*torch.mean(torch.nn.functional.relu(-torch.sqrt(X_mu[:,args.d-1]**2 + 1)))\n",
    "    gen_moment_mu_negative = a*torch.mean(torch.nn.functional.relu(-torch.sqrt(X_mu[:,args.d-1]**2 + 1))) + b*torch.mean(torch.nn.functional.relu(torch.sqrt(X_mu[:,args.d-1]**2 + 1)))\n",
    "    return torch.max(torch.abs(gen_moment_nu_positive - gen_moment_mu_positive),torch.abs(gen_moment_nu_negative - gen_moment_mu_negative))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 342,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.3024, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(d_f1_estimate_w(X_mu_w,X_nu_w,1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 346,
   "metadata": {},
   "outputs": [],
   "source": [
    "def d_f2_estimate_w(X_nu_w, X_mu_w, a, b):\n",
    "    ones_d = torch.ones(argsw.n_samples,1).double()\n",
    "    X_mu = torch.cat((torch.from_numpy(X_mu_w),ones_d), 1)\n",
    "    X_nu = torch.cat((torch.from_numpy(X_nu_w),ones_d), 1)\n",
    "    Y0 = torch.randn(argsw.d+1,argsw.avg_samples)\n",
    "    Y0 = torch.nn.functional.normalize(Y0, p=2, dim=0).double()\n",
    "    gen_moment_nu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0)\n",
    "    gen_moment_nu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0)\n",
    "    gen_moment_mu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0)\n",
    "    gen_moment_mu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0)\n",
    "    d_f2_sq = torch.mean(0.5*(gen_moment_nu_positive-gen_moment_mu_positive)**2 + 0.5*(gen_moment_nu_negative-gen_moment_mu_negative)**2)\n",
    "    return torch.sqrt(d_f2_sq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 347,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0255, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(d_f2_estimate_w(X_mu_w,X_nu_w,1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 362,
   "metadata": {},
   "outputs": [],
   "source": [
    "def d_tilde_f2_estimate_w(X_nu_w, X_mu_w, a, b):\n",
    "    ones_d = torch.ones(argsw.n_samples,1).double()\n",
    "    X_mu = torch.cat((torch.from_numpy(X_mu_w),ones_d), 1)\n",
    "    X_nu = torch.cat((torch.from_numpy(X_nu_w),ones_d), 1)\n",
    "    Z0 = torch.randn(argsw.d,argsw.avg_samples)\n",
    "    Z0 = torch.nn.functional.normalize(Z0, p=2, dim=0).double()\n",
    "    w0 = (np.pi*torch.rand(argsw.avg_samples) - 0.5*np.pi).unsqueeze(0).double()\n",
    "    #print((torch.cos(w0)*Z0).shape,w0.shape)\n",
    "    Y0 = torch.cat((torch.cos(w0)*Z0,torch.sin(w0)),0)\n",
    "    gen_moment_nu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0)\n",
    "    gen_moment_nu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_nu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_nu,Y0)), dim=0)\n",
    "    gen_moment_mu_positive = a*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0)\n",
    "    gen_moment_mu_negative = a*torch.mean(torch.nn.functional.relu(-torch.matmul(X_mu,Y0)), dim=0) + b*torch.mean(torch.nn.functional.relu(torch.matmul(X_mu,Y0)), dim=0)\n",
    "    d_f2_sq = torch.mean(0.5*(gen_moment_nu_positive-gen_moment_mu_positive)**2 + 0.5*(gen_moment_nu_negative-gen_moment_mu_negative)**2)\n",
    "    return torch.sqrt(d_f2_sq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 363,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "tensor(0.0168, dtype=torch.float64)\n"
     ]
    }
   ],
   "source": [
    "print(d_tilde_f2_estimate_w(X_mu_w,X_nu_w,1,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
