"""Build script for Gemm operator"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Dict

from buildscripts.common import (
    BaseOp,
    OpBuild,
    OperatorsRegistry,
    ScheduleInputs,
    WGTFormatting,
    get_kernel_id,
    normalize_shape,
    save_cfg_json,
    dtype_info,
    DTYPE_C0,
    OpRegistryGroupKey,
)
from dmacompiler import BackEnd, DevGen, set_dev_gen, DmaPaddingMap
from scheduler.common import (
    L3Alloc,
    LinearOpType,
)
from scheduler.conv.conv_config_builders import (
    ActMode,
    ConvMapping,
    ConvShape,
)
from scheduler.conv.conv_L2_schedule import compile_L2_dataflow
from scheduler.conv.conv_L3_schedule import compile_L3_dataflow
from tiler.gemm_tiler import generate_gemm_mappings
from utils.utils_common import L2Alloc, log, is_log_enabled, ReadBins

set_dev_gen(DevGen.Aie4)


def register() -> None:
    """Adding `gemm_a16w8_qdq` operator that has multiple flavors"""
    OperatorsRegistry.add_operator(
        ["MatMul_qdq_int16xint8xint16", "MatMul_qdq_uint16xint8xuint16", "MatMul_qdq_uint16xuint8xuint16", "MatMul_qdq_bias_uint16xuint8xuint16"],
        {
            "testbench": ["gemm.cpp", "gemm.hpp"],
            "dataflow_script": "conv_L3_schedule.py",
            "build_script": "build_gemm.py",
            "kernel_names": {"run_gemm_int16x8": get_kernel_id("run_gemm_int16x8")},
            "kernel_includes": ["super.hh", "gemm_qdq_int16x8/gemm_int16x8_wrapper.cc"],
        },
        group_key=OpRegistryGroupKey.MATMUL_A16W8.value
    )

    OperatorsRegistry.add_operator(
        [
            "MatMul_qdq_uint16xint4xuint16",
            "MatMul_qdq_uint16xuint4xuint16",
            "MatMul_qdq_int16xint4xint16",
            "MatMul_qdq_bias_uint16xint4xuint16",
            "MatMul_qdq_bias_uint16xuint4xuint16",
            "MatMul_qdq_bias_int16xint4xint16",
        ],
        {
            "testbench": ["gemm.cpp", "gemm.hpp"],
            "dataflow_script": "conv_L3_schedule.py",
            "build_script": "build_gemm.py",
            "kernel_names": {"run_gemm_int16x4": get_kernel_id("run_gemm_int16x4")},
            "kernel_includes": ["super.hh", "gemm_qdq_int16x4/gemm_int16x4_wrapper.cc"],
        },
        group_key=OpRegistryGroupKey.MATMUL_A16W4.value
    )

    OperatorsRegistry.add_operator(
        ["MatMul_qdq_actxact_uint16xuint16xuint16", "MatMul_qdq_actxact_int16xint16xint16", "MatMul_qdq_actxact_Transpose_uint16xuint16xuint16", "MatMul_qdq_actxact_Transpose_int16xint16xint16"],
        {
            "testbench": ["gemm_act.cpp", "gemm_act.hpp"],
            "dataflow_script": "conv_L3_schedule.py",
            "build_script": "build_gemm.py",
            "kernel_names": {"run_gemm_int16x16_transpose": get_kernel_id("run_gemm_int16x16_transpose")},
            "kernel_includes": ["super.hh", "gemm_qdq_int16x16_transpose/gemm_int16x16_transpose_wrapper.cc"],
        },
        group_key=OpRegistryGroupKey.ACTXACT_A16.value
    )


@dataclass(frozen=True)
class GeMMOp(BaseOp):
    """BaseOp for GeMM"""

    # Tiling/mapping selection
    MappingRank: int = 0

    # Op hyper-params
    Ky: int = 1
    Kx: int = 1
    Sy: int = 1
    Sx: int = 1
    Py: int = 0
    Px: int = 0

    # Padding for input/output
    Yip: int = 0
    Xip: int = 0
    Cip: int = 0
    Yop: int = 0
    Xop: int = 0
    Cop: int = 0

    dtype_Bias: int = 0
    debug_mode: int = 0
    linear_op_type: LinearOpType = LinearOpType.gemm_A16W8_qdq

    qdq: int = 0

    # Activation mode and compute policy
    act_type: ActMode = ActMode(0)
    enable_over_compute: bool = False

    # GEMM-specific knobs
    vector_coeff: int = 0

    # Preproc knobs
    out_shift: int = 0
    bias_shift: int = 0
    # Whether to transpose weights for both dataflow and kernel
    transpose_wgts: int = 1

    # Note: this dict is used to map exception cases to specific M, K, N values
    exception_mapping = {
        # key: (M, K, N): Value is a list of splits to skip
        # (1024, 640, 640): [(1, 1, 12, 1)]
    }


class GemmBuild(OpBuild):
    """GeMM Build Interface for build.py"""
    LINEAR_OP_TYPE_MAP = {
        (8, 8, 8, 32): LinearOpType.gemm_A8W8_qdq,
        (16, 8, 16, 32): LinearOpType.gemm_A16W8_qdq,
        (16, 4, 16, 32): LinearOpType.gemm_A16W4_qdq,
        # (16, 16, 16, 32): LinearOpType.gemm_A16A16_v1,
        (16, 16, 16, 32): LinearOpType.gemm_A16A16_v2,
        (8, 8, 8, 0): LinearOpType.gemm_A8W8_qdq,
        (16, 8, 16, 0): LinearOpType.gemm_A16W8_qdq,
        (16, 4, 16, 0): LinearOpType.gemm_A16W4_qdq,
        # (16, 16, 16, 0): LinearOpType.gemm_A16A16_v1,
        (16, 16, 16, 0): LinearOpType.gemm_A16A16_v2,
    }

    def _get_linear_op_type(self, dtype_A: int, dtype_W: int, dtype_O: int, dtype_Bias: int) -> LinearOpType:
        """Determine LinearOpType based on dtype combination"""
        dtype_combo = (dtype_A, dtype_W, dtype_O, dtype_Bias)

        if dtype_combo in self.LINEAR_OP_TYPE_MAP:
            return self.LINEAR_OP_TYPE_MAP[dtype_combo]
        raise ValueError(f"Unsupported combination of dtypes GEMM: A={dtype_A}, W={dtype_W}, O={dtype_O}, Bias={dtype_Bias}")

    def default_kernel_names(self) -> Dict[str, int]:
        return {"run_gemm_int16x8": get_kernel_id("run_gemm_int16x8")}

    def default_kernel_includes(self) -> list[str]:
        return ["super.hh", "gemm_qdq_int16x8/gemm_int16x8_wrapper.cc"]

    def op_type(self):
        return GeMMOp

    def _parse_from_dict(
        self,
        data: dict,
        shim_prm_offset: int,
        shim_wgt_offset: int,
        read_bins: ReadBins,
        read_model_data: bool,
        model_data_path: str,
    ) -> GeMMOp:
        op_name = data["op"]
        if op_name in ("MatMul_qdq_actxact_uint16xuint16xuint16", "MatMul_qdq_actxact_Transpose_uint16xuint16xuint16",
                       "MatMul_qdq_actxact_int16xint16xint16", "MatMul_qdq_actxact_Transpose_int16xint16xint16"):
            Ni, Yi, Xi, Ci = normalize_shape(data.get("input0"))
            No, Yo, Xo, Co = normalize_shape(data.get("output"))
            Nip, Yip, Xip, Cip = normalize_shape(data.get("padded_input0"))
            Nop, Yop, Xop, Cop = normalize_shape(data.get("padded_output"))
            _ = (Ni, No)
            _ = (Nip, Nop)
            ifm_tile, ifm_off = self._get_tile_offset(data["input0_addr"])
            ofm_tile, ofm_off = self._get_tile_offset(data["output_addr"])
            L3_ifm = [data["L3"]["ifm0"], data["L3"]["ifm1"]]
        else:
            Ni, Yi, Xi, Ci = normalize_shape(data.get("input"))
            assert Ni == 1, "ActXWgt GEMM Op requires Ni=1"
            assert Yi == 1, "ActXWgt GEMM Op requires Yi=1"
            No, Yo, Xo, Co = normalize_shape(data.get("output"))
            Nip, Yip, Xip, Cip = normalize_shape(data.get("padded_input"))
            Nop, Yop, Xop, Cop = normalize_shape(data.get("padded_output"))
            _ = (Ni, No)
            _ = (Nip, Nop)
            ifm_tile, ifm_off = self._get_tile_offset(data["input_addr"])
            ofm_tile, ofm_off = self._get_tile_offset(data["output_addr"])
            L3_ifm = data["L3"].get("ifm", data["L3"].get("ifm0"))

        shape_1 = (Yi == 16 and Xi == 32 and Ci == 32 and Yo == 16 and Xo == 32 and Co == 196)
        shape_2 = (Yi == 8 and Xi == 32 and Ci == 784 and Yo == 8 and Xo == 32 and Co == 32)

        assert not (shape_1 or shape_2), \
            f"GEMM Op currently does not support shape: Yi={Yi}, Xi={Xi}, Ci={Ci}, Yo={Yo}, Xo={Xo}, Co={Co}"

        L2_alloc = L2Alloc(
            (ifm_tile, ifm_off),
            (ofm_tile, ofm_off),
            self._get_wgt_addr(data["wgt_addr"]),
            self._get_prm_addr(data["prm_addr"]),
            data.get("load_input_from_ddr", True),
            data.get("store_output_to_ddr", True),
            data.get("enable_L2_fusion", False),
        )

        if "Transpose" in op_name:
            assert Yi == 1, "Transpose BMM requires Yi=1, till Co padding is supported"
            transpose_wgts = 1
        else:
            transpose_wgts = 0

        L3_alloc = L3Alloc(
            L3_ifm,
            data["L3"]["ofm"],
            [2, shim_wgt_offset],
            [3, shim_prm_offset],
        )

        # Parse input activation dtype and sign
        dtype_A, sign_A = dtype_info(data["in_dtype_A"])

        # Parse weight dtype and sign
        dtype_W, sign_W = dtype_info(data["in_dtype_B"])

        # Parse output dtype and sign
        dtype_O, sign_O = dtype_info(data["out_dtype_Y"])

        # Parse bias dtype and determine QDQ flag
        qdq = 1
        if "_qdq_" in op_name:
            # QDQ Matmuls always as bias/c0 as float
            dtype_Bias = DTYPE_C0
        else:
            # Read the bias data type only for non qdq GEMM flavors
            dtype_Bias, _ = dtype_info(data["in_dtype_Bias"])
            qdq = 0

        # Determine linear operation type
        linear_op_type = self._get_linear_op_type(dtype_A, dtype_W, dtype_O, dtype_Bias)

        dataflow_type = 0 if bool(data["enable_L2_fusion"]) else 1
        debug_mode = data.get("debug_mode", 0)

        wgt_fmt = WGTFormatting(
            data["name"],
            model_data_path,
            read_model_data,
            data["in_dtype_A"],
            data["in_dtype_B"],
            data.get("in_dtype_Bias", "0"),
            data["out_dtype_Y"]
        )

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        dataflow_type = 0 if bool(data["enable_L2_fusion"]) else 1

        return GeMMOp(
            # base
            Ni=Ni,
            Yi=Yi,
            Xi=Xi,
            Ci=Ci,
            No=No,
            Yo=Yo,
            Xo=Xo,
            Co=Co,
            Yip=Yip,
            Xip=Xip,
            Cip=Cip,
            Yop=Yop,
            Xop=Xop,
            Cop=Cop,
            L2=L2_alloc,
            L3=L3_alloc,
            transpose_wgts=transpose_wgts,
            dataflow_type=int(dataflow_type),
            read_bins=read_bins,
            sign_A=sign_A, sign_W=sign_W, sign_O=sign_O,
            dtype_A=dtype_A, dtype_W=dtype_W, dtype_O=dtype_O,
            wgt_fmt=wgt_fmt, pad_value=pad_value, is_dma_pad=is_dma_pad,
            MappingRank=int(data.get("MappingRank", 0)),
            Ky=1, Kx=1, Sy=1, Sx=1, Py=0, Px=0,
            act_type=ActMode(int(data.get("act_type", 0))),
            enable_over_compute=bool(data.get("enable_over_compute", True)),
            vector_coeff=int(data.get("vector_coeff", 0)), debug_mode=debug_mode,
            out_shift=int(data.get("out_shift", 0)),
            bias_shift=int(data.get("bias_shift", 0)),
            dtype_Bias=dtype_Bias,
            linear_op_type=linear_op_type,
            qdq=qdq,
        )

    def shape(self, op_class: GeMMOp) -> Any:
        # Build the ConvShape for GEMM
        dims_shape = ConvShape(
            ifm=(op_class.Yip, op_class.Xip, op_class.Cip),  # pad in L2 alloc?
            ofm=(op_class.Yop, op_class.Xop, op_class.Cop),
            kernel=(op_class.Ky, op_class.Kx),
            stride=(op_class.Sy, op_class.Sx),
            padding=(op_class.Py, op_class.Px),
            vector_coeff=op_class.vector_coeff,
            ifm_bits=op_class.dtype_A,
            ofm_bits=op_class.dtype_O,
            linear_op_type=op_class.linear_op_type,
            enable_over_compute=op_class.enable_over_compute,
            wgt_bits=op_class.dtype_W,
            bias_bits=op_class.dtype_Bias,
            sign_A=op_class.sign_A,
            sign_W=op_class.sign_W,
            sign_O=op_class.sign_O,
            group=1,    # GEMM is group=1 conv
            transpose_wgts=op_class.transpose_wgts,
            Ci_orig=op_class.Ci,
            Co_orig=op_class.Co,
        )
        log(f"dims_shape: {dims_shape}")
        return dims_shape

    def tiler(self, dims_shape: ConvShape, op_class: GeMMOp) -> ScheduleInputs:
        # GEMM tiler mode depends on dtype
        gemm_mode = "act" if op_class.dtype_W == 16 else "wgt"
        mappings = generate_gemm_mappings(dims_shape, gemm_mode)
        shape_key = (dims_shape.ifm[1], dims_shape.ifm[2], dims_shape.ofm[2])
        skipped_tilings = op_class.exception_mapping.get(shape_key, [])
        if skipped_tilings and gemm_mode == "wgt":
            # NOTE: This part handles exception mappings for specific (M, K, N) cases
            # If the selected mapping is in the exception list, we try to find an alternative mapping
            # If no valid mapping is found, raise an error
            initial_mapping = mappings[op_class.MappingRank]
            if initial_mapping.spatial_split in skipped_tilings:
                for idx, candidate_mapping in enumerate(mappings):
                    if candidate_mapping.spatial_split not in skipped_tilings:
                        log(f"Skipping MappingRank {op_class.MappingRank} with spatial_split {initial_mapping.spatial_split} "
                            f"due to exception mapping. Using MappingRank {idx} with spatial_split {candidate_mapping.spatial_split}")
                        mapping = candidate_mapping
                        break
            else:
                mapping = initial_mapping
        else:
            mapping = mappings[op_class.MappingRank]
        return ScheduleInputs(dims_shape, mapping, op_class.dataflow_type,
                              op_class.L2, op_class.L3,
                              DmaPaddingMap(op_class.pad_value, op_class.dtype_A, op_class.is_dma_pad))

    def L2_schedule(self, schedule_input: ScheduleInputs):
        return compile_L2_dataflow(schedule_input)

    def L3_schedule(self, schedule_input: ScheduleInputs):
        return compile_L3_dataflow(schedule_input)

    def preproc(self, schedule_input: ScheduleInputs, op_class: GeMMOp):
        shape: ConvShape = schedule_input.shape
        mapping: ConvMapping = schedule_input.mapping
        back_end: BackEnd = schedule_input.backend
        wgt_fmt: WGTFormatting = op_class.wgt_fmt
        asm_mode = int(back_end != BackEnd.Adf)
        Yi, Xi, Ci = shape.ifm
        _, _, Co = shape.ofm
        _, _, Cis = mapping.ifm_subv
        _, Xos, Cos = mapping.ofm_subv
        _, _, _, C_split = mapping.spatial_split
        is_int4 = int(op_class.dtype_W == 4)
        cfg = {
            "B_GEMM_A16W8": Yi,
            "M_GEMM_A16W8": Xi,
            "K_GEMM_A16W8": Ci,
            "N_GEMM_A16W8": Co,
            "M_SUBV_A16W8": Xos,
            "K_SUBV_A16W8": Cis,
            "N_SUBV_A16W8": Cos,
            # NOTE: set this properly
            "TRANSPOSE_WGTS": op_class.transpose_wgts,
            "M_GEMM_ORIG": op_class.Xi,
            "K_GEMM_ORIG": op_class.Ci,
            "N_GEMM_ORIG": op_class.Co,
            "ASM_MODE": asm_mode,
            "SIGN_ACT": op_class.sign_A,
            "SIGN_WGT": op_class.sign_W,
            "SIGN_OUT": op_class.sign_O,
            "COEFF_VECTOR": op_class.vector_coeff,
            "IS_INT4_WGT": is_int4,
            "READ_IFM": op_class.read_bins.read_ifm,
            "READ_WGT": op_class.read_bins.read_wgt,
            "N_SPLIT": C_split,
            "DEBUG": is_log_enabled(),
            "NODE_NAME": wgt_fmt.node_name,
            "MD_PATH": wgt_fmt.model_data_path,
            "READ_MD": int(wgt_fmt.read_model_data),
            "DTYPE_ACT": wgt_fmt.dtype_act,
            "DTYPE_WGT": wgt_fmt.dtype_wgt,
            "DTYPE_BIAS": wgt_fmt.dtype_bias,
            "DTYPE_OFM": wgt_fmt.dtype_ofm,
            "QDQ": op_class.qdq
        }
        save_cfg_json(cfg, "gemm_cfg.json")


def get_op():
    """API exposed by GeMM for build.py"""
    return GemmBuild()
