import torch
from torch.utils.data import DataLoader
import numpy as np
import scipy.sparse
import scipy.io

from . import file_utils


class BasicDatasetHandler:
    def __init__(self, dataset_dir, batch_size=200, read_labels=False, device='cpu'):
        self.load_data(dataset_dir, read_labels)
        self.vocab_size = len(self.vocab)

        self.train_data = self.train_bow
        self.test_data = self.test_bow

        self.train_data = torch.from_numpy(self.train_data).to(device)
        self.test_data = torch.from_numpy(self.test_data).to(device)

        self.train_dataloader = DataLoader(self.train_data, batch_size=batch_size, shuffle=True)
        self.test_dataloader = DataLoader(self.test_data, batch_size=batch_size, shuffle=False)

    def load_data(self, path, read_labels):

        self.train_bow = scipy.sparse.load_npz(f'{path}/train_bow.npz').toarray().astype('float32')
        self.test_bow = scipy.sparse.load_npz(f'{path}/test_bow.npz').toarray().astype('float32')
        self.pretrained_WE = scipy.sparse.load_npz(f'{path}/word_embeddings.npz').toarray().astype('float32')

        self.train_texts = file_utils.read_text(f'{path}/train_texts.txt')
        self.test_texts = file_utils.read_text(f'{path}/test_texts.txt')

        if read_labels:
            self.train_labels = np.loadtxt(f'{path}/train_labels.txt', dtype=int)
            self.test_labels = np.loadtxt(f'{path}/test_labels.txt', dtype=int)

        self.vocab = file_utils.read_text(f'{path}/vocab.txt')
