from spaghettini import quick_register
from munch import Munch

from torch.optim import Adam, SGD, RMSprop, LBFGS
from torch.nn import CrossEntropyLoss, MSELoss, BCELoss
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor, Normalize, Compose
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.loggers import WandbLogger
from torchvision.datasets import MNIST
from torchvision.datasets import DatasetFolder
from torch.nn.functional import mse_loss, l1_loss, binary_cross_entropy_with_logits, cross_entropy
from torch.nn.functional import relu, softplus
from torch.nn.utils import spectral_norm
from torch.nn import Softplus, ReLU, ELU, Sigmoid, LeakyReLU
from torch import sigmoid, tanh
from torch.nn.functional import elu, leaky_relu
from torch.optim.lr_scheduler import StepLR, LambdaLR, MultiStepLR
from torch.nn import Linear
from transformers import get_constant_schedule_with_warmup

# Register basic.
quick_register(Munch)

# ____Pytorch Related____ #
# Data related.
quick_register(MNIST)
quick_register(DataLoader)
quick_register(DatasetFolder)
quick_register(ToTensor)
quick_register(Normalize)
quick_register(Compose)

# Optimizer related.
quick_register(Adam)
quick_register(SGD)
quick_register(RMSprop)
quick_register(LBFGS)
quick_register(StepLR)
quick_register(MultiStepLR)
quick_register(LambdaLR)
quick_register(get_constant_schedule_with_warmup)

# Losses.
quick_register(CrossEntropyLoss)
quick_register(mse_loss)
quick_register(cross_entropy)
quick_register(MSELoss)
quick_register(l1_loss)
quick_register(BCELoss)
quick_register(binary_cross_entropy_with_logits)

# Activations (lowercase for functional, uppercase for modular).
quick_register(relu)
quick_register(softplus)
quick_register(sigmoid)
quick_register(Sigmoid)
quick_register(tanh)
quick_register(elu)
quick_register(Softplus)
quick_register(ReLU)
quick_register(ELU)
quick_register(LeakyReLU)
quick_register(leaky_relu)

# Model building blocks.
quick_register(Linear)
quick_register(spectral_norm)

# ____Pytorch Lightning Related___ #
quick_register(Trainer)
quick_register(ModelCheckpoint)
quick_register(WandbLogger)
