'''Build script for Maxpool operator'''
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List

from buildscripts.common import (
    BaseOp,
    OpBuild,
    OperatorsRegistry,
    ScheduleInputs,
    WGTFormatting,
    get_kernel_id,
    normalize_shape,
    save_cfg_json,
    bytes_to_bits,
)
from dmacompiler import (
    BackEnd,
    DevGen,
    set_dev_gen,
    DmaPaddingMap,
)
from scheduler.common import L3Alloc
from scheduler.maxpool.maxpool import (
    MaxpoolShape,
    compile_maxpool_dataflow,
    generate_maxpool_mappings,
)
from utils.utils_common import L2Alloc, log, ReadBins

set_dev_gen(DevGen.Aie4)


def register() -> None:
    '''Adding `maxpool_noqdq_a8` operator that has multiple flavors'''
    OperatorsRegistry.add_operator("maxpool_noqdq_a8", {
            "testbench": ["maxpool.cpp", "maxpool.hpp"],
            "dataflow_script": "maxpool.py",
            "build_script": "build_maxpool.py",
            "kernel_names": {"run_maxpool_int8x8": get_kernel_id("run_maxpool_int8x8")},
            "kernel_includes": ["super.hh", "maxpool/maxpool_int8x8_wrapper.cc"]
        }
    )


@dataclass(frozen=True)
class MaxPoolOp(BaseOp):
    """BaseOp for MaxPool."""
    # Tiling/mapping selection
    MappingRank: int

    # Op kernel params
    Ky: int
    Kx: int
    Sy: int
    Sx: int
    Py: int
    Px: int

    # Preproc + dtype for testbench
    sign_A: int
    dtype: int = 1


class MaxPoolBuild(OpBuild):
    """MaxPool Build Interface for build.py"""

    def default_kernel_names(self) -> Dict[str, int]:
        return {"run_maxpool_int8x8": get_kernel_id("run_maxpool_int8x8")}

    def default_kernel_includes(self) -> list[str]:
        return ["super.hh", "maxpool/maxpool_int8x8_wrapper.cc"]

    def op_type(self):
        return MaxPoolOp

    def _parse_from_dict(
        self, data: dict, shim_prm_offset: int, shim_wgt_offset: int,
        read_bins: ReadBins, read_model_data: bool, model_data_path: str
    ) -> MaxPoolOp:
        Ni, Yi, Xi, Ci = normalize_shape(data.get("input"))
        No, Yo, Xo, Co = normalize_shape(data.get("output"))
        Ky, Kx = data.get("kernel", [1, 1])
        Sy, Sx = data.get("stride", [1, 1])
        Py, Px = data.get("pad", [0, 0])
        _ = (Ni, No, Co)

        ifm_tile, ifm_off = self._get_tile_offset(data["input_addr"])
        ofm_tile, ofm_off = self._get_tile_offset(data["output_addr"])

        L2_alloc = L2Alloc(
            (ifm_tile, ifm_off),
            (ofm_tile, ofm_off),
            self._get_wgt_addr(data["wgt_addr"]),
            self._get_prm_addr(data["prm_addr"]),
            data.get("load_input_from_ddr", True),
            data.get("store_output_to_ddr", True),
            data.get("enable_L2_fusion", False),
        )

        L3_alloc = L3Alloc(
            data["L3"]["ifm"],
            data["L3"]["ofm"],
            [2, shim_wgt_offset],
            [3, shim_prm_offset],
        )

        wgt_fmt = WGTFormatting(
            node_name=data.get("name") or data.get("op"),
            model_data_path=model_data_path,
            read_model_data=read_model_data
        )

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        return MaxPoolOp(
            Ni=Ni, Yi=Yi, Xi=Xi, Ci=Ci, No=No, Yo=Yo, Xo=Xo, Co=Co,
            L2=L2_alloc, L3=L3_alloc,
            dataflow_type=0,  # L2
            read_bins=read_bins,
            sign_A=int(data.get("sign_A", 0)),
            sign_W=0, sign_O=0,
            dtype_A=int(data.get("dtype", 1)),
            dtype_W=int(data.get("dtype", 1)),
            dtype_O=int(data.get("dtype", 1)),
            wgt_fmt=wgt_fmt, pad_value=pad_value, is_dma_pad=is_dma_pad,
            MappingRank=int(data.get("MappingRank", 0)),
            Ky=Ky, Kx=Kx, Sy=Sy, Sx=Sx, Py=Py, Px=Px,
            dtype=int(data.get("dtype", 1)),
        )

    def shape(self, op_class: MaxPoolOp) -> List[int]:
        dims_shape = [
            op_class.Yi, op_class.Xi, op_class.Ci,
            op_class.Yo, op_class.Xo,
            op_class.Ky, op_class.Kx,
            op_class.Sy, op_class.Sx,
            op_class.Py, op_class.Px,
        ]
        log(f"maxpool dims: {dims_shape}")
        return dims_shape

    def tiler(self, dims_shape: List[int], op_class: MaxPoolOp) -> ScheduleInputs:
        mappings = generate_maxpool_mappings(
            op_class.Yi, op_class.Xi, op_class.Ci,
            op_class.Yo, op_class.Xo,
            op_class.Ky, op_class.Kx,
            op_class.Sy, op_class.Sx,
            op_class.Py, op_class.Px,
        )
        mapping = mappings[op_class.MappingRank]
        return ScheduleInputs(mapping, dims_shape,
                              op_class.dataflow_type,
                              op_class.L2, op_class.L3,
                              DmaPaddingMap(op_class.pad_value, bytes_to_bits(op_class.dtype_A), op_class.is_dma_pad))

    def L2_schedule(self, schedule_input: ScheduleInputs):
        return compile_maxpool_dataflow(schedule_input)

    def L3_schedule(self, schedule_input: ScheduleInputs):
        return compile_maxpool_dataflow(schedule_input)

    def preproc(self, schedule_input: ScheduleInputs, op_class: MaxPoolOp):
        shape: MaxpoolShape = schedule_input.shape
        back_end: BackEnd = schedule_input.backend
        sign_act: int = op_class.sign_A
        dtype: int = op_class.dtype
        read_bins: ReadBins = op_class.read_bins
        wgt_fmt = op_class.wgt_fmt
        asm_mode = int(back_end != BackEnd.Adf)
        if dtype == 1:
            maxpool_noqdq_a8 = 1
        else:
            maxpool_noqdq_a8 = 0
        Yi, Xi, _ = shape.ifm_dims
        Yo, Xo, Co = shape.ofm_dims
        Ky, Kx = shape.filter_dims
        Sy, Sx = shape.stride
        Py, Px = shape.padding
        cfg = {
            "Y_IN": Yi,
            "X_IN": Xi,
            "C_OUT": Co,
            "Y_OUT": Yo,
            "X_OUT": Xo,
            "KERNEL_Y": Ky,
            "KERNEL_X": Kx,
            "STRIDE_Y": Sy,
            "STRIDE_X": Sx,
            "PAD_Y": Py,
            "PAD_X": Px,
            "SIGN": sign_act,
            "MAXPOOL_NOQDQ_A8": maxpool_noqdq_a8,
            "ASM_MODE": asm_mode,
            "READ_IFM": read_bins.read_ifm,
            "READ_WGT": read_bins.read_wgt,
            "NODE_NAME": wgt_fmt.node_name,
            "MD_PATH": wgt_fmt.model_data_path,
            "READ_MD": int(wgt_fmt.read_model_data),
        }
        save_cfg_json(cfg, "maxpool_cfg.json")


def get_op():
    """API exposed by MaxPool for build.py"""
    return MaxPoolBuild()
