"""Build script for Broadcast operator"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Dict, List

from scheduler.common import L3Alloc
from scheduler.broadcast.folding import fold_bdcast
from scheduler.broadcast.broadcast import compile_L2_dataflow, compile_L3_dataflow

from utils.utils_common import L2Alloc, BaseShape, iceil, log

from tiler.broadcast_tiler import BroadcastShape, generate_broadcast_mappings, broadcast_mapping_key

from dmacompiler import (
    BackEnd,
    DevGen,
    set_dev_gen,
    config,
    DmaPaddingMap
)

from buildscripts.common import (
    OperatorsRegistry,
    save_cfg_json,
    normalize_shape,
    get_kernel_id,
    ScheduleInputs,
    BaseOp,
    OpBuild,
    WGTFormatting,
    dtype_info,
    ReadBins,
    OpRegistryGroupKey,
)

set_dev_gen(DevGen.Aie4)


class BroadcastOpType(Enum):
    """Define data types supported by broadcast ops. Used to unify host code."""

    ADD = 0
    SUB = 1
    MUL = 2
    DIV = 3


def op_type(name: str):
    """Function to get op_type"""
    # Extract the operation from the name (e.g., "Add_..." -> "Add")
    op_str = name.split("_")[0]
    try:
        return BroadcastOpType[op_str.upper()]
    except KeyError as exc:
        raise ValueError(f"Could not determine op type for {name}") from exc


def dtype_act(disable_dq0: int, _bytes: int, op_name: str) -> str:
    """Returns type of activation element given disable_dq0 flag, byte size, and operator name. """
    is_signed = "uint" not in op_name
    if _bytes == 1:
        return "int8" if is_signed else "uint8"
    assert _bytes == 2
    if not disable_dq0:
        return "int16" if is_signed else "uint16"
    return "bfloat16"


def dtype_ofm(disable_q: int, _bytes: int, op_name: str) -> str:
    """Returns type of output element given disable_q flag, byte size, and operator name. """
    is_signed = "uint" not in op_name
    if _bytes == 1:
        return "int8" if is_signed else "uint8"
    assert _bytes == 2
    if not disable_q:
        return "int16" if is_signed else "uint16"
    return "bfloat16"


ADD8_KERNEL_NAMES = [
    "Add_qdq_BroadCast_uint8xuint8xuint8",
    "Add_qdq_BroadCast_int8xint8xint8",
    "Sub_qdq_BroadCast_uint8xuint8xuint8",
    "Sub_qdq_BroadCast_int8xint8xint8",
    "Add_qdq_EleWise_uint8xuint8xuint8",
    "Add_qdq_EleWise_int8xint8xint8",
    "Sub_qdq_EleWise_uint8xuint8xuint8",
    "Sub_qdq_EleWise_int8xint8xint8",
]
ADD16_KERNEL_NAMES = [
    "Add_qdq_BroadCast_uint16xuint16xuint16",
    "Add_qdq_BroadCast_int16xint16xint16",
    "Sub_qdq_BroadCast_uint16xuint16xuint16",
    "Sub_qdq_BroadCast_int16xint16xint16",
    "Add_qdq_EleWise_uint16xuint16xuint16",
    "Add_qdq_EleWise_int16xint16xint16",
    "Sub_qdq_EleWise_uint16xuint16xuint16",
    "Sub_qdq_EleWise_int16xint16xint16",
]


def register() -> None:
    """Operator registry for bdcastadd int8"""
    OperatorsRegistry.add_operator(
        ADD8_KERNEL_NAMES,
        {
            "testbench": ["broadcast.cpp", "broadcast.hpp"],
            "dataflow_script": "broadcast.py",
            "build_script": "build_broadcast.py",
            "kernel_names": {"run_bdcastadd_16": get_kernel_id("run_bdcastadd_16")},
            "kernel_includes": ["super.hh", "broadcast/run_bdcastadd_wrapper.cc"],
        },
        group_key=OpRegistryGroupKey.BDCAST_ADD_A8.value
    )
    OperatorsRegistry.add_operator(
        ADD16_KERNEL_NAMES,
        {
            "testbench": ["broadcast.cpp", "broadcast.hpp"],
            "dataflow_script": "broadcast.py",
            "build_script": "build_broadcast.py",
            "kernel_names": {"run_bdcastadd_16": get_kernel_id("run_bdcastadd_16")},
            "kernel_includes": [
                "super.hh",
                "broadcast/run_bdcastadd_wrapper.cc",
                "q/q_impl.hpp",
                "dq/dq_impl.hpp",
            ],
        },
        group_key=OpRegistryGroupKey.BDCAST_ADD_A16.value
    )

    OperatorsRegistry.add_operator(
        ["Mul_qdq_BroadCast_uint8xuint8xuint8", "Mul_qdq_BroadCast_int8xint8xint8", "Mul_qdq_EleWise_uint8xuint8xuint8", "Mul_qdq_EleWise_int8xint8xint8"],
        {
            "testbench": ["broadcast.cpp", "broadcast.hpp"],
            "dataflow_script": "broadcast.py",
            "build_script": "build_broadcast.py",
            "kernel_names": {"run_bdcastmul_16": get_kernel_id("run_bdcastmul_16")},
            "kernel_includes": ["super.hh", "broadcast/run_bdcastmul_wrapper.cc"],
        },
        group_key=OpRegistryGroupKey.BDCAST_MUL_A8.value
    )
    OperatorsRegistry.add_operator(
        ["Mul_qdq_BroadCast_uint16xuint16xuint16", "Mul_qdq_BroadCast_int16xint16xint16", "Mul_qdq_EleWise_uint16xuint16xuint16", "Mul_qdq_EleWise_int16xint16xint16"],
        {
            "testbench": ["broadcast.cpp", "broadcast.hpp"],
            "dataflow_script": "broadcast.py",
            "build_script": "build_broadcast.py",
            "kernel_names": {"run_bdcastmul_16": get_kernel_id("run_bdcastmul_16")},
            "kernel_includes": [
                "super.hh",
                "broadcast/run_bdcastmul_wrapper.cc",
                "q/q_impl.hpp",
                "dq/dq_impl.hpp",
            ],
        },
        group_key=OpRegistryGroupKey.BDCAST_MUL_A16.value
    )

    OperatorsRegistry.add_operator(
        ["Div_qdq_BroadCast_uint8xuint8xuint8", "Div_qdq_BroadCast_int8xint8xint8"],
        {
            "testbench": ["broadcast.cpp", "broadcast.hpp"],
            "dataflow_script": "broadcast.py",
            "build_script": "build_broadcast.py",
            "kernel_names": {"run_bdcastdiv_16": get_kernel_id("run_bdcastdiv_16")},
            "kernel_includes": ["super.hh", "broadcast/run_bdcastdiv_wrapper.cc"],
        },
        group_key=OpRegistryGroupKey.BDCAST_DIV_A8.value
    )
    OperatorsRegistry.add_operator(
        ["Div_qdq_BroadCast_uint16xuint16xuint16", "Div_qdq_BroadCast_int16xint16xint16"],
        {
            "testbench": ["broadcast.cpp", "broadcast.hpp"],
            "dataflow_script": "broadcast.py",
            "build_script": "build_broadcast.py",
            "kernel_names": {"run_bdcastdiv_16": get_kernel_id("run_bdcastdiv_16")},
            "kernel_includes": [
                "super.hh",
                "broadcast/run_bdcastdiv_wrapper.cc",
                "q/q_impl.hpp",
                "dq/dq_impl.hpp",
            ],
        },
        group_key=OpRegistryGroupKey.BDCAST_DIV_A16.value
    )


@dataclass(frozen=True)
class BroadcastOp(BaseOp):
    """
    Dataclass for elementwise Broadcast ops (Add/Mul qdq variants).
    """

    # Broadcast-specific configuration
    ifm_bytes: int
    ofm_bytes: int
    disable_dq0: int
    disable_q: int
    a_on_wgt: int
    b_on_wgt: int
    debug_mode: int

    # Name of the op variant (used by preproc to derive op_type)
    op_name: str
    has_scalar_broadcast: int


class BroadcastBuild(OpBuild):
    """Broadcast Build Interface for build.py"""

    def default_kernel_names(self) -> Dict[str, int]:
        # Your previous broadcast path used bdcastadd kernels (elementwise int8).
        return {"run_bdcastadd_8": get_kernel_id("run_bdcastadd_8")}

    def default_kernel_includes(self) -> List[str]:
        return ["super.hh", "kernel/run_bdcastadd_wrapper.cc"]

    def op_type(self):
        return BroadcastOp

    def _parse_from_dict(
        self,
        data: dict,
        shim_prm_offset: int,
        shim_wgt_offset: int,
        read_bins: ReadBins,
        read_model_data: bool,
        model_data_path: str,
    ) -> BroadcastOp:

        Nia, Yia, Xia, Cia = normalize_shape(data.get("input0") or data.get("input"))
        Nib, Yib, Xib, Cib = normalize_shape(data.get("input1") or data.get("input"))
        No, Yo, Xo, Co = normalize_shape(data.get("output"))

        # raise error if there are any dims that are not broadcast compatible
        def is_broadcast_compatible(dim_a: int, dim_b: int, dim_o: int) -> bool:
            return (
                dim_a == dim_b == dim_o or
                dim_a == 1 and dim_b == dim_o or
                dim_b == 1 and dim_a == dim_o
            )

        for dim_a, dim_b, dim_o in zip((Nia, Yia, Xia, Cia), (Nib, Yib, Xib, Cib), (No, Yo, Xo, Co)):
            if not is_broadcast_compatible(dim_a, dim_b, dim_o):
                raise ValueError(f"Shapes are not broadcast compatible: {(Nia, Yia, Xia, Cia)}, {(Nib, Yib, Xib, Cib)}, {(No, Yo, Xo, Co)}")

        # Signs/dtypes (broadcast path often unquantized at codegen level)
        dtype_in_a = data.get("in_dtype_A")
        dtype_in_b = data.get("in_dtype_B")
        dtype_out = data.get("out_dtype_C") or data.get("out_dtype_output")
        dtype_A, sign_A = dtype_info(dtype_in_a)
        dtype_W, sign_W = dtype_info(dtype_in_b)
        dtype_O, sign_O = dtype_info(dtype_out)
        assert dtype_A == dtype_W, "Broadcast inputs must have same dtype"
        ifm_bytes = dtype_A // 8
        ofm_bytes = dtype_O // 8

        # Deduce element size and qdq mode
        op_name = data.get("op", "")
        disable_dq0 = data.get("attributes", {}).get("disable_dq0", [0])[0]
        disable_q = data.get("attributes", {}).get("disable_q", [0])[0]
        debug_mode = data.get("debug_mode", 1)

        if ifm_bytes == 1 and disable_dq0:
            raise ValueError("DQ cannot be disabled for int8 input broadcast ops.")
        if ifm_bytes == 1 and disable_q:
            raise ValueError("Q cannot be disabled for int8 output broadcast ops.")
        # L2 addresses (two inputs + ofm)
        ifm_a_tile, ifm_a_off = self._get_tile_offset(data.get("input_addr") or data.get("input0_addr"))
        if data.get("input1_addr"):
            ifm_b_tile, ifm_b_off = self._get_tile_offset(data.get("input1_addr"))
        else:
            # We do NOT use these values in schedule_broadcast, these get overwritten
            # if input1_addr is not provided in the scheduler. These are just placeholders.
            ifm_b_tile, ifm_b_off = self._get_tile_offset({1: 0})
        ofm_tile, ofm_off = self._get_tile_offset(data.get("output_addr"))

        L2_alloc = L2Alloc(
            ifm_L2_loc=[(ifm_a_tile, ifm_a_off), (ifm_b_tile, ifm_b_off)],
            ofm_L2_loc=(ofm_tile, ofm_off),
            wgt_l2_loc=self._get_wgt_addr(data["wgt_addr"]),
            prm_l2_loc=self._get_prm_addr(data["prm_addr"]),
            enable_ifm_fill=bool(data.get("load_input_from_ddr", True)),
            enable_ofm_spill=bool(data.get("store_output_to_ddr", True)),
        )

        # IFM B is padded on L3 in all cases
        ifm_a_channels = iceil(Cia, 64)
        ifm_b_channels = iceil(Cib, 64)
        ifm_a_bytes = Nia * Yia * Xia * ifm_a_channels * ifm_bytes
        ifm_b_bytes = Nib * Yib * Xib * ifm_b_channels * ifm_bytes

        a_on_wgt = data["input_types"]["A"] == "const"
        b_on_wgt = data["input_types"]["B"] == "const"
        is_actxact = not a_on_wgt and not b_on_wgt
        is_actxwgt = a_on_wgt or b_on_wgt
        if is_actxact:
            L3_alloc = L3Alloc(
                ifm=[data["L3"]["ifm0"], data["L3"]["ifm1"]],
                ofm=data["L3"]["ofm"],
                wgt=[2, shim_wgt_offset],
                prm=[3, shim_prm_offset],
            )
        elif is_actxwgt:
            if a_on_wgt:
                ifm_l3_alloc = [[2, shim_wgt_offset, iceil(ifm_a_bytes, 4)], data["L3"]["ifm"]]
                wgt_l3_alloc = [2, shim_wgt_offset + iceil(ifm_a_bytes, 4)]
            else:
                ifm_l3_alloc = [data["L3"]["ifm"], [2, shim_wgt_offset, iceil(ifm_b_bytes, 4)]]
                wgt_l3_alloc = [2, shim_wgt_offset + iceil(ifm_b_bytes, 4)]
            L3_alloc = L3Alloc(
                ifm=ifm_l3_alloc,
                ofm=data["L3"]["ofm"],
                wgt=wgt_l3_alloc,
                prm=[3, shim_prm_offset],
            )
        else:
            raise RuntimeError("Invalid L3 allocation for broadcast op.")

        dataflow_type = not data.get("enable_L2_fusion", False)

        wgt_fmt = WGTFormatting(
            node_name=data.get("name") or data.get("op"),
            model_data_path=model_data_path,
            read_model_data=read_model_data,
            # we (you) will probably need to change this
            dtype_act=dtype_in_a,
            dtype_ofm=dtype_out,
        )

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        has_scalar_broadcast = Cib == 1 and not Co == 1

        return BroadcastOp(
            # base
            Ni=[Nia, Nib],
            Yi=[Yia, Yib],
            Xi=[Xia, Xib],
            Ci=[Cia, Cib],
            No=No,
            Yo=Yo,
            Xo=Xo,
            Co=Co,
            L2=L2_alloc,
            L3=L3_alloc,
            dataflow_type=int(dataflow_type),
            read_bins=read_bins,
            sign_A=int(sign_A),
            sign_W=int(sign_W),
            sign_O=int(sign_O),
            dtype_A=int(dtype_A),
            dtype_W=int(dtype_W),
            dtype_O=int(dtype_O),
            wgt_fmt=wgt_fmt,
            pad_value=pad_value,
            is_dma_pad=is_dma_pad,
            # broadcast-specific
            disable_dq0=int(disable_dq0),
            disable_q=int(disable_q),
            a_on_wgt=int(a_on_wgt),
            b_on_wgt=int(b_on_wgt),
            ifm_bytes=ifm_bytes,
            ofm_bytes=ofm_bytes,
            debug_mode=int(debug_mode),
            op_name=str(op_name),
            has_scalar_broadcast=int(has_scalar_broadcast),
        )

    def shape(self, op_class: BroadcastOp) -> BaseShape:
        Nia, Nib = op_class.Ni
        Yia, Yib = op_class.Yi
        Xia, Xib = op_class.Xi
        Cia, Cib = op_class.Ci
        No, Yo, Xo, Co = op_class.No, op_class.Yo, op_class.Xo, op_class.Co

        # Pad channels
        Co = iceil(Co, 64)
        Cia = iceil(Cia, 64)
        Cib = 1 if op_class.has_scalar_broadcast else iceil(Cib, 64)

        # compute total size of ifm a + ifm b + ofm with padding
        ofm_size_in_bytes = No * Yo * Xo * Co * op_class.ofm_bytes
        total_ifm_size_in_bytes = (
            Nia * Yia * Xia * Cia * op_class.ifm_bytes +
            Nib * Yib * Xib * Cib * op_class.ifm_bytes
        )
        total_data_size_in_bytes = ofm_size_in_bytes + total_ifm_size_in_bytes
        if op_class.dataflow_type == 0 and total_data_size_in_bytes > config.MAX_MEMTILE_ADDR * config.NUM_AIE_COLS:
            raise ValueError(
                f"Total data size (ifm a + ifm b + ofm) {total_data_size_in_bytes} exceeds "
                f"available memory on L2 {config.MAX_MEMTILE_ADDR * config.NUM_AIE_COLS}. Only L3 dataflow is supported."

            )

        (Nia, Yia, Xia, Cia), (Nib, Yib, Xib, Cib), (No, Yo, Xo, Co) = fold_bdcast(
            (Nia, Yia, Xia, Cia), (Nib, Yib, Xib, Cib), (No, Yo, Xo, Co)
        )
        log(f"Broadcast shape after folding step: {(Nia, Yia, Xia, Cia)}, {(Nib, Yib, Xib, Cib)}, {(No, Yo, Xo, Co)}")

        # assert that total ifm size is the same after folding
        assert total_ifm_size_in_bytes == (
            Nia * Yia * Xia * Cia * op_class.ifm_bytes +
            Nib * Yib * Xib * Cib * op_class.ifm_bytes
        )
        assert ofm_size_in_bytes == No * Yo * Xo * Co * op_class.ofm_bytes

        # confirm shape is a valid broadcast shape
        for dim_a, dim_b, dim_o in zip((Nia, Yia, Xia, Cia), (Nib, Yib, Xib, Cib), (No, Yo, Xo, Co)):
            if not (
                dim_a == dim_b == dim_o or
                dim_a == 1 and dim_b == dim_o or
                dim_b == 1 and dim_a == dim_o
            ):
                raise ValueError(f"Shapes are not broadcast compatible: {(Nia, Yia, Xia, Cia)}, {(Nib, Yib, Xib, Cib)}, {(No, Yo, Xo, Co)}")

        return BroadcastShape(
            ifm=[(Yia, Xia, Cia,), (Yib, Xib, Cib), ],
            ofm=(Yo, Xo, Co,),
            ifm_bytes=op_class.ifm_bytes,
            ofm_bytes=op_class.ofm_bytes,
            ifm_bits=op_class.ifm_bytes * 8,
            wgt_bits=op_class.ifm_bytes * 8,
            ofm_bits=op_class.ofm_bytes * 8,
            bias_bits=0,
            sign_A=op_class.sign_A,
            sign_W=op_class.sign_W,
            sign_O=op_class.sign_O,
            op_name=op_class.op_name,
            b_on_wgt=op_class.b_on_wgt,
            call_kernel=list(OperatorsRegistry.get_operator(op_class.op_name)["kernel_names"].keys())[0],
            has_scalar_broadcast=op_class.has_scalar_broadcast,
        )

    def tiler(
        self,
        dims_shape: BroadcastShape,
        op_class: BroadcastOp,
    ) -> ScheduleInputs:
        enable_over_compute, kernel_gran, kernel_loop_range = True, 64, 8
        mapping = generate_broadcast_mappings(dims_shape, enable_over_compute, kernel_gran, kernel_loop_range)[0]
        log(f"TILER: key: {broadcast_mapping_key(mapping, dims_shape)} mapping: {mapping}")

        return ScheduleInputs(
            shape=dims_shape,
            mapping=mapping,
            dataflow_type=op_class.dataflow_type,
            L2_alloc=op_class.L2,
            L3_alloc=op_class.L3,
            dma_pad=DmaPaddingMap(op_class.pad_value, dims_shape.ifm_bits, op_class.is_dma_pad)
        )

    def L2_schedule(self, schedule_input: ScheduleInputs):
        return compile_L2_dataflow(schedule_input)

    def L3_schedule(self, schedule_input: ScheduleInputs):
        return compile_L3_dataflow(schedule_input)

    def preproc(self, schedule_input: ScheduleInputs, op_class: BroadcastOp):
        back_end: BackEnd = schedule_input.backend
        read_bins: ReadBins = op_class.read_bins
        op_type_enum: BroadcastOpType = op_type(op_class.op_name)
        wgt_fmt: WGTFormatting = op_class.wgt_fmt
        asm_mode = int(back_end != BackEnd.Adf)
        QDQ_PRM_SIZE = 128  # initial qdq parameters
        DQ_BUF_SIZE = Q_BUF_SIZE = 512
        # qdq params pad the first 128 bytes of the wgt buffer before the 1024 of scratch space
        q_buf_offset = QDQ_PRM_SIZE
        dq_buf_offset = q_buf_offset + Q_BUF_SIZE
        wgt_size = QDQ_PRM_SIZE + DQ_BUF_SIZE + Q_BUF_SIZE
        # Ci/Co is unpadded host does padding on its own. Padding is added
        # to BroadcastShape for scheduler/tiler
        cfg = {
            "AIE_COLS": config.NUM_AIE_COLS,
            "AIE_ROWS": config.NUM_AIE_ROWS,
            "N_IN_A": op_class.Ni[0],
            "X_IN_A_ORIG": op_class.Xi[0],
            "X_IN_A": op_class.Xi[0],
            "Y_IN_A_ORIG": op_class.Yi[0],
            "Y_IN_A": op_class.Yi[0],
            "C_IN_A_ORIG": op_class.Ci[0],
            "C_IN_A": iceil(op_class.Ci[0], 64),
            "N_IN_B": op_class.Ni[1],
            "X_IN_B_ORIG": op_class.Xi[1],
            "X_IN_B": op_class.Xi[1],
            "Y_IN_B_ORIG": op_class.Yi[1],
            "Y_IN_B": op_class.Yi[1],
            "C_IN_B_ORIG": op_class.Ci[1],
            "C_IN_B": iceil(op_class.Ci[1], 64),
            "N_OUT": op_class.No,
            "X_OUT_ORIG": op_class.Xo,
            "X_OUT": op_class.Xo,
            "Y_OUT_ORIG": op_class.Yo,
            "Y_OUT": op_class.Yo,
            "C_OUT_ORIG": op_class.Co,
            "C_OUT": iceil(op_class.Co, 64),
            "IFM_BYTES": schedule_input.shape.ifm_bytes,
            "OFM_BYTES": schedule_input.shape.ofm_bytes,
            "WGT_SIZE": wgt_size,
            "ASM_MODE": asm_mode,
            "READ_IFM": read_bins.read_ifm,
            "READ_WGT": read_bins.read_wgt,
            "DISABLE_DQ0": op_class.disable_dq0,
            "DISABLE_Q": op_class.disable_q,
            "A_ON_WGT": op_class.a_on_wgt,
            "B_ON_WGT": op_class.b_on_wgt,
            "DQ_BUF_OFFSET": dq_buf_offset,
            "Q_BUF_OFFSET": q_buf_offset,
            "OP_TYPE": op_type_enum.value,
            "NODE_NAME": wgt_fmt.node_name,
            "MD_PATH": wgt_fmt.model_data_path,
            "READ_MD": int(wgt_fmt.read_model_data),
            "DTYPE_ACT": dtype_act(
                op_class.disable_dq0, schedule_input.shape.ifm_bytes, op_class.op_name
            ),
            "DTYPE_OFM": dtype_ofm(
                op_class.disable_q, schedule_input.shape.ofm_bytes, op_class.op_name
            ),
            "SIGN_A": op_class.sign_A,
            "SIGN_W": op_class.sign_W,
            "SIGN_O": op_class.sign_O,
            "DEBUG_MODE": op_class.debug_mode,
            "HAS_SCALAR_BROADCAST": op_class.has_scalar_broadcast,
        }
        save_cfg_json(cfg, "broadcast_cfg.json")


def get_op():
    """API exposed by Broadcast for build.py"""
    return BroadcastBuild()
