import numpy as np
from PIL import Image
from scipy.stats import special_ortho_group
from scipy.linalg import logm, expm
import torch
from torch.utils.data import Dataset, DataLoader

import matplotlib.pyplot as plt
import io
import matplotlib
matplotlib.use('agg')

def read_off(filename):
    with open(filename, 'r') as file:
        if 'OFF' != file.readline().strip():
            raise('Not a valid OFF header')
        n_verts, n_faces, n_dontknow = tuple([int(s) for s in file.readline().strip().split(' ')])
        verts = [[float(s) for s in file.readline().strip().split(' ')] for i_vert in range(n_verts)]
        faces = [[int(s) for s in file.readline().strip().split(' ')][1:] for i_face in range(n_faces)]
    return verts, faces


def render(vertices, triangles, R, lim=0.95):
    vertices = (R @ vertices.T).T
    x = vertices[:, 0]
    y = vertices[:, 1]
    z = vertices[:, 2]

    fig = plt.figure(figsize=(4, 4), dpi=128)
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111, projection='3d')
    ax.set_facecolor('black')
    ax.set_xlim([-lim, lim])
    ax.set_ylim([-lim, lim])
    ax.set_zlim([-lim, lim])
    ax.axis('off')
    ax.plot_trisurf(
        x, z, triangles, y, shade=True, color=(0.5, 0.5, 0.5),
        edgecolor='none', linewidth=0, antialiased=False, alpha=1.0
    )

    images = []

    ax.view_init(elev=0, azim=0)
    # taken from https://stackoverflow.com/a/61443397/3090085
    with io.BytesIO() as io_buf:
        fig.savefig(io_buf, facecolor='black', format='raw', dpi=128)
        io_buf.seek(0)
        images.append(
            np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
            newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
        )
    # -------------------------------------------------------

    ax.view_init(elev=0, azim=90)
    # taken from https://stackoverflow.com/a/61443397/3090085
    with io.BytesIO() as io_buf:
        fig.savefig(io_buf, facecolor='black', format='raw', dpi=128)
        io_buf.seek(0)
        images.append(
            np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
            newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
        )
    # -------------------------------------------------------

    ax.view_init(elev=90, azim=0)
    # taken from https://stackoverflow.com/a/61443397/3090085
    with io.BytesIO() as io_buf:
        fig.savefig(io_buf, facecolor='black', format='raw', dpi=128)
        io_buf.seek(0)
        images.append(
            np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
            newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
        )
    # -------------------------------------------------------

    plt.close('all')
    images = [np.array(Image.fromarray(x).resize((48, 48), resample=Image.BICUBIC).convert('L')) for x in images]
    return np.stack(images, axis=0)

def render_one(vertices, triangles, R, lim=0.95):
    vertices = (R @ vertices.T).T
    x = vertices[:, 0]
    y = vertices[:, 1]
    z = vertices[:, 2]

    fig = plt.figure(figsize=(4, 4), dpi=128)
    fig.set_tight_layout(True)
    ax = fig.add_subplot(111, projection='3d')
    ax.set_facecolor('black')
    ax.set_xlim([-lim, lim])
    ax.set_ylim([-lim, lim])
    ax.set_zlim([-lim, lim])
    ax.axis('off')
    ax.plot_trisurf(
        x, z, triangles, y, shade=True, color=(0.5, 0.5, 0.5),
        edgecolor='none', linewidth=0, antialiased=False, alpha=1.0
    )

    ax.view_init(elev=0, azim=0)
    # taken from https://stackoverflow.com/a/61443397/3090085
    with io.BytesIO() as io_buf:
        fig.savefig(io_buf, facecolor='black', format='raw', dpi=128)
        io_buf.seek(0)
        image = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
            newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))

    plt.close('all')
    image = np.array(Image.fromarray(image).resize((48, 48), resample=Image.BICUBIC).convert('L'))
    return image

