
import gym
from gym import spaces
from gym.utils import seeding
import numpy as np
from os import path
import scipy.integrate
solve_ivp = scipy.integrate.solve_ivp

class TwoBodyEnv(gym.Env):
    metadata = {
        'render.modes' : ['human', 'rgb_array'],
        'video.frames_per_second' : 30
    }

    def __init__(self, g=10.0):
        self.max_speed=100.
        self.max_torque=10.
        self.dt=.05
        self.g = g
        self.viewer = None

        high = np.array([1., 1., self.max_speed])
        self.action_space = spaces.Box(low=-self.max_torque, high=self.max_torque, shape=(1,), dtype=np.float32)
        self.observation_space = spaces.Box(low=-high, high=high, dtype=np.float32)

        self.seed()

        self.L = 1
        self.mu = 0.1

    def seed(self, seed=None):
        self.np_random, seed = seeding.np_random(seed)
        return [seed]

    ##### INITIALIZE THE TWO BODIES #####
    def random_config(self,orbit_noise=5e-2, min_radius=0.5, max_radius=1.5,mass_1=1.0,mass_2=1.0):
        state = np.zeros((2,5))
        state[0,0] = mass_1
        state[1,0] = mass_2
        pos = np.random.rand(2) * (max_radius-min_radius) + min_radius
        if np.random.rand() > 0.5:
            pos[0] = - pos[0]
        if np.random.rand() > 0.5:
            pos[1] = - pos[1]

        #print('init pos:',pos)

        r = np.sqrt( np.sum((pos**2)) )
        
        # velocity that yields a circular orbit
        vel = np.flipud(pos) / (2 * r**1.5)
        #print('vel:',vel)
        vel[0] *= -1
        vel *= 1 + orbit_noise*np.random.randn()

        # make the circular orbits SLIGHTLY elliptical
        state[:,1:3] = pos
        state[:,3:5] = vel
        state[1,1:] *= -1


        

        return state

    def get_accelerations(self, state, epsilon=0):
        # shape of state is [bodies x properties]
        net_accs = [] # [nbodies x 2]
        for i in range(state.shape[0]): # number of bodies
            other_bodies = np.concatenate([state[:i, :], state[i+1:, :]], axis=0)
            displacements = other_bodies[:, 1:3] - state[i, 1:3] # indexes 1:3 -> pxs, pys
            distances = (displacements**2).sum(1, keepdims=True)**0.5
            masses = other_bodies[:, 0:1] # index 0 -> mass
            pointwise_accs = masses * displacements / (distances**3 + epsilon) # G=1
            net_acc = pointwise_accs.sum(0, keepdims=True)
            net_accs.append(net_acc)
        net_accs = np.concatenate(net_accs, axis=0)
        return net_accs

    def dynamics(self, t, y):
        y = y.reshape(-1,5) # [bodies, properties]
        # y is the state here
        deriv = np.zeros_like(y)
        deriv[:,1:3] = y[:,3:5] # dx, dy = vx, vy
        deriv[:,3:5] = self.get_accelerations(y)

        return deriv.reshape(-1)

    def step(self):

        nbodies = self.state.shape[0]

        dt = self.dt

        ivp = solve_ivp(fun=lambda t, y:self.dynamics(t, y), t_span=[0, self.dt], y0=self.state.flatten())
        self.state = ivp.y[:, -1]
        self.state = self.state.reshape(-1,5)

        return self._get_obs(), 0, False, {}

    def reset(self,orbit_noise, min_radius, max_radius,mass_1, mass_2,dt):

        self.state = self.random_config(orbit_noise, min_radius, max_radius, mass_1, mass_2)
        # the size of self.state is [2,5], where 2 means there are two balls with size of state 5
        # state is for each ball is [m, qx, qy, px, py] -> q:position; p:momentum
        # Set the time interval between each point
        self.dt = dt

        return self._get_obs()

    def _get_obs(self):

        # Compute r
        r_pos = (self.state[0,1] ** 2 + self.state[0,2] ** 2) ** 0.5
        
        #print('---------')
        #print('r pos:',r_pos)
        r_vel = (self.state[0,3] ** 2 + self.state[0,4] ** 2) ** 0.5
        #print('r vel:',r_vel,(1.0/(4*r_vel**2)))
        #print('m1 qx, m1 qy:',self.state[0,1:3])
        #print('m1 px, m1 py:',-1*self.state[0,2]/(2 * r**1.5), self.state[0,1]/(2 * r**1.5))
        #print('m2 qx, m2 qy:',self.state[1,1:3])
        #print('m2 px, m2 py:', 1*self.state[0,2]/(2 * r**1.5),-self.state[0,1]/(2 * r**1.5))
        #print('self.state:',self.state)
        #print('self.state:',r)
        #print('r:',(self.state[0,3]**2+self.state[0,4]**2)**0.5)


        return self.state.flatten()

    def render(self, mode='human'):

        ball_size = 0.25

        if self.viewer is None:
            from myenv import rendering
            self.viewer = rendering.Viewer(32,32)
            self.viewer.set_bounds(-2.2,2.2,-2.2,2.2)
            ball_1 = rendering.make_circle(ball_size)
            ball_1.set_color(0, 1, 0)
            self.ball_1 = rendering.Transform()
            ball_1.add_attr(self.ball_1)
            self.viewer.add_geom(ball_1)

            ball_2 = rendering.make_circle(ball_size)
            ball_2.set_color(0, 0, 1)
            self.ball_2 = rendering.Transform()
            ball_2.add_attr(self.ball_2)
            self.viewer.add_geom(ball_2)

        # set the positin of the ball using q (position) signa
        self.ball_1.set_translation(self.state[0][1],self.state[0][2])
        self.ball_2.set_translation(self.state[1][1],self.state[1][2])

        return self.viewer.render(return_rgb_array = mode=='rgb_array')

    def close(self):
        if self.viewer:
            self.viewer.close()
            self.viewer = None

def angle_normalize(x):
    return (((x+np.pi) % (2*np.pi)) - np.pi)