"""Build script for Binary operator"""

from __future__ import annotations

from dataclasses import dataclass
from enum import Enum
from typing import Dict, List

from buildscripts.common import (
    BaseOp,
    OpBuild,
    OperatorsRegistry,
    ScheduleInputs,
    WGTFormatting,
    get_kernel_id,
    normalize_shape,
    save_cfg_json,
    dtype_info,
    bytes_to_bits,
    bits_to_bytes,
    # OpRegistryGroupKey,
)
from dmacompiler import (
    BackEnd,
    DevGen,
    set_dev_gen,
    DmaPaddingMap,
)
from scheduler.binary.binary import BinaryMapping, compile_dataflow
from scheduler.common import (
    L3Alloc,
)
from tiler.binary_tiler import BinaryL2Dims, BinaryShape
from utils.utils_common import L2Alloc, log, ReadBins, iceil

set_dev_gen(DevGen.Aie4)


class BinaryOpType(Enum):
    """Define data types supported by binary ops. Used to unify host code."""

    ADD = 0
    SUB = 1
    MUL = 2
    DIV = 3


def op_type(name: str):
    """Function to get op_type"""
    if name in ("Add_qdq_EleWise_uint8xuint8xuint8", "Add_qdq_EleWise_uint16xuint16xuint16", "Add"):
        return BinaryOpType.ADD
    if name in ("Mul_qdq_EleWise_uint16xuint16xuint16"):
        return BinaryOpType.MUL
    raise ValueError(f"Could not determine op type for {name}")


def register() -> None:
    """
    Operator registry for element wise binary op.
    All Element Wise Binary Operators are being compiled via Broadcast dataflow now.
    Build System checks that all build<op>.py files have a `register()` function or raises
    warnings if otherwise. The `pass` in this function helps to get around this warning.
    """
    pass
#     OperatorsRegistry.add_operator(
#         ["Add_qdq_EleWise_uint8xuint8xuint8", "Add"],
#         {
#             "testbench": ["binary.cpp", "binary.hpp"],
#             "dataflow_script": "binary.py",
#             "build_script": "build_binary.py",
#             "kernel_names": {"run_matadd_int8": get_kernel_id("run_matadd_int8")},
#             "kernel_includes": ["super.hh", "binary/run_matadd_int8_wrapper.cc"],
#         },
#     )
#     OperatorsRegistry.add_operator(
#         ["Add_qdq_EleWise_uint16xuint16xuint16"],
#         {
#             "testbench": ["binary.cpp", "binary.hpp"],
#             "dataflow_script": "binary.py",
#             "build_script": "build_binary.py",
#             "kernel_names": {"run_matadd_16": get_kernel_id("run_matadd_16")},
#             "kernel_includes": ["super.hh", "binary/run_matadd_wrapper.cc", "q/q_impl.hpp", "dq/dq_impl.hpp"],
#         },
#     )

#     OperatorsRegistry.add_operator(
#         ["Mul_qdq_EleWise_uint16xuint16xuint16"],
#         {
#             "testbench": ["binary.cpp", "binary.hpp"],
#             "dataflow_script": "binary.py",
#             "build_script": "build_binary.py",
#             "kernel_names": {"run_mul_16": get_kernel_id("run_mul_16")},
#             "kernel_includes": ["super.hh", "binary/run_mul_wrapper.cc", "q/q_impl.hpp", "dq/dq_impl.hpp"],
#         },
#         group_key=OpRegistryGroupKey.MUL.value
#     )


@dataclass(frozen=True)
class BinaryOp(BaseOp):
    """
    Dataclass for elementwise Binary ops (Add/Mul qdq variants).
    """
    Cpad: int
    ifm_bytes: int

    # Name of the op variant (used by preproc to derive op_type)
    op_name: str

    # dq_enable/ q_enable flag
    dq_enable: int
    q_enable: int
    b_on_wgt: int
    a_on_wgt: int


