'''Build script for Binary operator'''
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict, List

from buildscripts.common import (
    BaseOp,
    OpBuild,
    OperatorsRegistry,
    ScheduleInputs,
    WGTFormatting,
    normalize_shape,
    save_cfg_json,
    bytes_to_bits,
)
from dmacompiler import (
    BackEnd,
    DevGen,
    set_dev_gen,
    DmaPaddingMap,
)
from scheduler.common import L3Alloc
from scheduler.resize.nni import compile_dataflow
from tiler.resize_tiler import ResizeMapping, ResizeShape
from utils.utils_common import (
    L2Alloc, ReadBins,
)

set_dev_gen(DevGen.Aie4)


def register() -> None:
    '''Operator registry for Resize NNI'''
    OperatorsRegistry.add_operator("Resize_qdq_uint16xuint16", {
            "testbench": ["resize.cpp", "resize.hpp"],
            "dataflow_script": "nni.py",
            "build_script": "build_resize.py",
            "kernel_names": [],
            "kernel_includes": ["super.hh"]
        }
    )


@dataclass(frozen=True)
class ResizeOp(BaseOp):
    """
    Dataclass for elementwise Resize op.
    """
    # Resize-specific configuration
    ifm_bytes: int
    num_interpolations: tuple[int, int]


class ResizeBuild(OpBuild):
    """Resize Build Interface for build.py"""
    def default_kernel_names(self) -> Dict[str, int]:
        return []

    def default_kernel_includes(self) -> List[str]:
        return ["super.hh"]

    def op_type(self):
        return ResizeOp

    def _parse_from_dict(self, data: dict, shim_prm_offset: int, shim_wgt_offset: int,
                         read_bins: ReadBins, read_model_data: bool, model_data_path: str) -> ResizeOp:
        Ni, Yi, Xi, Ci = normalize_shape(data.get("input"))
        No, Yo, Xo, Co = normalize_shape(data.get("output"))

        interpolation_n = int(data["attributes"]["scales_1"][0])
        interpolation_h = int(data["attributes"]["scales_2"][0])
        interpolation_w = int(data["attributes"]["scales_3"][0])
        interpolation_c = int(data["attributes"]["scales_4"][0])

        if interpolation_n != 1:
            assert False, "Interpolation cannot be performed in the n dimension"
        if interpolation_c != 1:
            assert False, "Interpolation cannot be performed in the c dimension"
        if interpolation_h != interpolation_w:
            assert False, "Interpolation scale has to equal in the h and w dimension"

        # L2 addresses (two inputs + ofm)
        ifm_tile, ifm_off = self._get_tile_offset(data["input_addr"])
        ofm_tile,  ofm_off = self._get_tile_offset(data["output_addr"])

        L2_alloc = L2Alloc(
            ifm_L2_loc=(ifm_tile, ifm_off),
            ofm_L2_loc=(ofm_tile, ofm_off),
            wgt_l2_loc=self._get_wgt_addr(data["wgt_addr"]),
            prm_l2_loc=self._get_prm_addr(data["prm_addr"]),
            enable_ifm_fill=bool(data.get("load_input_from_ddr", True)),
            enable_ofm_spill=bool(data.get("store_output_to_ddr", True)),
            enable_L2_fusion=bool(data.get("enable_L2_fusion", False)),
        )

        L3_alloc = L3Alloc(
            ifm=data["L3"]["ifm"],
            ofm=data["L3"]["ofm"],
            wgt=[2, shim_wgt_offset],
            prm=[3, shim_prm_offset],
        )

        # uint16xuint16 op
        ifm_bytes = 2

        # Signs/dtypes
        sign_A = data.get("sign_A", 0)
        sign_W = data.get("sign_W", 0)
        sign_O = data.get("sign_O", 0)
        dtype_A = 1 if ifm_bytes == 2 else 0
        dtype_W = 1 if ifm_bytes == 2 else 0
        dtype_O = 1 if ifm_bytes == 2 else 0

        # Binary flow uses unified compile; keep L2 path by default
        dataflow_type = 0 if data.get("enable_L2_fusion", False) else 1
        wgt_fmt = WGTFormatting(
            node_name=data.get("name") or data.get("op"),
            model_data_path=model_data_path,
            read_model_data=read_model_data
        )

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        return ResizeOp(
            Ni=Ni, Yi=Yi, Xi=Xi, Ci=Ci,
            No=No, Yo=Yo, Xo=Xo, Co=Co,
            L2=L2_alloc, L3=L3_alloc,
            dataflow_type=int(dataflow_type),
            read_bins=read_bins,
            sign_A=int(sign_A), sign_W=int(sign_W), sign_O=int(sign_O),
            dtype_A=int(dtype_A), dtype_W=int(dtype_W), dtype_O=int(dtype_O),
            wgt_fmt=wgt_fmt, pad_value=pad_value, is_dma_pad=is_dma_pad,
            ifm_bytes=int(ifm_bytes),
            num_interpolations=(interpolation_h, interpolation_w),
        )

    def shape(self, op_class: ResizeOp) -> ResizeShape:
        shape = ResizeShape(
            ifm=(op_class.Ni, op_class.Yi, op_class.Xi, op_class.Ci),
            ofm=(op_class.No, op_class.Yo, op_class.Xo, op_class.Co),
            ifm_bytes=op_class.ifm_bytes,
            num_interpolations=op_class.num_interpolations
        )
        return shape

    def tiler(self, dims_shape: ResizeShape, op_class: ResizeOp) -> ScheduleInputs:
        mapping = ResizeMapping(shape=dims_shape)
        return ScheduleInputs(
            dims_shape,
            mapping,
            op_class.dataflow_type,
            op_class.L2,
            op_class.L3,
            DmaPaddingMap(op_class.pad_value, bytes_to_bits(op_class.ifm_bytes), op_class.is_dma_pad),
        )

    def L2_schedule(self, schedule_input: ScheduleInputs):
        return compile_dataflow(schedule_input)

    def L3_schedule(self, schedule_input: ScheduleInputs):
        return compile_dataflow(schedule_input)

    def preproc(self, schedule_input: ScheduleInputs, op_class: ResizeOp):
        shape: ResizeShape = schedule_input.shape
        mapping: ResizeMapping = schedule_input.mapping
        back_end: BackEnd = schedule_input.backend
        read_bins: ReadBins = op_class.read_bins
        wgt_fmt: WGTFormatting = op_class.wgt_fmt
        asm_mode = int(back_end != BackEnd.Adf)
        cfg = {
            "AIE_COLS": mapping.aie_cols,
            "AIE_ROWS": mapping.aie_rows,
            "H_IN": mapping.Yi,
            "W_IN": mapping.Xi,
            "C_IN": mapping.Ci,
            "SCALE_Y": shape.num_interpolations[0],
            "SCALE_X": shape.num_interpolations[1],
            "H_OUT": mapping.Yo,
            "W_OUT": mapping.Xo,
            "C_OUT": mapping.Co,
            "ASM_MODE": asm_mode,
            "INT_16": 1,
            "BFLOAT_16": 0,
            "READ_IFM": read_bins.read_ifm,
            "READ_WGT": read_bins.read_wgt,
            "NODE_NAME": wgt_fmt.node_name,
            "MD_PATH": wgt_fmt.model_data_path,
            "READ_MD": int(wgt_fmt.read_model_data),
        }
        save_cfg_json(cfg, "resize_cfg.json")


def get_op():
    """API exposed by Resize for build.py"""
    return ResizeBuild()
