# pylint: disable=useless-return
'''Build script for Add operator'''
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Tuple, Any

from scheduler.pdi.pdi import compile_pdi

from dmacompiler import DevGen, set_dev_gen
from buildscripts.common import OperatorsRegistry, OpBuild, ScheduleInputs
from utils.utils_common import ReadBins
set_dev_gen(DevGen.Aie4)


def register() -> None:
    '''Operator registry for pdi'''
    OperatorsRegistry.add_operator("pdi", {
            "testbench": ["pdi.cpp", "pdi.hpp"],
            "dataflow_script": "pdi.py",
            "build_script": "build_pdi.py",
            "kernel_names": {
                                "run_conv_noqdq_a8w8": 0,
                                "run_maxpool_int8x8": 1,
                                "run_matadd_int8": 2,
                            },
            "kernel_includes": [
                                "super.hh",
                                "conv/conv_noqdq_a8w8_wrapper.cc",
                                "maxpool/maxpool_int8x8_wrapper.cc",
                                "binary/run_matadd_wrapper.cc",

                                ]
        }
    )


def get_all_kernels(dtype: int) -> tuple[List[str], Dict[str, Any]]:
    '''Collect kernel metadata'''
    def merge_dicts(dict_list: List[Dict[str, Any]]) -> Dict[str, Any]:
        """Takes a list of dictionaries, merges them into one dictionary"""
        merged = {}
        for d in dict_list:
            merged.update(d)
        return dict(merged)

    combined_kernel_names_list = []
    combined_kernel_includes: list[str] = ["super.hh"]
    int8_supported_ops = {
        "maxpool_noqdq_a8", "Add_qdq_EleWise_uint8xuint8xuint8",
        "conv_noqdq_a8w8", "gap", "pdi",
    }

    conv_ops = {"Conv_qdq_int16xint8xint16",
                "Conv_qdq_uint16xint8xuint16",
                "Conv_qdq_uint16xuint8xuint16",
                "Conv_qdq_bias_uint16xuint8xuint16",
                "Conv_qdq_bias_uint16xint8xuint16",
                "Conv_qdq_bias_int16xint8xint16",
                }
    _ = conv_ops

    matmul_act_act = {"MatMul_qdq_actxact_uint16xuint16xuint16",
                      "MatMul_qdq_actxact_int16xint16xint16",
                      "MatMul_qdq_actxact_Transpose_uint16xuint16xuint16",
                      "MatMul_qdq_actxact_Transpose_int16xint16xint16",
                      }

    broadcast_ops = {"Add_qdq_BroadCast_uint8xuint8xuint8",
                     "Add_qdq_BroadCast_int8xint8xint8",
                     "Add_qdq_BroadCast_uint16xuint16xuint16",
                     "Add_qdq_BroadCast_int16xint16xint16",
                     "Mul_qdq_BroadCast_uint16xuint16xuint16",
                     "Mul_qdq_BroadCast_int16xint16xint16"}

    all_ops = set(OperatorsRegistry.get_operators().keys())
    supported_ops = (
        int8_supported_ops
        if dtype == 8
        else (all_ops - int8_supported_ops) - matmul_act_act - broadcast_ops
    )
    for operator in supported_ops:
        op_cfg = OperatorsRegistry.get_operator(operator)
        combined_kernel_names_list.append(op_cfg["kernel_names"])
        for inc in op_cfg["kernel_includes"]:
            if inc not in combined_kernel_includes:
                combined_kernel_includes.append(inc)

    # Merge kernel_names dictionaries from all relevant ops
    combined_kernel_names = merge_dicts(combined_kernel_names_list)
    return combined_kernel_includes, combined_kernel_names


@dataclass(frozen=True)
class PdiOp:
    """Dataclass for PDI-only build"""
    inputs: int


class PdiBuild(OpBuild):
    """PDI Build Interface for build.py"""

    def default_kernel_names(self) -> Dict[str, int]:
        return {
            "run_conv_noqdq_a8w8": 0,
            "run_maxpool_int8x8": 1,
            "run_matadd_int8": 2,
        }

    def default_kernel_includes(self) -> List[str]:
        return [
            "super.hh",
            "conv/conv_noqdq_a8w8_wrapper.cc",
            "maxpool/maxpool_int8x8_wrapper.cc",
            "matadd/run_matadd_wrapper.cc",
        ]

    def op_type(self):
        return PdiOp

    def _parse_from_dict(self, data: dict, shim_prm_offset: int, shim_wgt_offset: int,
                         read_bins: ReadBins, read_model_data: bool, model_data_path: str) -> PdiOp:
        inputs = data.get("input", 0)
        _ = (shim_prm_offset, shim_wgt_offset, read_bins, read_model_data, model_data_path)
        return PdiOp(inputs=inputs)

    def shape(self, op_class: PdiOp) -> Any:
        _ = op_class
        return None

    def tiler(self, dims_shape: Any, op_class: PdiOp) -> ScheduleInputs:
        _ = dims_shape, op_class
        return ScheduleInputs(
            shape=op_class.inputs,
            mapping=None,
            dataflow_type=1,
            L2_alloc=None,
            L3_alloc=None,
        )

    def L2_schedule(self, schedule_input: ScheduleInputs):
        return self.L3_schedule(schedule_input)

    def L3_schedule(self, schedule_input: ScheduleInputs) -> Tuple[int, int]:
        # Generate PDI for all kernels
        if schedule_input.shape[0] == 0:
            kernel_includes, kernel_names = schedule_input.kernel_includes, schedule_input.kernel_names
        else:
            kernel_includes, kernel_names = get_all_kernels(schedule_input.shape[0])
        compile_pdi(kernel_names, list(kernel_includes), schedule_input.backend, schedule_input.layer_file_name)
        shim_prm_offset_next_layer, shim_wgt_offset_next_layer = 0, 0
        return shim_prm_offset_next_layer, shim_wgt_offset_next_layer

    def preproc(self, schedule_input: ScheduleInputs, op_class: PdiOp):
        _ = (schedule_input, op_class)
        return None


def get_op():
    """API exposed by PDI for build.py"""
    return PdiBuild()
