#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

"""In this tutorial, we will train an image classifier with FLSim to simulate a federated learning training environment.

With this tutorial, you will learn the following key components of FLSim:
1. Data loading
2. Model construction
3. Trainer construction

    Typical usage example:
    python3 cifar10_example.py --config-file configs/cifar10_config.json
"""
import flsim.configs  # noqa
import hydra
import numpy as np
import torch
from flsim.data.data_sharder import SequentialSharder, RandomSharder
from flsim.channels.base_channel import FLChannelConfig
from flsim.clients.base_client import ClientConfig
from flsim.servers.sync_servers import SyncServerConfig
from flsim.active_user_selectors.simple_user_selector import (
    SequentialActiveUserSelectorConfig, UniformlyRandomActiveUserSelectorConfig
    
)
from flsim.optimizers.local_optimizers import LocalOptimizerSGDConfig, LocalOptimizerAdamConfig
from flsim.interfaces.metrics_reporter import Channel
from flsim.utils.config_utils import fl_config_from_json
from utils import (
    DataLoader,
    DataProvider,
    FLModel,
    MetricsReporter,
    SimpleConvNet,
)
from dirichlet_sharder import DirichletSharder
from hydra.utils import instantiate
from omegaconf import DictConfig, OmegaConf
from torchvision import transforms
from torchvision.datasets.cifar import CIFAR10
from torchvision.datasets.flowers102 import Flowers102
from torchvision.datasets.stanford_cars import StanfordCars
from torchvision.datasets.eurosat import EuroSAT
from femnist_dataset import FEMNIST
from cub2011_dataset import Cub2011
from config import json_config
import wandb
from torchvision.models import squeezenet1_1
from torchvision.models.squeezenet import SqueezeNet1_1_Weights
from sync_trainer import SyncTrainer, SyncTrainerConfig
import torch.nn.functional as F
import os
from utils import validata_dataset_params
import ssl

#IMAGE_SIZE = 32