def sample(data_dir, batch_size, num_actions):
    vertices, triangles = read_off(data_dir + '/chair.off')
    vertices -= np.mean(vertices, axis=0)
    vertices /= np.mean(np.linalg.norm(vertices, axis=1))
    # action_list = [expm(0.2 * logm(special_ortho_group.rvs(3))) for _ in range(num_actions)]
    action_list = [special_ortho_group.rvs(3) for _ in range(num_actions)]
    init_list = [special_ortho_group.rvs(3) for _ in range(batch_size)]
    x_list = [np.empty((batch_size, 3, 48, 48)) for _ in range(num_actions)]
    for i in range(batch_size):
        for j in range(num_actions):
            R = action_list[j] @ init_list[i]
            x_list[j][i] = render(vertices, triangles, R)
    actions = np.array(action_list)
    return x_list, actions




def sample_batch_cont(data_dir, batch_size):
    vertices, triangles = read_off(data_dir + '/chair.off')
    vertices -= np.mean(vertices, axis=0)
    vertices /= np.mean(np.linalg.norm(vertices, axis=1))
    action = expm(0.01 * logm(special_ortho_group.rvs(3)))
    init_list = [special_ortho_group.rvs(3) for _ in range(batch_size)]
    x = np.empty((batch_size * 2, 3, 48, 48))
    for i in range(batch_size):
            R1 = init_list[i]
            x[i] = render(vertices, triangles, R1)
            R2 = action @ init_list[i]
            x[batch_size + i] = render(vertices, triangles, R2)
    return x

class ChairSampleDataset(Dataset):
    """Chair sample dataset."""
    def __init__(self, data_dir, len_dataset):
        vertices, triangles = read_off(data_dir + '/chair.off')
        vertices -= np.mean(vertices, axis=0)
        vertices /= np.mean(np.linalg.norm(vertices, axis=1))
        self.vertices = vertices
        self.triangles = triangles
        self.len_dataset = len_dataset

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx):
        action = expm(0.02 * logm(special_ortho_group.rvs(3)))
        init = special_ortho_group.rvs(3)
        image_1 = render(self.vertices, self.triangles, init)
        image_2 = render(self.vertices, self.triangles, action @ init)
        return torch.Tensor(image_1), torch.Tensor(image_2)

class ChairDRLIMDataset(Dataset):
    """Chair sample dataset."""
    def __init__(self, data_dir):
        self.images = np.load(data_dir + '/images.npy')
        print('dataset loaded for training')

    def __len__(self):
        return self.images.shape[0]

    def __getitem__(self, idx):
        image_1 = self.images[idx, 0]
        image_2 = self.images[idx, 1]
        return torch.Tensor(image_1), torch.Tensor(image_2)


class ChairTrainDataset(Dataset):
    """teapot dataset."""
    def __init__(self, data_dir, len_dataset, cluster_size, num_actions, image_channels):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_dir = data_dir
        self.cluster_size = cluster_size
        self.num_actions = num_actions
        self.images = np.load(self.data_dir + '/images.npy')
        self.actions = np.load(self.data_dir + '/actions.npy')
        self.num_samples = self.images.shape[1]
        self.total_num_actions = self.images.shape[0]
        self.len_dataset = len_dataset
        self.image_channels = image_channels

    def __len__(self):
        return self.len_dataset

    def __getitem__(self, idx):
        idx_cluster = np.random.choice(self.num_samples, self.cluster_size, replace=False)
        idx_actions = np.random.choice(self.total_num_actions, self.num_actions, replace=False)
        img = self.images[idx_actions][:, idx_cluster, :self.image_channels, :, :]
        actions = self.actions[idx_actions]
        return torch.Tensor(img), torch.Tensor(actions)



class ChairTestDataset(Dataset):
    """teapot dataset."""
    def __init__(self, data_dir):
        """
        Args:
            root_dir (string): Directory with all the images.
            transform (callable, optional): Optional transform to be applied
                on a sample.
        """
        self.data_dir = data_dir
        self.images = np.load(self.data_dir + '/images.npy')
        self.actions = np.load(self.data_dir + '/actions.npy')
        self.next_images = np.load(self.data_dir + '/next_images.npy')
        self.images = self.images / 255.
        self.next_images = self.next_images / 255.
        self.num_samples = self.images.shape[0]

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return torch.Tensor(self.images[idx]), torch.Tensor(self.actions[idx]), torch.Tensor(self.next_images[idx])


