import math
import numpy as np

import torch
import torch.nn as nn
from libs.models.decomp_model import AcousticRayTracing as model_1
# from libs.models.decomp_model_v2 import AcousticRayTracing as model_2
from libs.models.decomp_model_v3_all import AcousticRayTracing as model_3_all
from libs.models.decomp_model_v3 import AcousticRayTracing as model_3
from libs.models.decomp_model_v6_all import AcousticRayTracing as model_6
from libs.models.decomp_model_v6_new import AcousticRayTracing as model_6_new
class Renderer(nn.Module):
    def __init__(self,
                 model,
                 is_train=True):
        super().__init__()
        self.model = model
        self.is_train = is_train

    def render(self, source_points, points, norm_source_points, norm_points, dirs, b_range):
        B = 2 * len(source_points)
        pred_ir = self.model(source_points, points, norm_source_points, norm_points, dirs, b_range)
        pred_ir = torch.transpose(pred_ir.squeeze(3), 1, 2)
        pred_ir = pred_ir.reshape(B, -1)
        return pred_ir


def build_render(cfg, n_bins, patches):
    print(cfg)
    # if cfg.model.file == 'decomp_v2':
    #     model = model_2(n_bins, patches)
    # elif cfg.model.file == 'decomp_v3':
    #     model = model_3(n_bins, patches)
    # elif cfg.model.file == 'decomp_v4':
    #     model = model_4(n_bins, patches)
    # else:
    #     model = model_1(n_bins, patches)
    if cfg == 'decomp_v6':
        model = model_6(n_bins, patches)
    elif cfg == 'decomp_v6_new':
        model = model_6_new(n_bins, patches)
    elif cfg == 'decomp_v3':
        model = model_3(n_bins, patches)
    elif cfg == 'decomp_v3_all':
        model = model_3_all(n_bins, patches)
    else:
        model = model_1(n_bins, patches)
    # model = AcousticRayTracing(n_bins, patches)
    render_config = {
        'model': model
    }
    render = Renderer(**render_config)
    return render


if __name__ == "__main__":
    from libs.datasets.ReplicaDataset import ReplicaDataset
    from libs.models.BaseRayTracing import AcousticRayTracing

    dataset = ReplicaDataset(data_root='data/room_0/', dir=0, source_id=41)
    points, bounces, dirs = dataset.__getitem__(0)[:3]
    model = AcousticRayTracing(n_bins=dataset.max_n_bins)
    render = Renderer(model, time_bin_size=dataset.time_bin_size, fs=dataset.fs)
    pred_ir, pred_np_ir = render.render(points, bounces)