class BinaryBuild(OpBuild):
    """Binary Build Interface for build.py"""

    def default_kernel_names(self) -> Dict[str, int]:
        # Your previous binary path used matadd kernels (elementwise int8).
        return {"run_matadd_int8": get_kernel_id("run_matadd_int8")}

    def default_kernel_includes(self) -> List[str]:
        return ["super.hh", "kernel/run_matadd_wrapper.cc"]

    def op_type(self):
        return BinaryOp

    def _parse_from_dict(self, data: dict, shim_prm_offset: int, shim_wgt_offset: int, read_bins: ReadBins, read_model_data: bool, model_data_path: str) -> BinaryOp:
        Ni, Yi, Xi, Ci = normalize_shape(data.get("input") or data.get("input0"))
        No, Yo, Xo, Co = normalize_shape(data.get("output"))
        if "input1" in data:
            assert normalize_shape(data.get("input1")) == (Ni, Yi, Xi, Ci), "For elementwise binary ops, both input shapes must be identical"
        _ = (Ni, No)

        op_name = data.get("op", "")

        dequant_enable = not int(data["attributes"]["disable_dq0"][0])
        quant_enable = not int(data["attributes"]["disable_q"][0])

        # Signs/dtypes (binary path often unquantized at codegen level)
        dtype_in = data.get("in_dtype_A") or data.get("in_dtype_B")
        dtype_out = data.get("out_dtype_C") or data.get("out_dtype_output")
        dtype_A, sign_A = dtype_info(dtype_in)
        dtype_W, sign_W = dtype_info(dtype_in)
        dtype_O, sign_O = dtype_info(dtype_out)

        ifm_bytes = dtype_A // 8
        # L2 addresses (two inputs + ofm)
        ifm_a_tile, ifm_a_off = self._get_tile_offset(data.get("input_addr") or data.get("input0_addr"))
        if data.get("input1_addr"):
            ifm_b_tile, ifm_b_off = self._get_tile_offset(data["input1_addr"])
        else:
            # We do NOT use these values in schedule_binary, these get overwritten
            # if input1_addr is not provided in the scheduler. These are just placeholders.
            ifm_b_tile, ifm_b_off = self._get_tile_offset({1: 0})
        ofm_tile, ofm_off = self._get_tile_offset(data["output_addr"])

        L2_alloc = L2Alloc(
            ifm_L2_loc=[(ifm_a_tile, ifm_a_off), (ifm_b_tile, ifm_b_off)],
            ofm_L2_loc=(ofm_tile, ofm_off),
            wgt_l2_loc=self._get_wgt_addr(data["wgt_addr"]),
            prm_l2_loc=self._get_prm_addr(data["prm_addr"]),
            enable_ifm_fill=bool(data.get("load_input_from_ddr", True)),
            enable_ofm_spill=bool(data.get("store_output_to_ddr", True)),
            enable_L2_fusion=bool(data.get("enable_L2_fusion", False)),
        )

        # IFM B is padded on L3 in all cases
        ifm_channels = iceil(Ci, 64)
        total_ifm_bytes = Ni * Yi * Xi * ifm_channels * ifm_bytes

        a_on_wgt = data["input_types"]["A"] == "const"
        b_on_wgt = data["input_types"]["B"] == "const"
        is_actxact = not a_on_wgt and not b_on_wgt
        is_actxwgt = a_on_wgt or b_on_wgt
        if is_actxact:
            L3_alloc = L3Alloc(
                ifm=[data["L3"]["ifm0"], data["L3"]["ifm1"]],
                ofm=data["L3"]["ofm"],
                wgt=[2, shim_wgt_offset],
                prm=[3, shim_prm_offset],
            )
        elif is_actxwgt:
            if a_on_wgt:
                ifm_l3_alloc = [[2, shim_wgt_offset, iceil(total_ifm_bytes, 4)], data["L3"]["ifm"]]
                wgt_l3_alloc = [2, shim_wgt_offset + iceil(total_ifm_bytes, 4)]
            else:
                ifm_l3_alloc = [data["L3"]["ifm"], [2, shim_wgt_offset, iceil(total_ifm_bytes, 4)]]
                wgt_l3_alloc = [2, shim_wgt_offset + iceil(total_ifm_bytes, 4)]
            L3_alloc = L3Alloc(
                ifm=ifm_l3_alloc,
                ofm=data["L3"]["ofm"],
                wgt=wgt_l3_alloc,
                prm=[3, shim_prm_offset],
            )
        else:
            raise RuntimeError("Invalid L3 allocation for broadcast op.")

        # Binary flow uses unified compile; keep L2 path by default
        dataflow_type = 0 if data.get("enable_L2_fusion", False) else 1
        wgt_fmt = WGTFormatting(node_name=data.get("name") or data.get("op"),
                                model_data_path=model_data_path,
                                read_model_data=read_model_data,
                                dtype_act=dtype_in,
                                dtype_ofm=dtype_out)

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        return BinaryOp(
            # base
            Ni=Ni,
            Yi=Yi,
            Xi=Xi,
            Ci=Ci,
            No=No,
            Yo=Yo,
            Xo=Xo,
            Co=Co,
            L2=L2_alloc,
            L3=L3_alloc,
            dataflow_type=int(dataflow_type),
            read_bins=read_bins,
            sign_A=int(sign_A),
            sign_W=int(sign_W),
            sign_O=int(sign_O),
            dtype_A=bits_to_bytes(dtype_A),
            dtype_W=bits_to_bytes(dtype_W),
            dtype_O=bits_to_bytes(dtype_O),
            wgt_fmt=wgt_fmt,
            pad_value=pad_value, is_dma_pad=is_dma_pad,
            # binary-specific
            Cpad=iceil(Ci, 64),
            ifm_bytes=bits_to_bytes(dtype_A),
            op_name=str(op_name),
            dq_enable=dequant_enable,
            q_enable=quant_enable,
            a_on_wgt=int(a_on_wgt),
            b_on_wgt=int(b_on_wgt),
        )

    def shape(self, op_class: BinaryOp) -> BinaryL2Dims:
        # Binary dims is the “shape” your mapping expects
        shape = BinaryShape(
            Ci=op_class.Cpad,
            Yi=op_class.Yi,
            Xi=op_class.Xi,
            Co=op_class.Co,
            Yo=op_class.Yo,
            Xo=op_class.Xo,
            ifm_bytes=op_class.ifm_bytes,
        )
        dims = BinaryL2Dims(
            shape,
            q_enable=op_class.q_enable,
            dq_enable=op_class.dq_enable,
            call_kernel=list(OperatorsRegistry.get_operator(op_class.op_name)["kernel_names"].keys())[0],
            b_on_wgt=op_class.b_on_wgt,
        )
        log(f"binary dims: {dims}")
        return dims

    def tiler(self, dims_shape: BinaryL2Dims, op_class: BinaryOp) -> ScheduleInputs:
        mapping = BinaryMapping(dims_shape)
        return ScheduleInputs(
            dims_shape,
            mapping,
            op_class.dataflow_type,
            op_class.L2,
            op_class.L3,
            DmaPaddingMap(op_class.pad_value, bytes_to_bits(op_class.dtype_A), op_class.is_dma_pad),
        )

    def L2_schedule(self, schedule_input: ScheduleInputs):
        return compile_dataflow(schedule_input)

    def L3_schedule(self, schedule_input: ScheduleInputs):
        return compile_dataflow(schedule_input)

    def preproc(self, schedule_input: ScheduleInputs, op_class: BinaryOp):
        dims: BinaryL2Dims = schedule_input.shape
        back_end: BackEnd = schedule_input.backend
        read_bins: ReadBins = op_class.read_bins
        op_type_enum: BinaryOpType = op_type(op_class.op_name)
        wgt_fmt: WGTFormatting = op_class.wgt_fmt
        asm_mode = int(back_end != BackEnd.Adf)
        cfg = {
            "AIE_COLS": dims.aie_cols,
            "AIE_ROWS": dims.aie_rows,
            "IFM_CHS": op_class.Ci,
            "IFM_CHS_PAD": op_class.Cpad,
            "IFM_ROWS": dims.shape.Yi,
            "IFM_COLS": dims.shape.Xi,
            "IFM_BYTES": dims.shape.ifm_bytes,
            "OFM_CHS": dims.shape.Co,
            "OFM_ROWS": dims.shape.Yo,
            "OFM_COLS": dims.shape.Xo,
            "OFM_BYTES": dims.ofm_bytes,
            "WGT_SIZE": dims.wgt_size,
            "ASM_MODE": asm_mode,
            "READ_IFM": read_bins.read_ifm,
            "READ_WGT": read_bins.read_wgt,
            "OP_TYPE": op_type_enum.value,
            "NODE_NAME": wgt_fmt.node_name,
            "MD_PATH": wgt_fmt.model_data_path,
            "READ_MD": int(wgt_fmt.read_model_data),
            "DTYPE_ACT": wgt_fmt.dtype_act,
            "DTYPE_OFM": wgt_fmt.dtype_ofm,
            "Q_ENABLE": int(op_class.q_enable),
            "DQ_ENABLE": int(op_class.dq_enable),
            "B_ON_WGT": int(op_class.b_on_wgt),
            "A_ON_WGT": int(op_class.a_on_wgt),
        }
        save_cfg_json(cfg, "binary_cfg.json")


def get_op():
    """API exposed by Binary for build.py"""
    return BinaryBuild()
