'''Build script for Conv operator'''
from __future__ import annotations

from dataclasses import dataclass
from math import ceil, log2
from typing import Any, Dict

import numpy as np

from buildscripts.common import (
    BaseOp,
    OpBuild,
    OperatorsRegistry,
    ScheduleInputs,
    WGTFormatting,
    get_kernel_id,
    normalize_shape,
    save_cfg_json,
    dtype_info,
    DTYPE_C0,
    OpRegistryGroupKey,
    BaseKernelSelector,
)
from dmacompiler import BackEnd, DevGen, set_dev_gen, DmaPaddingMap
from scheduler.common import (
    L3Alloc,
    LinearOpType,
)
from scheduler.conv.conv_config_builders import (
    ActMode,
    ConvMapping,
    ConvShape,
)
from scheduler.conv.conv_L2_schedule import compile_L2_dataflow
from scheduler.conv.conv_L3_schedule import compile_L3_dataflow
from scheduler.dwc.dwc_L3_schedule import compile_DWC_L3_dataflow
from tiler.conv_tiler import generate_conv_mappings
from tiler.dwc_tiler import generate_dwc_mappings
from utils.utils_common import L2Alloc, iceil, log, is_log_enabled, ReadBins

set_dev_gen(DevGen.Aie4)


class ConvKernelSelector(BaseKernelSelector):
    """Kernel Selector for Conv"""
    def select(self, field_name, attrs, operator, metadata) -> str:
        """
        Choose which kernel field to use for this Conv node.

        - If attrs['group'] is greater than 1 and a '<field_name>_group' variant
          exists in metadata (e.g. 'kernel_names_group', 'kernel_includes_group'),
          that variant is used.
        - Otherwise, we fall back to the base field (e.g. 'kernel_names',
          'kernel_includes').
        - Group > 1 indicates its a depthwise convolution
        """
        if (attrs.get("group") or 1) > 1 and f"{field_name}_group" in metadata:
            return f"{field_name}_group"
        return field_name


def register() -> None:
    '''Adding `conv_noqdq_a8w8` operator that has multiple flavors'''
    OperatorsRegistry.add_operator(["conv_noqdq_a8w8", "Conv"], {
            "testbench": ["conv.cpp", "conv.hpp"],
            "dataflow_script": "conv.py",
            "build_script": "build_conv.py",
            "kernel_names": {"run_conv_noqdq_a8w8": get_kernel_id("run_conv_noqdq_a8w8")},
            "kernel_includes": ["super.hh", "conv/conv_noqdq_a8w8_wrapper.cc"],
        }
    )

    OperatorsRegistry.add_operator(
        [
            "Conv_qdq_int16xint8xint16",
            "Conv_qdq_uint16xint8xuint16",
            "Conv_qdq_uint16xuint8xuint16",
            "Conv_qdq_bias_uint16xuint8xuint16",
            "Conv_qdq_bias_uint16xint8xuint16",
            "Conv_qdq_bias_int16xint8xint16",
        ], {
            "testbench": ["conv.cpp", "conv.hpp"],
            "dataflow_script": "conv_L3_schedule.py",
            "build_script": "build_conv.py",
            "kernel_names": {"run_conv_qdq_a16w8": get_kernel_id("run_conv_qdq_a16w8")},
            "kernel_includes": [
                                    "super.hh",
                                    "conv_qdq_int16x8/conv_qdq_a16w8_wrapper.cc",
                               ],
            "kernel_names_group": {"run_dwc_qdq_a16w8": get_kernel_id("run_dwc_qdq_a16w8")},
            "kernel_includes_group": [
                                    "super.hh",
                                    "dwc_int16x8/dwc_qdq_a16w8_wrapper.cc",
                               ],
            "kernel_selector": ConvKernelSelector(),
        },
        group_key=OpRegistryGroupKey.CONV_A16.value
    )

    OperatorsRegistry.add_operator(
        ["Conv_DWC"], {
            "testbench": ["conv.cpp", "conv.hpp"],
            "dataflow_script": "conv_L3_schedule.py",
            "build_script": "build_conv.py",
            "kernel_names": {"run_dwc_qdq_a16w8": get_kernel_id("run_dwc_qdq_a16w8")},
            "kernel_includes": ["super.hh", "dwc_int16x8/dwc_qdq_a16w8_wrapper.cc"],
        },
        group_key=OpRegistryGroupKey.CONV_DWC_A16.value
    )


def ceil_log2(x: int) -> int:
    """ Integer-only ceiling log2"""
    return int(ceil(log2(x))) if x > 0 else 0


