from rearrange import baseline_models
import torch

# encoder_path = "pretrained_model_ckpts/darla/cnn/epoch100/beta1/betavaemodel20.pt"
# weights = torch.load(encoder_path, map_location=torch.device('cpu'))

_betaVAE = baseline_models.VaeResnetEncoder(latent_size = 512)
encoder_path = "pretrained_model_ckpts/darla/resnet/epoch50/beta1/betavaemodel_final.pt"
weights = torch.load(encoder_path, map_location=torch.device('cpu'))

del weights.decoder
del weights.log_var
params = weights.named_parameters()


state_dict = _betaVAE.state_dict()

betaVAE_namelist = [n for n,p in _betaVAE.named_parameters() if p.requires_grad == True]
print(betaVAE_namelist)
exit()
for n,p in _betaVAE.named_parameters():
    print(n)
betaVAE_paramlist = list(state_dict.keys())




