"""Build script for Uniop operator"""

from __future__ import annotations
from dataclasses import dataclass
import os
from typing import List, Dict

from dmacompiler import DevGen, set_dev_gen, DmaPaddingMap

from kerneltest.uniop.uniop_1x1_dataflow import compile_uniop_1x1_dataflow

from scheduler.common import L3Alloc
from scheduler.uniop.uniop_l3 import compile_uniop_3x4_dataflow
from scheduler.uniop.uniop_common import UnaryShape, op_mapping, map_op_name, waic_mapping
from scheduler.uniop.gen_io_data import ConfigParam, generate_and_save_io_pairs

from tiler.uniop_tiler import get_uniop_mappings
from utils.utils_common import is_log_enabled, log, ReadBins
from buildscripts.common import (
    OperatorsRegistry, save_cfg_json,
    normalize_shape, get_kernel_id,
    ScheduleInputs, BaseOp, OpBuild,
    WGTFormatting, dtype_info, bytes_to_bits,
    OpRegistryGroupKey
    )
from utils.build_utils import is_qdq_fp16


CURRDIR = os.path.dirname(os.path.abspath(__file__))

set_dev_gen(DevGen.Aie4)

dq_waic_names = ["Dequant_uint16xbfloat16", "Dequant_uint16xfloat32"]
q_waic_names = ["Quant_bfloat16xuint16", "Quant_float32xuint16"]


def register() -> None:
    """Operator registry for uniop"""
    OperatorsRegistry.add_operator(
        ["LpNormalization_qdq_uint16xuint16"],
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_l2norm_fp16x16": get_kernel_id("run_l2norm_fp16x16")},
            "kernel_includes": ["super.hh", "q/q.hpp", "dq/dq.hpp", "l2norm_fp16x16/l2norm_fp16x16_wrapper.cc"]
        },
        group_key=OpRegistryGroupKey.LP_NORM.value
    )

    OperatorsRegistry.add_operator(
        ["Softmax_qdq_uint16xuint16"],
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_softmax_fp16x16": get_kernel_id("run_softmax_fp16x16")},
            "kernel_includes": ["super.hh", "q/q.hpp", "dq/dq.hpp", "softmax_fp16x16/softmax_fp16x16_wrapper.cc"]
        }
    )

    OperatorsRegistry.add_operator(
        ["Silu_qdq_uint16xuint16", "PWLA_qdq_uint16xuint16", "Gelu_qdq_uint16xuint16", "Swish_qdq_uint16xuint16", "Tanh_qdq_uint16xuint16", "Sigmoid_qdq_uint16xuint16", "Elu_qdq_uint16xuint16"],
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_lut_fp16x16": get_kernel_id("run_lut_fp16x16")},
            "kernel_includes": ["super.hh", "q/q.hpp", "dq/dq.hpp", "linear_approx_bf16/linear_approx_bf16_wrapper.cc"]
            # NOTE: these following two lines are used instead if silu/gelu goes with poly-based implementation
            # "kernel_names": {"run_silu": get_kernel_id("run_silu"), "run_lut_fp16x16": get_kernel_id("run_lut_fp16x16")},
            # "kernel_includes": ["super.hh", "q/q.hpp", "dq/dq.hpp", "SiLU_exp2/SiLU_exp2_wrapper.cc", "linear_approx_bf16/linear_approx_bf16_wrapper.cc"]
        }
    )

    OperatorsRegistry.add_operator(
        dq_waic_names,
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_dequant": get_kernel_id("run_dequant")},
            "kernel_includes": ["super.hh", "dq/dq.hpp", "dq/dq_wrapper.cc"]
        },
        group_key=OpRegistryGroupKey.DQ.value
    )

    OperatorsRegistry.add_operator(
        q_waic_names,
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_quant": get_kernel_id("run_quant")},
            "kernel_includes": ["super.hh", "q/q.hpp", "q/q_wrapper.cc"]
        },
        group_key=OpRegistryGroupKey.Q.value
    )

    OperatorsRegistry.add_operator(
        ["Copy_uint16xuint16"],
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_copy_fp16x16": get_kernel_id("run_copy_fp16x16")},
            "kernel_includes": ["super.hh", "softmax_fp16x16/copy_fp16x16_wrapper.cc"]
        }
    )

    OperatorsRegistry.add_operator(
        ["LayerNorm_qdq_uint16xuint16", "LayerNormalization_qdq_uint16xuint8xuint16"],
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_layernorm_fp16x16": get_kernel_id("run_layernorm_fp16x16")},
            "kernel_includes": ["super.hh", "layer_norm_fp16x16/layer_norm_fp16x16_wrapper.cc"]
        }
    )

    OperatorsRegistry.add_operator(
        ["GroupNorm_qdq_uint16xuint16", "GroupNorm_qdq_uint16xuint16xuint16", "GroupNorm_qdq_uint16xint16xuint16",
         "GroupNormalization_qdq_uint16xuint16xuint16", "GroupNormalization_qdq_uint16xint16xuint16"],
        {
            "testbench": ["uniop.cpp"],
            "dataflow_script": "uniop_l3.py",
            "build_script": "build_uniop.py",
            "kernel_names": {"run_group_norm_qdq": get_kernel_id("run_group_norm_qdq")},
            "kernel_includes": ["super.hh", "groupnorm/norm.cc"]
        },
        group_key=OpRegistryGroupKey.GP_NORM.value
    )