def compute_shift_from_bin(bin_path, cin, k, bias_shift=2, b_max=8, dtype=np.uint8) -> int:
    """
    bin_path : path to ofm.bin
    cin      : number of input channels for next conv
    k        : kernel size
    b_max    : max weight value (default 8)
    dtype    : dtype used when writing binary (np.uint8, np.uint16, np.int32, etc.)
    """
    data = np.fromfile(bin_path, dtype=dtype)
    a_max = data.max()
    num_accum = cin * k * k
    bC = ceil_log2(a_max * b_max * num_accum)
    shift_out = max(0, bC + bias_shift - 8)
    log(f"a_max: {int(a_max)}")
    log(f"num_accum: {num_accum}")
    log(f"bC_bits: {bC}")
    log(f"shift_out: {shift_out}")
    return shift_out


@dataclass(frozen=True)
class ConvOp(BaseOp):
    """BaseOp for Conv"""
    # Tiling/mapping selection
    MappingRank: int

    # Op hyper-params
    Ky: int
    Kx: int
    Sy: int
    Sx: int
    Py: int
    Px: int

    dtype_Bias: int

    linear_op_type: LinearOpType

    qdq: int

    # Activation mode and compute policy
    debug_mode: int = 0
    vector_coeff: int = 0
    act_type: ActMode = ActMode(0)
    enable_over_compute: bool = False

    # Preproc knobs
    out_shift: int = 0
    bias_shift: int = 0

    # NOTE: For normal CONV group=1 for DWC group == Cout
    group: int = 1


