import numpy as np
from typing import Union
import sys
sys.path.append('./')
from est.models.Model import Model



class EulerMaruyamaStepper:
    def __init__(self, model: Model):
        """
        Euler Simulation Step
        :param model: the SDE model
        """
        self._model = model

    def next(self,
             t: float,
             dt: float,
             x: Union[float, np.ndarray],
             dZ: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        """
        Given the current state, and random variate(s), evolves state by one step over time increment dt
        :param t: float, current time
        :param dt: float, time increment (between now and next state transition)
        :param x: Union[float, np.ndarray], current state
        :param dZ: Union[float, np.ndarray], normal random variates, N(0,I), to evolve current state
        :return: next state, after evolving by one step
        """
        xp = x + (self._model.drift(x) * dt).flatten() \
             + (self._model.diffusion(x) @ dZ * np.sqrt(dt)).flatten()
        return xp
    
    def __call__(self,
                 t: float,
                 dt: float,
                 x: Union[float, np.ndarray],
                 dZ: Union[float, np.ndarray]) -> Union[float, np.ndarray]:
        """ Same as a call to next() """
        return self.next(t=t, dt=dt, x=x, dZ=dZ)