import torch
import torch.nn as nn
import numpy as np
import rl_utils

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# use robot_pos, robot_mat to predict Lidar
class DomainModel(nn.Module):
    def __init__(self, env, hazards_pos, vases_pos, goal_pos):
        super().__init__()
        self.hazards_pos = hazards_pos
        self.vases_pos = vases_pos
        self.goal_pos = goal_pos
        self.env = env
        self.net = rl_utils.MLP([15 + 2, 256, 256, 15])
        self.lidar_predictor = rl_utils.MLP([2, 64, 64, 16])
        self.eye = torch.eye(self.env.lidar_num_bins, device=device)
        self.log_std = nn.Parameter(torch.zeros(63, device=device), True)

    def obj_lidar_obs(self, angle, dist):
        ret = []
        bin_size = (np.pi * 2) / self.env.lidar_num_bins
        bin = int(angle / bin_size)
        bin_angle = bin_size * bin
        sensor = torch.relu(1 - dist / self.env.lidar_max_dist)
        # obs[bin] = torch.max(obs[bin], sensor)
        ret.append(self.eye[bin] * sensor)
        # Aliasing
        if self.env.lidar_alias:
            alias = (angle - bin_angle) / bin_size
            assert 0 <= alias <= 1, f'bad alias {alias}, dist {dist}, angle {angle}, bin {bin}'
            bin_plus = (bin + 1) % self.env.lidar_num_bins
            bin_minus = (bin - 1) % self.env.lidar_num_bins
            # obs[bin_plus] = torch.max(obs[bin_plus], alias * sensor)
            ret.append(self.eye[bin_plus] * alias * sensor[..., None])
            # obs[bin_minus] = torch.max(obs[bin_minus], (1 - alias) * sensor)
            ret.append(self.eye[bin_minus] * (1 - alias) * sensor[..., None])
        return ret

    def obs_pseudo_lidar(self, positions, ego_xy):
        obj_obs = []
        for pos in positions:
            pos = torch.as_tensor(pos[:2], dtype=torch.float32, device=device)
            x, y = ego_xy(pos)
            dist = torch.hypot(x, y)
            angle = torch.atan2(y, x) % (np.pi * 2)
            # obj_obs.extend(self.obj_lidar_obs(angle, dist))
            # breakpoint()
            obj_obs.append(self.lidar_predictor(angle[..., None], dist[..., None]))
            # obj_obs.extend([self.lidar_predictor(angle, dist)])
        return torch.max(torch.stack(obj_obs), dim=0).values

    def infer_obs_from_state(self, state):
        accelerometer = state[..., :3]
        gyro = state[..., 3:6]
        magnetometer = state[..., 6:9]
        velocimeter = state[..., 9:12]
        robot_pos = state[..., 12:14]
        robot_angle = state[..., 14]

        sin_angle = robot_angle.sin()
        cos_angle = robot_angle.cos()

        def ego_xy(pos):
            direction = pos - robot_pos
            new_x = sin_angle * direction[..., 0] - cos_angle * direction[..., 1]
            new_y = cos_angle * direction[..., 0] + sin_angle * direction[..., 1]
            return new_x, new_y

        goal_lidar = self.obs_pseudo_lidar([self.goal_pos], ego_xy)
        hazard_lidar = self.obs_pseudo_lidar(self.hazards_pos, ego_xy)
        vase_lidar = self.obs_pseudo_lidar(self.vases_pos, ego_xy)

        return torch.cat([
            accelerometer,  # [0, 3)
            goal_lidar,     # [3, 19)
            gyro,           # [19, 22)
            hazard_lidar,   # [22, 38)
            magnetometer,   # [38, 41)
            vase_lidar,     # [41, 57)
            velocimeter,    # [57, 60)
            robot_pos,      # [60, 62)
            robot_angle[..., None],  # [62, 63)
        ], dim=-1)

    def forward(self, states, actions, det=True):
        useful_indices = np.r_[0:3, 19:22, 38:41, 57:60, 60:63]
        output = self.net(states[..., useful_indices], actions) + states[..., useful_indices]
        next_states = self.infer_obs_from_state(output)
        if det:
            return next_states
        return torch.distributions.Normal(next_states, self.log_std.exp())

    def extra_loss(self):
        return 0