@dataclass(frozen=True)
class UnaryOp(BaseOp):
    '''Preproc directives for Uniop used  by testbench'''
    # External op string (e.g., "Softmax_qdq_uint16xuint16")
    op_name: str
    # Internal functional tag ("softmax", "l2norm", "silu", "dequant", "quant", "copy")
    function: str
    gen_io: bool

    # Overlay choice
    overlay_is_1x1: bool = False

    # dequant/quant
    dequant_zero_point: float = 0.0
    dequant_scale: float = 0.0
    dequant_enable: bool = False
    quant_zero_point: float = 0.0
    quant_scale: float = 0.0
    quant_enable: bool = False
    nlf_enable: bool = True
    hasMask: bool = False
    pwla_alpha_val: float = 1.0


class UnaryBuild(OpBuild):
    """Uniop Build Interface for build.py"""
    def __init__(self):
        self.norm_ops_with_wgt = {"LayerNorm", "GroupNorm"}

    def default_kernel_names(self) -> Dict[str, int]:
        return {}

    def default_kernel_includes(self) -> List[str]:
        return []

    def op_type(self):
        return UnaryOp

    def kernel_wrapper_name(self, function: str) -> str:
        """Get name of kernel wrapper per op"""
        op_key = self._op_key_for_function(function)
        return list(OperatorsRegistry.get_operators(op_key)["kernel_names"].keys())[0]

    def kernel_includes(self, function: str) -> List[str]:
        """Get name of kernel include per op"""
        op_key = self._op_key_for_function(function)
        return OperatorsRegistry.get_operators(op_key)["kernel_includes"]

    def _op_key_for_function(self, function: str) -> str:
        if function in op_mapping:
            return map_op_name(function)
        raise ValueError(f"Unknown function: {function}")

    def _parse_from_dict(self, data: dict, shim_prm_offset: int, shim_wgt_offset: int,
                         read_bins: ReadBins, read_model_data: bool, model_data_path: str) -> UnaryOp:
        Ni, Yi, Xi, Ci = normalize_shape(data.get("input"))
        No, Yo, Xo, Co = normalize_shape(data.get("output"))
        _ = (Ni, No, Yo, Xo, Co)

        if Ni != 1:
            Yi = Yi * Ni
            Ni = 1

        op_name = data["op"]

        # Fall-back Alpha value for swish and elu ops is : 1.7
        pwla_alpha_val = 1.7
        if op_name == "PWLA_qdq_uint16xuint16":
            assert "attributes" in data and "pwla_type" in data["attributes"]
            true_op_type = data["attributes"]["pwla_type"][0].lower()
            op_name = map_op_name(true_op_type)
            if true_op_type in {"swish", "elu"} and "Alpha" in data["attributes"]:
                pwla_alpha_val = data["attributes"]["Alpha"][0]

        assert op_name in waic_mapping, f"Unsupported uniop: {op_name}"
        function = map_op_name(op_name)

        overlay_is_1x1 = bool(data.get("overlay_is_1x1", False))

        dequant_zero_point = float(data.get("dequant_zero_point", 5))
        dequant_scale = float(data.get("dequant_scale", 2.0))
        if op_name in dq_waic_names:
            dequant_enable = True
            quant_enable = False
        elif op_name in q_waic_names:
            dequant_enable = False
            quant_enable = True
        else:
            dequant_enable = not data["attributes"]["disable_dq0"][0]
            quant_enable = not data["attributes"]["disable_q"][0]
        quant_zero_point = float(data.get("quant_zero_point", 3))
        quant_scale = float(data.get("quant_scale", 0.000015259 * 15))

        if function == "swish":
            dequant_zero_point = float(data.get("dequant_zero_point", 6724))
            dequant_scale = float(data.get("dequant_scale", 0.0039354744))
            quant_zero_point = float(data.get("quant_zero_point", 46))
            quant_scale = float(data.get("quant_scale", 0.0035341859))

        nlf_enable = bool(data.get("nlf_enable", True))
        hasMask = bool(data.get("hasMask", False))

        L2_alloc = None

        L3_alloc = L3Alloc(
            ifm=data.get("L3", {}).get("ifm", [1, 0]),
            ofm=data.get("L3", {}).get("ofm", [0, 0]),
            wgt=[2, shim_wgt_offset],
            prm=[3, shim_prm_offset],
        )

        dtype_A = data.get("in_dtype_X") or data.get("in_dtype_x") or data.get("in_dtype_input") or data.get("in_dtype_data")
        dtype_W = "uint16"
        dtype_O = data.get("out_dtype_Y") or data.get("out_dtype_y") or data.get("out_dtype_output")

        _, sign_A = dtype_info(dtype_A)
        _, sign_W = dtype_info(dtype_W)
        _, sign_O = dtype_info(dtype_O)

        # Uniop currently on supports L3, once L2 is supported the line below can be uncommented
        # dataflow_type = 0 if data.get("enable_L2_fusion", False) else 1
        dataflow_type = 1
        wgt_fmt = WGTFormatting(
            dtype_act=dtype_A,
            dtype_ofm=dtype_O,
            node_name=data.get("name") or data.get("op"),
            model_data_path=model_data_path,
            read_model_data=read_model_data,
        )

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        if any(norm in op_name for norm in self.norm_ops_with_wgt):

            wgt_fmt.dtype_gamma = data.get("in_dtype_gamma") or data.get("in_dtype_mul_B")
            wgt_fmt.dtype_beta = data.get("in_dtype_beta") or data.get("in_dtype_add_B")
            if wgt_fmt.dtype_gamma is None:
                wgt_fmt.dtype_gamma = "Missing_Value"
            if wgt_fmt.dtype_beta is None:
                wgt_fmt.dtype_beta = "Missing_Value"

        return UnaryOp(
            Ni=Ni, Yi=Yi, Xi=Xi, Ci=Ci,
            No=No, Yo=Yo, Xo=Xo, Co=Co,
            L2=L2_alloc, L3=L3_alloc,
            dataflow_type=dataflow_type,
            read_bins=read_bins,
            sign_A=sign_A, sign_W=sign_W, sign_O=sign_O,
            dtype_A=dtype_A, dtype_W=dtype_W, dtype_O=dtype_O,
            wgt_fmt=wgt_fmt, pad_value=pad_value, is_dma_pad=is_dma_pad,
            op_name=op_name, function=function, gen_io=data.get("gen_io", wgt_fmt.read_model_data != 1),
            overlay_is_1x1=overlay_is_1x1,
            dequant_zero_point=dequant_zero_point,
            dequant_scale=dequant_scale,
            dequant_enable=dequant_enable,
            quant_zero_point=quant_zero_point,
            quant_scale=quant_scale,
            quant_enable=quant_enable,
            nlf_enable=nlf_enable,
            hasMask=hasMask,
            pwla_alpha_val=pwla_alpha_val
        )

    def shape(self, op_class: UnaryOp) -> UnaryShape:
        N = op_class.Ni
        X = op_class.Yi * op_class.Xi
        C = op_class.Ci
        ifmbytes = 4 if op_class.wgt_fmt.dtype_act in {"float32", "int32", "uint32"} else 2
        ofmbytes = 4 if op_class.wgt_fmt.dtype_ofm in {"float32", "int32", "uint32"} else 2
        ifmSign = 1 if op_class.wgt_fmt.dtype_act in {"int8", "int16", "int32"} else 0
        ofmSign = 1 if op_class.wgt_fmt.dtype_ofm in {"int8", "int16", "int32"} else 0

        return UnaryShape(op_class.function, (N, X, C), ifmbytes, ofmbytes, ifmSign, ofmSign, "N1X12C1")

    def tiler(self, dims_shape: UnaryShape, op_class: UnaryOp) -> ScheduleInputs:
        # For 1x1 overlay, we skip mapping and call the 1x1 compiler later.
        mapping = None if op_class.overlay_is_1x1 else get_uniop_mappings(dims_shape)

        return ScheduleInputs(
            dims_shape,
            mapping,
            op_class.dataflow_type,
            op_class.L2,
            op_class.L3,
            DmaPaddingMap(op_class.pad_value, bytes_to_bits(dims_shape.ifmbytes), op_class.is_dma_pad),
            kernel_names=None,
            kernel_includes=None,
        )

    def L3_schedule(self, schedule_input: ScheduleInputs):
        dims: UnaryShape = schedule_input.shape
        # Ensure kernels/includes exist (prefer from build system, else registry)
        kernel_names = schedule_input.kernel_names
        kernel_includes = schedule_input.kernel_includes
        if not kernel_names or not kernel_includes:
            function = dims.function if hasattr(dims, "function") else None
            if function:
                kernel_names = {self.kernel_wrapper_name(function): get_kernel_id(self.kernel_wrapper_name(function))}
                kernel_includes = self.kernel_includes(function)
                schedule_input.kernel_names = kernel_names
                schedule_input.kernel_includes = kernel_includes

        # Two paths: 1x1 overlay (kerneltest) vs 3x4 L3 scheduler
        if schedule_input.mapping is None:
            # 1x1 overlay: SubVolumeDim = TensorDim = (1, N*X, C)
            N, X, C = dims.N, dims.X, dims.C
            return (0, 0) if compile_uniop_1x1_dataflow(1, N * X, C) is None else (0, 0)

        # 3x4 L3 overlay
        return compile_uniop_3x4_dataflow(schedule_input)

    def L2_schedule(self, schedule_input: ScheduleInputs):
        _ = schedule_input
        raise NotImplementedError("UnaryBuild has no L2-only scheduler path.")

    def preproc(self, schedule_input: ScheduleInputs, op_class: UnaryOp):
        _ = schedule_input

        # we use the true dim to form folder name
        process_id = os.getpid()
        aie4_model_repo_root = os.environ.get("AIE4_ROOT_DIR")
        kernel_test_dir = os.path.join(aie4_model_repo_root, "kerneltest")
        test_case_data_dir = op_class.wgt_fmt.node_name.replace(".", "_").replace("/", "_").replace("\\", "_")
        test_data_dir = os.path.join(
            kernel_test_dir,
            op_class.function,
            f"test_data_{process_id}",
            test_case_data_dir,
        )
        test_data_dir = test_data_dir + "/"
        os.makedirs(test_data_dir, exist_ok=True)
        test_data_dir = os.path.join(test_data_dir, "")

        # IO generation
        cfg = ConfigParam()
        if op_class.dequant_enable:
            cfg.enable_dQ()
        if op_class.quant_enable:
            cfg.enable_Q()
        cfg.set_dQ_zero_point(op_class.dequant_zero_point)
        cfg.set_dQ_scale(op_class.dequant_scale)
        cfg.set_Q_zero_point(op_class.quant_zero_point)
        cfg.set_Q_scale(op_class.quant_scale)

        ifmbytes = 4 if op_class.wgt_fmt.dtype_act in {"float32", "int32", "uint32"} else 2
        ofmbytes = 4 if op_class.wgt_fmt.dtype_ofm in {"float32", "int32", "uint32"} else 2

        log("ifmbytes:", ifmbytes)
        log("ofmbytes:", ofmbytes)

        PadC = schedule_input.mapping.PaddedTDim[2]

        reshaped_N = schedule_input.mapping.TensorDim[0]
        reshaped_X = schedule_input.mapping.TensorDim[1]
        reshaped_C = schedule_input.mapping.TensorDim[2]

        log("Reshaped N,X,C:", (reshaped_N, reshaped_X, reshaped_C))
        if op_class.gen_io:
            generate_and_save_io_pairs(
                reshaped_N * reshaped_X,  # here this is the "reshaped" dimension
                reshaped_C,               # here this is the "reshaped" dimension
                PadC,
                config=cfg,
                test_data_dir=test_data_dir,  # we use the true dim to form folder name
                function=op_class.function,
                overlay_is_1x1=op_class.overlay_is_1x1,
                qdq_input_fp32=(ifmbytes == 4),
                qdq_output_fp32=(ofmbytes == 4),
                qdq_floating_is_fp16=(is_qdq_fp16())
            )

        # preproc cfg
        cfg = {
            "_Y": reshaped_N,
            "_X": reshaped_X,
            "_C": PadC,
            "_trueC": reshaped_C,
            "TEST_DATA_DIR": test_data_dir,
            "DEQUANT_ZERO_POINT": op_class.dequant_zero_point,
            "DEQUANT_SCALE": op_class.dequant_scale,
            "DEQUANT_ENABLE": int(op_class.dequant_enable),
            "QUANT_ZERO_POINT": op_class.quant_zero_point,
            "QUANT_SCALE": op_class.quant_scale,
            "QUANT_ENABLE": int(op_class.quant_enable),
            "NLF_ENABLE": int(op_class.nlf_enable),
            "HAS_MASK": int(op_class.hasMask),
            "DEBUG_MODE": is_log_enabled(),
            "NODE_NAME": op_class.wgt_fmt.node_name,
            "MD_PATH": op_class.wgt_fmt.model_data_path,
            "READ_MD": int(op_class.wgt_fmt.read_model_data),
            "FUNCTION": op_class.function,
            "DTYPE_ACT": op_class.wgt_fmt.dtype_act,
            "DTYPE_OUT": op_class.wgt_fmt.dtype_ofm,
            "SIGN_ACT": op_class.sign_A,
            "SIGN_OUT": op_class.sign_O,
            "PWLA_ALPHA": op_class.pwla_alpha_val,
            "gen_io": int(op_class.gen_io)
        }

        if any(norm in op_class.op_name for norm in self.norm_ops_with_wgt):
            cfg["DTYPE_GAMMA"] = op_class.wgt_fmt.dtype_gamma
            cfg["DTYPE_BETA"] = op_class.wgt_fmt.dtype_beta
        save_cfg_json(cfg, "uniop_cfg.json")


def get_op():
    """API exposed by Uniop for build.py"""
    return UnaryBuild()