class ConvBuild(OpBuild):
    """Conv Build Interface for build.py"""
    CONV_LINEAR_OP_TYPE_MAP = {
        (8, 8, 8, 32): LinearOpType.conv_A8W8_qdq,
        (16, 8, 16, 32): LinearOpType.conv_A16W8_qdq,
        (8, 8, 8, 0): LinearOpType.conv_A8W8_qdq,
        (16, 8, 16, 0): LinearOpType.conv_A16W8_qdq,
        (8, 8, 8, 16): LinearOpType.conv_A8W8_noqdq,
    }

    DWC_LINEAR_OP_TYPE_MAP = {
        (8, 8, 8, 32): LinearOpType.dwc_A8W8_qdq,
        (16, 8, 16, 32): LinearOpType.dwc_A16W8_qdq,
        (8, 8, 8, 0): LinearOpType.dwc_A8W8_qdq,
        (16, 8, 16, 0): LinearOpType.dwc_A16W8_qdq,
        (8, 8, 8, 16): LinearOpType.dwc_A8W8_noqdq,
    }

    def _get_linear_op_type(self, dtype_A: int, dtype_W: int, dtype_O: int, dtype_Bias: int, group: int) -> LinearOpType:
        """Determine LinearOpType based on dtype combination"""
        dtype_combo = (dtype_A, dtype_W, dtype_O, dtype_Bias)

        if dtype_combo in self.CONV_LINEAR_OP_TYPE_MAP and group == 1:
            return self.CONV_LINEAR_OP_TYPE_MAP[dtype_combo]
        if dtype_combo in self.DWC_LINEAR_OP_TYPE_MAP:
            return self.DWC_LINEAR_OP_TYPE_MAP[dtype_combo]
        raise ValueError(f"Unsupported combination of dtypes conv: A={dtype_A}, W={dtype_W}, O={dtype_O}, Bias={dtype_Bias}, group={group}")

    def default_kernel_names(self) -> Dict[str, int]:
        return {"run_conv_noqdq_a8w8": get_kernel_id("run_conv_noqdq_a8w8")}

    def default_kernel_includes(self) -> list[str]:
        return ["super.hh", "conv/conv_noqdq_a8w8_wrapper.cc"]

    def op_type(self):
        return ConvOp

    def _parse_from_dict(
        self, data: dict, shim_prm_offset: int, shim_wgt_offset: int,
        read_bins: ReadBins, read_model_data: bool, model_data_path: str
    ) -> ConvOp:
        Ni, Yi, Xi, Ci = normalize_shape(data.get("input"))
        No, Yo, Xo, Co = normalize_shape(data.get("output"))
        Ky, Kx = data.get("kernel", [1, 1])
        Sy, Sx = data.get("stride", [1, 1])
        Py, Px = data.get("pad", [0, 0])
        group = data.get("group", 1)
        if group > 1:
            assert group == Co, \
                f"When group > 1 it is infered as a DWC OP and group == Cout but recieved group={group}, Cout={Co}"

        ifm_tile, ifm_off = self._get_tile_offset(data["input_addr"])
        ofm_tile, ofm_off = self._get_tile_offset(data["output_addr"])

        L2_alloc = L2Alloc(
            (ifm_tile, ifm_off),
            (ofm_tile, ofm_off),
            self._get_wgt_addr(data["wgt_addr"]),
            self._get_prm_addr(data["prm_addr"]),
            data.get("load_input_from_ddr", True),
            data.get("store_output_to_ddr", True),
            data.get("enable_L2_fusion", False),
        )

        op_name = data["op"]

        L3_alloc = L3Alloc(
            data["L3"]["ifm"],
            data["L3"]["ofm"],
            [2, shim_wgt_offset],
            [3, shim_prm_offset],
        )

        # Parse input activation dtype and sign
        dtype_A, sign_A = dtype_info(data["in_dtype_A"])

        # Parse weight dtype and sign
        dtype_W, sign_W = dtype_info(data["in_dtype_B"])

        # Parse output dtype and sign
        dtype_O, sign_O = dtype_info(data["out_dtype_Y"])

        # Parse bias dtype and determine QDQ flag
        qdq = 1
        if "_qdq_" in op_name:
            dtype_Bias = DTYPE_C0
        else:
            dtype_Bias, _ = dtype_info(data["in_dtype_Bias"])
            qdq = 0

        # Determine linear operation type
        linear_op_type = self._get_linear_op_type(dtype_A, dtype_W, dtype_O, dtype_Bias, group)

        # Parse other parameters
        bias_shift = int(data.get("bias_shift", 0))
        out_shift = (
            data.get("out_shift", 8)
            if read_bins.read_ifm == 0
            else compute_shift_from_bin("../intermediate_bins/ifm1.bin", Ci, Ky, bias_shift)
        )
        dataflow_type = 0 if data.get("enable_L2_fusion", False) else 1
        debug_mode = data.get("debug_mode", 0)

        wgt_fmt = WGTFormatting(
            node_name=data.get("name") or data.get("op"),
            model_data_path=model_data_path,
            read_model_data=read_model_data,
            dtype_act=data["in_dtype_A"],
            dtype_wgt=data["in_dtype_B"],
            dtype_bias=data.get("in_dtype_Bias", "0"),
            dtype_ofm=data["out_dtype_Y"],
        )

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        return ConvOp(
            Ni=Ni, Yi=Yi, Xi=Xi, Ci=Ci, No=No, Yo=Yo, Xo=Xo, Co=Co,
            L2=L2_alloc, L3=L3_alloc,
            dataflow_type=dataflow_type,
            read_bins=read_bins,
            sign_A=sign_A,
            sign_W=sign_W,
            sign_O=sign_O,
            dtype_A=dtype_A, dtype_W=dtype_W, dtype_O=dtype_O,
            dtype_Bias=dtype_Bias,
            linear_op_type=linear_op_type,
            qdq=qdq,
            wgt_fmt=wgt_fmt, pad_value=pad_value, is_dma_pad=is_dma_pad,
            MappingRank=int(data.get("MappingRank", 0)),
            Ky=Ky, Kx=Kx, Sy=Sy, Sx=Sx, Py=Py, Px=Px,
            act_type=ActMode(int(data.get("act_type", 0))),
            enable_over_compute=bool(data.get("enable_over_compute", True)),
            vector_coeff=int(data.get("vector_coeff", 0)), debug_mode=debug_mode,
            out_shift=out_shift, bias_shift=bias_shift,
            group=group,
        )

    def shape(self, op_class: ConvOp) -> Any:
        Ci = 0
        if op_class.Ci <= 8 and op_class.linear_op_type == LinearOpType.conv_A8W8_noqdq:
            # NOTE: THe kernel can do Ci = 8 with Kx folding into Cin on the weights formatting
            Ci = 8
        else:
            Ci = iceil(op_class.Ci, 64)
        dims_shape = ConvShape(
            ifm=(op_class.Yi, op_class.Xi, Ci),
            ofm=(op_class.Yo, op_class.Xo, iceil(op_class.Co, 64)),
            kernel=(op_class.Ky, op_class.Kx),
            stride=(op_class.Sy, op_class.Sx),
            padding=(op_class.Py, op_class.Px),
            vector_coeff=op_class.vector_coeff,
            ifm_bits=op_class.dtype_A,
            ofm_bits=op_class.dtype_O,
            linear_op_type=op_class.linear_op_type,
            enable_over_compute=op_class.enable_over_compute,
            wgt_bits=op_class.dtype_W,
            bias_bits=op_class.dtype_Bias,
            sign_A=op_class.sign_A,
            sign_W=op_class.sign_W,
            sign_O=op_class.sign_O,
            group=op_class.group,
        )
        log(f"dims_shape: {dims_shape}")
        return dims_shape

    def tiler(self, dims_shape: ConvShape, op_class: ConvOp) -> ScheduleInputs:
        if op_class.group == 1:
            mappings = generate_conv_mappings(dims_shape, op_class.enable_over_compute)
        else:
            mappings = generate_dwc_mappings(dims_shape, op_class.enable_over_compute)
        mapping = mappings[op_class.MappingRank]
        return ScheduleInputs(dims_shape, mapping,
                              op_class.dataflow_type,
                              op_class.L2, op_class.L3,
                              DmaPaddingMap(op_class.pad_value, op_class.dtype_A, op_class.is_dma_pad))

    def L2_schedule(self, schedule_input: ScheduleInputs):
        if schedule_input.shape.group == 1:
            return compile_L2_dataflow(schedule_input)
        raise ValueError("DWC L2 scheduling is not supported")

    def L3_schedule(self, schedule_input: ScheduleInputs):
        if schedule_input.shape.group == 1:
            return compile_L3_dataflow(schedule_input)
        return compile_DWC_L3_dataflow(schedule_input)

    def preproc(self, schedule_input: ScheduleInputs, op_class: ConvOp):
        shape: ConvShape = schedule_input.shape
        mapping: ConvMapping = schedule_input.mapping
        back_end: BackEnd = schedule_input.backend
        wgt_fmt: WGTFormatting = op_class.wgt_fmt
        asm_mode = int(back_end != BackEnd.Adf)
        Yi, Xi, Ci = shape.ifm
        Yo, Xo, Co = shape.ofm
        Ky, Kx = shape.kernel
        Sy, Sx = shape.stride
        Py, Px = shape.padding
        Yis, Xis, Cis = mapping.ifm_subv
        Yos, Xos, Cos = mapping.ofm_subv
        _, _, _, C_split = mapping.spatial_split
        cfg = {
            "C_IN": Ci,
            "Y_IN": Yi,
            "X_IN": Xi,
            "C_OUT": Co,
            "Y_OUT": Yo,
            "X_OUT": Xo,
            "C_IN_ORIG": op_class.Ci,
            "Y_IN_ORIG": op_class.Yi,
            "X_IN_ORIG": op_class.Xi,
            "C_OUT_ORIG": op_class.Co,
            "Y_OUT_ORIG": op_class.Yo,
            "X_OUT_ORIG": op_class.Xo,
            "KERNEL_Y": Ky,
            "KERNEL_X": Kx,
            "STRIDE_Y": Sy,
            "STRIDE_X": Sx,
            "CIS": Cis,
            "YIS": Yis,
            "XIS": Xis,
            "COS": Cos,
            "YOS": Yos,
            "XOS": Xos,
            "PAD_Y": Py,
            "PAD_X": Px,
            "C_OUT_SPLIT": C_split,
            "SHIFT_OUT":  op_class.out_shift,
            "BIAS_SHIFT":  op_class.bias_shift,
            "SIGN_ACT":  op_class.sign_A,
            "SIGN_WGT":  op_class.sign_W,
            "SIGN_OUT":  op_class.sign_O,
            "COEFF_VECTOR": op_class.vector_coeff,
            "ACT_MODE":  op_class.act_type,
            "QDQ": op_class.qdq,
            "ASM_MODE": asm_mode,
            "READ_IFM": op_class.read_bins.read_ifm,
            "READ_WGT": op_class.read_bins.read_wgt,
            "NODE_NAME": wgt_fmt.node_name,
            "MD_PATH": wgt_fmt.model_data_path,
            "READ_MD": int(wgt_fmt.read_model_data),
            "DTYPE_ACT": wgt_fmt.dtype_act,
            "DTYPE_WGT": wgt_fmt.dtype_wgt,
            "DTYPE_BIAS": wgt_fmt.dtype_bias,
            "DTYPE_OFM": wgt_fmt.dtype_ofm,
            "DEBUG": is_log_enabled(),
            "GROUP": op_class.group,
        }
        save_cfg_json(cfg, "conv_cfg.json")


def get_op():
    """API exposed by Conv for build.py"""
    return ConvBuild()
