{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## On OPT generated data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import glob\n",
    "import face\n",
    "\n",
    "from itertools import chain\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Read spectrum files\n",
    "\n",
    "fft_results_dir = \"../data/experiments_data/opt-original/\"\n",
    "gs_news_dir = \"../data/gs_news/\"\n",
    "gs_story_dir = \"../data/gs_story/\"\n",
    "gs_wiki_dir = \"../data/gs_wiki/\"\n",
    "\n",
    "# Generated data\n",
    "all_opt_files = glob.glob(os.path.join(fft_results_dir, '*.fft.csv'))\n",
    "print(len(all_opt_files))\n",
    "\n",
    "opt_sm_news = [f for f in all_opt_files if 'news' in f and '125m' in f]\n",
    "print(len(opt_sm_news))\n",
    "opt_sm_story = [f for f in all_opt_files if 'story' in f and '125m' in f]\n",
    "print(len(opt_sm_story))\n",
    "opt_sm_wiki = [f for f in all_opt_files if 'wiki' in f and '125m' in f]\n",
    "print(len(opt_sm_wiki))\n",
    "\n",
    "opt_bg_news = [f for f in all_opt_files if 'news' in f and '6.7b' in f]\n",
    "print(len(opt_bg_news))\n",
    "opt_bg_story = [f for f in all_opt_files if 'story' in f and '6.7b' in f]\n",
    "print(len(opt_bg_story))\n",
    "opt_bg_wiki = [f for f in all_opt_files if 'wiki' in f and '6.7b' in f]\n",
    "print(len(opt_bg_wiki))\n",
    "print(sorted([os.path.basename(f) for f in opt_bg_wiki]))\n",
    "\n",
    "# Human data\n",
    "gs_news = glob.glob(os.path.join(gs_news_dir, '*.csv'))\n",
    "print(len(gs_news))\n",
    "\n",
    "gs_story = glob.glob(os.path.join(gs_story_dir, '*.csv'))\n",
    "print(len(gs_story))\n",
    "\n",
    "gs_wiki = glob.glob(os.path.join(gs_wiki_dir, '*.csv'))\n",
    "print(len(gs_wiki))\n",
    "print(sorted([os.path.basename(f) for f in gs_wiki]))\n",
    "\n",
    "# Sort\n",
    "opt_sm_news = sorted(opt_sm_news)\n",
    "opt_sm_story = sorted(opt_sm_story)\n",
    "opt_sm_wiki = sorted(opt_sm_wiki)\n",
    "\n",
    "opt_bg_news = sorted(opt_bg_news)\n",
    "opt_bg_story = sorted(opt_bg_story)\n",
    "opt_bg_wiki = sorted(opt_bg_wiki)\n",
    "\n",
    "gs_news = sorted(gs_news)\n",
    "gs_story = sorted(gs_story)\n",
    "gs_wiki = sorted(gs_wiki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Compute FACE-SO score for each domain and model size group\n",
    "\n",
    "# sm news\n",
    "so_sm_news_list = []\n",
    "for i in range(len(opt_sm_news)):\n",
    "    opt_file = opt_sm_news[i]\n",
    "    gs_file = gs_news[i]\n",
    "    _, _, so_sm_news = face.getSO(opt_file, gs_file)\n",
    "    so_sm_news_list.append(so_sm_news)\n",
    "\n",
    "# sm story\n",
    "so_sm_story_list = []\n",
    "for i in range(len(opt_sm_story)):\n",
    "    opt_file = opt_sm_story[i]\n",
    "    gs_file = gs_story[i]\n",
    "    _, _, so_sm_story = face.getSO(opt_file, gs_file)\n",
    "    so_sm_story_list.append(so_sm_story)\n",
    "\n",
    "# sm wiki\n",
    "so_sm_wiki_list = []\n",
    "for i in range(len(opt_sm_wiki)):\n",
    "    opt_file = opt_sm_wiki[i]\n",
    "    gs_file = gs_wiki[i]\n",
    "    _, _, so_sm_wiki = face.getSO(opt_file, gs_file)\n",
    "    so_sm_wiki_list.append(so_sm_wiki)\n",
    "\n",
    "# bg news\n",
    "so_bg_news_list = []\n",
    "for i in range(len(opt_bg_news)):\n",
    "    opt_file = opt_bg_news[i]\n",
    "    gs_file = gs_news[i]\n",
    "    _, _, so_bg_news = face.getSO(opt_file, gs_file)\n",
    "    so_bg_news_list.append(so_bg_news)\n",
    "\n",
    "# bg story\n",
    "so_bg_story_list = []\n",
    "for i in range(len(opt_bg_story)):\n",
    "    opt_file = opt_bg_story[i]\n",
    "    gs_file = gs_story[i]\n",
    "    _, _, so_bg_story = face.getSO(opt_file, gs_file)\n",
    "    so_bg_story_list.append(so_bg_story)\n",
    "\n",
    "# bg wiki\n",
    "so_bg_wiki_list = []\n",
    "for i in range(len(opt_bg_wiki)):\n",
    "    opt_file = opt_bg_wiki[i]\n",
    "    gs_file = gs_wiki[i]\n",
    "    _, _, so_bg_wiki = face.getSO(opt_file, gs_file)\n",
    "    so_bg_wiki_list.append(so_bg_wiki)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Collect and save results\n",
    "so_sm_news = list(chain.from_iterable(so_sm_news_list))\n",
    "so_sm_story = list(chain.from_iterable(so_sm_story_list))\n",
    "so_sm_wiki = list(chain.from_iterable(so_sm_wiki_list))\n",
    "print(len(so_sm_news), len(so_sm_story), len(so_sm_wiki))\n",
    "\n",
    "so_bg_news = list(chain.from_iterable(so_bg_news_list))\n",
    "so_bg_story = list(chain.from_iterable(so_bg_story_list))\n",
    "so_bg_wiki = list(chain.from_iterable(so_bg_wiki_list))\n",
    "print(len(so_bg_news), len(so_bg_story), len(so_bg_wiki))\n",
    "\n",
    "df_so = pd.DataFrame({'so_sm_news': so_sm_news,\n",
    "                      'so_sm_story': so_sm_story,\n",
    "                      'so_sm_wiki': so_sm_wiki,\n",
    "                      'so_bg_news': so_bg_news,\n",
    "                      'so_bg_story': so_bg_story,\n",
    "                      'so_bg_wiki': so_bg_wiki})\n",
    "df_so.to_csv('OPT_SO.csv', index=False)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