def build_data_provider(local_batch_size, num_clients, dataset, num_classes, alpha, drop_last: bool = False):
    if dataset == 'femnist':
        data_dir = '../data/femnist/'
        train_dataset = FEMNIST(root=data_dir, train=True, download=True)
        mean = train_dataset.train_data.float().mean()
        std = train_dataset.train_data.float().std()

        apply_transform = transforms.Compose([
            transforms.RandomCrop(24, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])
        test_transform = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize(mean, std)])

        train_dataset = FEMNIST(data_dir, train=True, download=False, transform=apply_transform)
        test_dataset = FEMNIST(data_dir, train=False, download=False, transform=test_transform)

    elif dataset == 'flowers':
        train_dir = '../data/flowers/train'
        test_dir = '../data/flowers/test'
        train_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                                    transforms.RandomHorizontalFlip(),
                                                    transforms.ToTensor(),
                                                    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                                    ])
        test_transforms = transforms.Compose([transforms.Resize((224, 224)),
                                                   transforms.ToTensor(),
                                                   transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
                                                   ])

        train_dataset = Flowers102(train_dir, split='train', transform=train_transforms, download=True)
        test_dataset = Flowers102(test_dir, split='test', transform=test_transforms, download=True)

    elif dataset == 'cub':
        data_dir = '../data/cub2011'
        train_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        test_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        train_dataset = Cub2011(data_dir, train=True, transform=train_transform, download=True)
        test_dataset = Cub2011(data_dir, train=False, transform=test_transform, download=True)

    elif dataset == 'cars':
        data_dir = '../data/stanford_cars'
        train_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        test_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        train_dataset = StanfordCars(data_dir, split="train", transform=train_transform, download=True)
        test_dataset = StanfordCars(data_dir, split="test", transform=test_transform, download=True)

    elif dataset == 'eurosat':
        data_dir = '../data/eurosat'
        train_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

        test_transform = transforms.Compose(
            [
                transforms.Resize((224, 224)),
                transforms.ToTensor(),
                transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
            ])
        ssl._create_default_https_context = ssl._create_unverified_context
        train_dataset = EuroSAT(data_dir, transform=train_transform, download=True)
        test_dataset = EuroSAT(data_dir, transform=test_transform, download=True)

        #split randomly into 5000 train, 1000 test images (iid)
        train_idxs = []
        test_idxs = []
        for i in range(10):
            idx_for_label = np.where(np.asarray(train_dataset.targets) == i)[0]
            selected_idxs = np.random.choice(idx_for_label, 600, False)
            train_idxs.append(selected_idxs[:500])
            test_idxs.append(selected_idxs[500:])
        train_idxs = [idx for i in range(10) for idx in train_idxs[i]]
        test_idxs = [idx for i in range(10) for idx in test_idxs[i]]
        test_dataset = torch.utils.data.Subset(test_dataset, test_idxs)
        train_dataset = torch.utils.data.Subset(train_dataset, train_idxs)

    elif dataset == 'cifar':
        train_transform = transforms.Compose(
            [
                transforms.Resize(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                transforms.Normalize((0.491399689874, 0.482158419622, 0.446530924224), (0.247032237587, 0.243485133253, 0.261587846975))
            ]
        )

        test_transform = transforms.Compose(
            [
                transforms.Resize(224),
                transforms.ToTensor(),
                transforms.Normalize((0.491399689874, 0.482158419622, 0.446530924224),
                                     (0.247032237587, 0.243485133253, 0.261587846975))
            ]
        )

        train_dataset = CIFAR10(
            root="../data/cifar10", train=True, download=True, transform=train_transform)
        test_dataset = CIFAR10(
            root="../data/cifar10", train=False, download=True, transform=test_transform)
    else:
        raise Exception(f'{dataset} is not a valid dataset')
    #sharder = SequentialSharder(examples_per_shard=examples_per_user)
    #sharder = RandomSharder(num_shards=100)

    # num_shards=number_of_users
    sharder = DirichletSharder(num_shards=num_clients, alpha=alpha, num_classes=num_classes)
    fl_data_loader = DataLoader(
        train_dataset, test_dataset, test_dataset, sharder, local_batch_size, drop_last
    )

    data_provider = DataProvider(fl_data_loader)
    print(f"Clients in total: {data_provider.num_train_users()}")

    return data_provider, fl_data_loader


def main(cfg,
    use_cuda_if_available: bool = True,
) -> None:
    cuda_enabled = torch.cuda.is_available() and use_cuda_if_available
    device = torch.device(f"cuda:{0}" if cuda_enabled else "cpu")

    if cfg.trainer.pretrained:
        model = squeezenet1_1(weights=SqueezeNet1_1_Weights.DEFAULT)
    else:
        model = squeezenet1_1()

    if cfg.trainer.last_layer:
        for name, param in model.named_parameters():
        #     if not name == 'classifier.0.weight' or not 'classifier.0.bias':
                param.requires_grad = False
        model.eval()

    model.classifier = torch.nn.Sequential(
        torch.nn.AdaptiveAvgPool2d((1, 1)),
        torch.nn.Flatten(),
        torch.nn.Linear(512, cfg.dataset.num_classes)
    )
    print(model)
    #wandb.watch(model, log_freq=1)

    # pyre-fixme[6]: Expected `Optional[str]` for 2nd param but got `device`.
    global_model = FLModel(model, device)

    if cuda_enabled:
        global_model.fl_cuda()

    data_provider, fl_data_loader = build_data_provider(
        local_batch_size=cfg.trainer.client.local_bs,
        num_clients=cfg.trainer.total_users,
        dataset=cfg.dataset.name,
        num_classes=cfg.dataset.num_classes,
        alpha=cfg.dataset.alpha,
        drop_last=False,
    )

    #Linear Only
    trainer = SyncTrainer(
            model=global_model,
            cuda_enabled=cuda_enabled,
            **OmegaConf.structured(SyncTrainerConfig(
                epochs=cfg.trainer.epochs,
                do_eval=cfg.trainer.do_eval,
                always_keep_trained_model=False,
                train_metrics_reported_per_epoch=cfg.trainer.train_metrics_reported_per_epoch,
                eval_epoch_frequency=1,
                report_train_metrics=cfg.trainer.report_train_metrics,
                report_train_metrics_after_aggregation=cfg.trainer.report_train_metrics_after_aggregation,
                client=ClientConfig(
                    epochs=cfg.trainer.client.epochs,
                    # optimizer=LocalOptimizerSGDConfig(
                    #     lr=cfg.trainer.client.optimizer.lr,
                    # ),
                    optimizer=LocalOptimizerAdamConfig(
                        lr=cfg.trainer.client.optimizer.lr,
                    ),
                    lr_scheduler=cfg.trainer.client.lr_scheduler,
                    shuffle_batch_order=False,
                ),
                channel=FLChannelConfig(),
                server=SyncServerConfig(
                    active_user_selector=UniformlyRandomActiveUserSelectorConfig()
                ),
                users_per_round=cfg.trainer.users_per_round,
                dropout_rate=cfg.trainer.dropout_rate,
            )),
        )

    if cfg.trainer.ncm_init:

        trainloader = torch.utils.data.DataLoader(
            fl_data_loader.train_dataset, batch_size=128, shuffle=True, num_workers=2)

        class_sums = torch.zeros((10,512)).to(device)
        for batch_idx, (data, target) in enumerate(trainloader):
            data, target = data.to(device), target.to(device)
            features = global_model.model.features(data)
            b,f,h,w = features.shape
            features = F.adaptive_avg_pool2d(features,1).squeeze()
            # loss = F.cross_entropy(output, target)

            for i,t in enumerate(target):
                class_sums[t]+=features[i].data.squeeze()
        class_means = class_sums/5000.
        print(global_model.model.classifier)

        global_model.model.classifier[0].weight.data = torch.nn.functional.normalize(class_means)

    print(f"\nClients in total: {data_provider.num_train_users()}")
    #wandb.log({f"Clients in total": data_provider.num_train_users()})

    metrics_reporter = MetricsReporter([Channel.TENSORBOARD, Channel.STDOUT])

    print("Before training")
    final_model, eval_score = trainer.train(
        data_provider=data_provider,
        metrics_reporter=metrics_reporter,
        num_total_users=data_provider.num_train_users(),
        distributed_world_size=1,
    )

    # Full
    trainer_ft = SyncTrainer(
            model=final_model,
            cuda_enabled=False,
            **OmegaConf.structured(SyncTrainerConfig(
                epochs=500,
                do_eval=True,
                always_keep_trained_model=False,
                train_metrics_reported_per_epoch=100,
                eval_epoch_frequency=1,
                report_train_metrics=cfg.trainer.report_train_metrics,
                report_train_metrics_after_aggregation=cfg.trainer.report_train_metrics_after_aggregation,
                client=ClientConfig(
                    epochs=1,
                    optimizer=LocalOptimizerSGDConfig(
                        lr=0.001,
                    ),
                    lr_scheduler=cfg.trainer.client.lr_scheduler,
                    shuffle_batch_order=False,
                ),
                channel=FLChannelConfig(),
                server=SyncServerConfig(
                    active_user_selector=SequentialActiveUserSelectorConfig()
                ),
                users_per_round=30,
                dropout_rate=cfg.trainer.dropout_rate,
            )),
        )

    if cfg.trainer.last_layer:
        for name, param in final_model.model.named_parameters():
            param.requires_grad = True

    final_model, eval_score = trainer_ft.train(
        data_provider=data_provider,
        metrics_reporter=metrics_reporter,
        num_total_users=data_provider.num_train_users(),
        distributed_world_size=1,
    )

    trainer.test(
        data_provider=data_provider,
        metrics_reporter=MetricsReporter([Channel.STDOUT]),
    )

def run(cfg: DictConfig) -> None:
    print(OmegaConf.to_yaml(cfg))
    main(cfg)


if __name__ == "__main__":
    wandb.init(project="pretrained-fl-flsim", entity="bigbernnn")
    wandb.config.update(json_config)

    cfg = fl_config_from_json(json_config)
    validata_dataset_params(cfg)
    run(cfg)