import numpy as np
import torch
from sklearn.svm import LinearSVC
from tqdm import tqdm

def random_rotate(X):
    # Rotate by random multiple of 90 degrees
    rotation_angle = np.random.choice([0, 0.25, 0.5, 0.75]) * 2 * np.pi
    cosval = np.cos(rotation_angle)
    sinval = np.sin(rotation_angle)
    rotation_matrix = np.array([[cosval, 0, sinval],
                                [0, 1, 0],
                                [-sinval, 0, cosval]])
    if isinstance(X, list):
        X = [np.copy(x) for x in X]
        for x in X:
            x[:,:3] = np.dot(x[:,:3].reshape((-1, 3)), rotation_matrix)
        return X
    X = np.copy(X)
    X[:,:3] = np.dot(X[:,:3].reshape((-1, 3)), rotation_matrix)
    return X

def scale_to_unit_ball(x):
    x = np.copy(x)
    if len(x) == 0:
        return x
    centroid = np.mean(x, axis=0)
    x -= centroid
    furthest_distance = np.max(np.sqrt(np.sum(abs(x) ** 2, axis=-1))) + 10e-3
    x /= furthest_distance
    return x

def scale_to_unit_cube(x):
    x = np.copy(x)
    if len(x) == 0:
        return x
    centroid = np.mean(x, axis=0)
    x -= centroid
    furthest_distance = np.max(np.abs(x)) + 10e-3
    x /= furthest_distance
    return x

def scale_to_unit_cube_s3dis(x):
    # Used for S3DIS dataset!
    x = np.copy(x)
    if len(x) == 0:
        return x
    x[:, :3] /= np.array([0.5, np.max(np.abs(x[:, 1])), 0.5])
    x[:, 1] = (x[:, 1] - 0.5) * 2
    return x

def jitter(x, sigma=0.01, clip=0.05):
    x = np.copy(x)
    N, C = x.shape
    assert (clip > 0)
    jitter = np.clip(sigma * np.random.randn(N, C), -1 * clip, clip)
    x += jitter
    return x

def shift(batch_data, sigma=0.01):
    arr = np.zeros(batch_data.shape[1])
    shifts = np.random.normal(scale=sigma, size=3)
    arr[:3] += shifts
    return batch_data + arr

def random_scale(X, sigma=0.05):
    X = np.copy(X)
    rand = np.random.normal(scale=sigma, size = 3)
    X[:, :3] *= (1 + rand)
    return X

def pairwise_dist(x):
    """ Take BxNxD matrix and returns pairwise euclidean
    BxNxN matrix """
    inner = torch.bmm(x, torch.transpose(x, 2, 1))
    inner = -2 * inner
    square = (x**2).sum(2).unsqueeze(2)
    square_transpose = torch.transpose(square, 2, 1)
    return square + inner + square_transpose

def train_svm(X_train, y_train, X_test, y_test):
    svm = LinearSVC(class_weight="balanced", max_iter = 50000)
    if len(y_train.shape) > 1:
        y_train = y_train[:, 0]
    if len(y_test.shape) > 1:
        y_test = y_test[:, 0]
    svm.fit(X_train, y_train)
    return svm.score(X_test, y_test)

def load_h5(h5_filename):
    f = h5py.File(h5_filename)
    point_embs = f['point_embs'][:]
    point_max_embs = f['point_max_embs'][:]
    max_emb_index = f['max_emb_index'][:]
    labels = f['labels'][:]
    pid = f['pid'][:]
    return (point_embs, point_max_embs, max_emb_index, labels, pid)

def tqdm_(loader):
    return enumerate(tqdm(loader, ascii=True, leave=False))


def maybe_rotate(x, p=0.15):
    if np.random.uniform() < p:
        m = np.mean(x, axis=0)
        return random_rotate(x-m)+m
    return x

n = 3
lookup = []
d = 2 / n
for i in range(n):
    for j in range(n):
        for k in range(n):
            lookup.append([1 - d * (i + 0.5), 1 - d * (j + 0.5), 1 - d * (k  + 0.5)])
lookup = np.array(lookup)


def split(X, R, n=3):
    # Performs all splits at once while keeping track of the point IDs which is necessary for semantic/part segmentatin


    X_ = []
    d = 2 / n
    X_clip =  np.clip(X, -0.99999999, 0.99999999)

    for x in range(n):
        for y in range(n):
            for z in range(n):
                X_.append(X[np.min([
                    -1 + x * d < X_clip[:, 0],
                    X_clip[:, 0] < -1 + (x + 1) * d,
                    -1 + y * d < X_clip[:, 1],
                    X_clip[:, 1] < -1 + (y + 1) * d,
                    -1 + z * d < X_clip[:, 2],
                    X_clip[:, 2] < -1 + (z + 1) * d,
                ], axis=0)])

    y = np.zeros(X.shape[0])
    c = 0
    for i, x in enumerate(X_):
        y[c:c + x.shape[0]] = i
        c += x.shape[0]

    x_r, y_r, z_r = np.random.randint(n, size=(3,))
    R_clip = np.clip(R, -0.99999999, 0.99999999)
    R = R[np.min([
        -1 + x_r * d < R_clip[:, 0],
        R_clip[:, 0] < -1 + (x_r + 1) * d,
        -1 + y_r * d < R_clip[:, 1],
        R_clip[:, 1] < -1 + (y_r + 1) * d,
        -1 + z_r * d < R_clip[:, 2],
        R_clip[:, 2] < -1 + (z_r + 1) * d,
    ], axis=0)]

    if len(R)>0:
        r = int(np.random.choice(np.unique(y)))
        ind = y == r
        if np.sum(ind)>0:
            X_[r] = R[np.random.choice(len(R), size=np.sum(ind), replace=True)]
            y[ind] = n**3
    #return np.concatenate([maybe_rotate(x) for x in X_ if len(x)>0]), y
    return np.concatenate([maybe_rotate(shift(x)) for x in X_ if len(x)>0]), y

def randomize(X, y, lookup, n=3):
    y_ = np.random.permutation(n**3)

    for i in range(n**3):
        ind = y == i
        if np.sum(ind)>0:
            X[ind,:3] -= np.mean(X[ind][:,:3], axis=0) + lookup[y_[i]]
    return X, y
