{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "import itertools\n",
    "from matplotlib import pyplot as plt"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Linear equations\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import itertools"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_par(root):\n",
    "    if root == 0:\n",
    "        return 'x'\n",
    "    elif root > 0:\n",
    "        return f'(x-{root})'\n",
    "    else:\n",
    "        return f'(x+{np.abs(root)})'\n",
    "        \n",
    "def get_coef(roots, power):\n",
    "    s = np.sum([np.prod(cfs) for cfs in itertools.combinations(roots, len(roots)-power)])\n",
    "    return pow(-1, power) * s\n",
    "\n",
    "def get_monome(coef, power, dim=2):\n",
    "    monome = ''\n",
    "\n",
    "    if coef == 0:\n",
    "        return monome\n",
    "\n",
    "    if coef == -1:\n",
    "        monome += '-' if power > 0 else '-1'\n",
    "    elif coef == 1:\n",
    "        if power == 0:\n",
    "            monome += '+1'\n",
    "        elif power < dim:\n",
    "            monome += '+'\n",
    "    elif coef != 1:\n",
    "        if coef < 0:\n",
    "            c = f'{coef}'\n",
    "        elif power == dim:\n",
    "            c = f'{coef}'\n",
    "        else:\n",
    "            c = f'+{coef}'\n",
    "        monome += c\n",
    "\n",
    "        \n",
    "        if power > 0:\n",
    "            monome += '*'\n",
    "        # if power == dim:\n",
    "        #     monome = monome[1:]\n",
    "    \n",
    "    if power == 0:\n",
    "        return monome\n",
    "\n",
    "    if power == 1:\n",
    "        monome += 'x'\n",
    "    elif power > 1:\n",
    "        monome += f'x^{power}'\n",
    "    return monome"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "max_root = 100\n",
    "dim = 2\n",
    "precision=0\n",
    "shuffle_limit = 10\n",
    "\n",
    "# roots = np.random.uniform(-max_root, max_root, dim).round(precision)\n",
    "roots = np.random.uniform(-max_root, max_root, dim).astype(int)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "# paretheses form\n",
    "pars = [get_par(r) for r in sorted(roots)]\n",
    "par_eq = '*'.join(pars) + '=0'\n",
    "\n",
    "# canonic form\n",
    "coeffs = [int(get_coef(roots, n)) for n in range(len(roots)+1)][::-1]\n",
    "canonic_eq = ''.join([get_monome(get_coef(roots, n), n) for n in range(len(roots)+1)][::-1]) + '=0'\n",
    "\n",
    "# discriminant\n",
    "D = int(coeffs[1]**2 - 4*coeffs[0]*coeffs[2])\n",
    "sqd = int(np.sqrt(D))\n",
    "discr_eq = f'D={coeffs[1]}^2-4*{coeffs[0]}*{coeffs[2]}={D}={sqd}^2'\n",
    "solution_eq = f'x=({-coeffs[1]}-{sqd})/{int(2*coeffs[0])}={min(roots)};x=({-coeffs[1]}+{sqd})/{int(2*coeffs[0])}={max(roots)}'\n",
    "\n",
    "# shuffled form\n",
    "shuffle_coeff = np.random.choice([-1, 1]) * round(np.random.uniform(1, shuffle_limit), precision)\n",
    "base_coeffs = np.array(coeffs) * shuffle_coeff\n",
    "\n",
    "if precision == 0:\n",
    "    base_coeffs = base_coeffs.astype(int)\n",
    "base_eq = ''.join([get_monome(coef, n) for coef, n in zip(base_coeffs, range(len(roots)+1)[::-1])]) + '=0'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_equation_solution(max_root = 100,\n",
    "                        dim = 2,\n",
    "                        shuffle_limit = 10,\n",
    "                        with_D = True,\n",
    "                        with_canonic = True):\n",
    "    roots = np.random.uniform(-max_root, max_root, dim).astype(int)\n",
    "    # # paretheses form\n",
    "    # pars = [get_par(r) for r in sorted(roots)]\n",
    "    # par_eq = '*'.join(pars) + '=0'\n",
    "\n",
    "    # canonic form\n",
    "    canonic_coeffs = [int(get_coef(roots, n)) for n in range(len(roots)+1)][::-1]\n",
    "    canonic_eq = ''.join([get_monome(get_coef(roots, n), n) for n in range(len(roots)+1)][::-1]) + '=0'\n",
    "\n",
    "    # shuffled form\n",
    "    shuffle_coeff = np.random.choice([-1, 1]) * round(np.random.uniform(1, shuffle_limit), precision)\n",
    "    base_coeffs = (np.array(canonic_coeffs) * shuffle_coeff).astype(int)\n",
    "\n",
    "    # discriminant\n",
    "    if with_canonic:\n",
    "        coeffs = canonic_coeffs\n",
    "    else:\n",
    "        coeffs = base_coeffs\n",
    "\n",
    "    D = int(coeffs[1]**2 - 4*coeffs[0]*coeffs[2])\n",
    "    sqd = int(np.sqrt(D))\n",
    "    discr_eq = f'D={np.abs(coeffs[1])}^2-4*{coeffs[0]}*{coeffs[2]}={D}={sqd}^2'\n",
    "    solution_eq = f'x=({-coeffs[1]}-{sqd})/{int(2*coeffs[0])}={min(roots)};x=({-coeffs[1]}+{sqd})/{int(2*coeffs[0])}={max(roots)}'\n",
    "\n",
    "    if precision == 0:\n",
    "        base_coeffs = base_coeffs.astype(int)\n",
    "    base_eq = ''.join([get_monome(coef, n) for coef, n in zip(base_coeffs, range(len(roots)+1)[::-1])]) + '=0'\n",
    "\n",
    "    answer = ','.join(np.unique(roots).astype(str))\n",
    "    \n",
    "    solution_equations = []\n",
    "    if with_canonic:\n",
    "        solution_equations.append(canonic_eq)\n",
    "    if with_D:\n",
    "        solution_equations.append(discr_eq)\n",
    "        solution_equations.append(solution_eq)\n",
    "        \n",
    "    solution = ';'.join(solution_equations)\n",
    "    return base_eq, solution, answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[1, 154, 5673]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "coeffs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[-93 -61]\n",
      "(x+93)*(x+61)=0\n",
      "x=(-154-32)/2=-93;x=(-154+32)/2=-61\n",
      "D=154^2-4*1*5673=1024=32^2\n",
      "x^2+154*x+5673=0\n",
      "-8*x^2-1232*x-45384=0\n"
     ]
    }
   ],
   "source": [
    "print(roots)\n",
    "print(par_eq)\n",
    "print(solution_eq)\n",
    "print(discr_eq)\n",
    "print(canonic_eq)\n",
    "print(base_eq)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "eq, sol, ans = gen_equation_solution(with_canonic=True, with_D=True)\n",
    "total = ';'.join((eq, sol, ans))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "def gen_no_roots_equation(max_root=100, with_canonic=True, with_D=True):\n",
    "    b = np.random.randint(max_root*2)\n",
    "    c = np.random.randint(b**2/4, max_root**2)\n",
    "    canonic_coeffs = [1, b, c]\n",
    "\n",
    "    canonic_eq = ''.join([get_monome(c, pow) for c, pow in zip(canonic_coeffs, range(len(canonic_coeffs))[::-1])]) + '=0'\n",
    "\n",
    "    # shuffled form\n",
    "    shuffle_coeff = np.random.choice([-1, 1]) * round(np.random.uniform(1, shuffle_limit), precision)\n",
    "    base_coeffs = np.array(canonic_coeffs) * shuffle_coeff\n",
    "    if precision == 0:\n",
    "        base_coeffs = base_coeffs.astype(int)\n",
    "\n",
    "    # discriminant\n",
    "    coeffs = canonic_coeffs if with_canonic else base_coeffs\n",
    "    D = int(coeffs[1]**2 - 4*coeffs[0]*coeffs[2])\n",
    "    # sqd = int(np.sqrt(D))\n",
    "    discr_eq = f'D={np.abs(coeffs[1])}^2-4*{coeffs[0]}*{coeffs[2]}={D}'\n",
    "    solution_eq = ''#f'x=({-coeffs[1]}-{sqd})/{int(2*coeffs[0])}={min(roots)},x=({-coeffs[1]}+{sqd})/{int(2*coeffs[0])}={max(roots)}'\n",
    "\n",
    "\n",
    "    \n",
    "    base_eq = ''.join([get_monome(coef, n) for coef, n in zip(base_coeffs, range(len(roots)+1)[::-1])]) + '=0'\n",
    "\n",
    "    answer = ''#','.join(np.unique(roots).astype(str))\n",
    "    solution_equations = []\n",
    "    if with_canonic:\n",
    "        solution_equations.append(canonic_eq)\n",
    "    if with_D:\n",
    "        solution_equations.append(discr_eq)\n",
    "        solution_equations.append(solution_eq)\n",
    "    solution = ';'.join(solution_equations)\n",
    "    # equations = [canonic_eq, discr_eq, solution_eq]\n",
    "    # equations = [eq for i, eq in enumerate(equations) if i not in skip_steps]\n",
    "    return base_eq, solution, answer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "\n",
    "class equation_generator:\n",
    "    def __init__(self, power=2, max_root=1000, shuffle_limit=10, batch_size=32, seq_len=30, prob_no_roots=0.2, with_canonic=True, with_D=True, add_eos=False):\n",
    "        self.batch_size = batch_size\n",
    "        self.power = power\n",
    "        self.max_root = max_root\n",
    "        self.shuffle_limit = shuffle_limit\n",
    "        self.set_dict()\n",
    "        self.seq_len = seq_len\n",
    "        self.src_mask = self.tgt_mask = None\n",
    "        self.prob_no_roots = prob_no_roots\n",
    "\n",
    "        self.with_canonic = with_canonic\n",
    "        self.with_D = with_D\n",
    "        \n",
    "        self.add_eos = add_eos\n",
    "        self.generated = []\n",
    "\n",
    "    \n",
    "    def set_dict(self):\n",
    "        keys = [str(i) for i in range(10)]\n",
    "        self.c2token = dict(zip(keys + ['+', '-', '*', '/', '^', '=', '(', ')', 'x', 'D', ';', ','], range(2, 24)))\n",
    "        self.eos = 22\n",
    "        \n",
    "    \n",
    "    def __next__(self):\n",
    "        seq_len = self.seq_len\n",
    "\n",
    "        if self.with_D:\n",
    "            if self.with_canonic:\n",
    "                N = 6\n",
    "            else:\n",
    "                N = 5\n",
    "        else:\n",
    "            N = 2\n",
    "\n",
    "        src = np.zeros([self.batch_size, seq_len*N]).astype(int)\n",
    "        tgt = np.zeros([self.batch_size, seq_len*N]).astype(int)\n",
    "        # src[:, 0] = 1\n",
    "        for i in range(self.batch_size):\n",
    "            if np.random.uniform(0, 1) < self.prob_no_roots:\n",
    "                eq, sol, ans = gen_no_roots_equation(with_canonic=self.with_canonic, with_D=self.with_D)\n",
    "            else:\n",
    "                eq, sol, ans = gen_equation_solution(with_canonic=self.with_canonic, with_D=self.with_D)\n",
    "            total = ';'.join((eq, sol, ans))\n",
    "            self.generated.append({'problem': eq, 'steps': sol, 'answer': ans})\n",
    "            eq_tokens = [self.c2token[c] for c in eq]\n",
    "            sol_tokens = [[self.c2token[c] for c in sol] for sol in sol.split(';')]\n",
    "            ans_tokens = [self.c2token[c] for c in ans]\n",
    "            \n",
    "            src[i, :len(eq_tokens)] = eq_tokens\n",
    "            tgt[i, :len(eq_tokens)-1] = eq_tokens[1:]\n",
    "            \n",
    "            if self.add_eos:\n",
    "                src[i, len(eq_tokens)] = self.eos\n",
    "                tgt[i, len(eq_tokens)-1] = self.eos\n",
    "            \n",
    "            for j, tokens in enumerate(sol_tokens):\n",
    "                src[i, seq_len*(j+1):seq_len*(j+1) + len(tokens)] = tokens\n",
    "                tgt[i, seq_len*(j+1)-1:seq_len*(j+1) + len(tokens)-1] = tokens\n",
    "\n",
    "                if self.add_eos:\n",
    "                    src[i, seq_len*(j+1) + len(tokens)] = self.eos\n",
    "                    tgt[i, seq_len*(j+1) + len(tokens)-1] = self.eos\n",
    "\n",
    "            # print(len(ans_tokens), ans_tokens)\n",
    "            src[i, seq_len*(N-1):seq_len*(N-1)+len(ans_tokens)] = ans_tokens\n",
    "            tgt[i, seq_len*(N-1)-1:seq_len*(N-1)+len(ans_tokens)-1] = ans_tokens\n",
    "            \n",
    "            if self.add_eos:\n",
    "                src[i, seq_len*(N-1)+len(ans_tokens)] = self.eos\n",
    "                tgt[i, seq_len*(N-1)+len(ans_tokens)-1] = self.eos\n",
    "            \n",
    "\n",
    "        return torch.tensor(src), torch.tensor(tgt), self.src_mask, self.tgt_mask  "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = equation_generator(with_canonic=True, with_D=True, add_eos=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 51,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def generate_data(generator, task_name, path='data', train_size=10_000, val_size=1_000, test_size=2_000, batch_size=32):\n",
    "    Xs, ys = [], []\n",
    "    total_size = train_size + test_size + val_size\n",
    "    num_batches = total_size // batch_size * 3\n",
    "    for _ in range(num_batches):\n",
    "        X, y, _, _ = next(generator)\n",
    "        Xs.append(X)\n",
    "        ys.append(y)\n",
    "\n",
    "    Xs = torch.vstack(Xs)\n",
    "    ys = torch.vstack(ys)\n",
    "    \n",
    "    _, inds = np.unique(Xs, axis=0, return_index=True)\n",
    "    inds = np.random.permutation(inds)\n",
    "    \n",
    "    Xs = Xs[inds][:total_size]\n",
    "    ys = ys[inds][:total_size]\n",
    "    \n",
    "    np.save(f'{path}/{task_name}_train_X.npy', Xs[:train_size].cpu())\n",
    "    np.save(f'{path}/{task_name}_train_y.npy', ys[:train_size].cpu())\n",
    "\n",
    "    np.save(f'{path}/{task_name}_val_X.npy', Xs[train_size:train_size+val_size].cpu())\n",
    "    np.save(f'{path}/{task_name}_val_y.npy', ys[train_size:train_size+val_size].cpu())\n",
    "\n",
    "    np.save(f'{path}/{task_name}_test_X.npy', Xs[train_size+val_size:train_size+val_size+test_size].cpu())\n",
    "    np.save(f'{path}/{task_name}_test_y.npy', ys[train_size+val_size:train_size+val_size+test_size].cpu())\n",
    "    return inds"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 52,
   "metadata": {},
   "outputs": [],
   "source": [
    "gen = equation_generator(with_canonic=True, with_D=True)\n",
    "generate_data(gen, task_name='sqeq_cd', path='data', train_size=100_000, val_size=10_000, test_size=20_000, batch_size=32)"
   ]
  }
 ],
 "metadata": {
  "interpreter": {
   "hash": "3ad594eba56fa4d30e478d0eb2c02077805d7e655fd2c7b71fc86bcad8bf7b09"
  },
  "kernelspec": {
   "display_name": "Python 3.9.0 ('cudaenv')",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.0"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
