"""Build script for gap operator"""

from __future__ import annotations

from dataclasses import dataclass
from typing import List, Dict

from utils.utils_common import L2Alloc, ReadBins
from scheduler.common import (
    L3Alloc,
)
from scheduler.gap.gap_comman import GAPDims
from scheduler.gap.gap_l2 import compile_gap_dataflow_l2
from scheduler.gap.gap_l3 import compile_gap_dataflow_l3
from dmacompiler import BackEnd, DevGen, DmaPaddingMap, set_dev_gen

from buildscripts.common import (
    OperatorsRegistry, save_cfg_json,
    normalize_shape, get_kernel_id,
    BaseOp, OpBuild, ScheduleInputs,
    WGTFormatting, bits_to_bytes,
    bytes_to_bits,
    )

set_dev_gen(DevGen.Aie4)


def register() -> None:
    """Adding `globalavgpool_int8x8` operator that has multiple flavors"""
    OperatorsRegistry.add_operator(
        "gap",
        {
            "testbench": ["gap.cpp", "gap.hpp"],
            "dataflow_script": "gap.py",
            "build_script": "build_gap.py",
            "kernel_names": {"run_globalavgpool_int8x8": get_kernel_id("run_globalavgpool_int8x8")},
            "kernel_includes": ["super.hh", "gap/globalavgpool_int8x8_wrapper.cc"],
        },
    )


@dataclass(frozen=True)
class GAPOp(BaseOp):
    """
    Dataclass for elementwise GAP op.
    """


class GAPBuild(OpBuild):
    """GAP Build Interface for build.py"""

    def default_kernel_names(self) -> Dict[str, int]:
        return {"run_globalavgpool_int8x8": get_kernel_id("run_globalavgpool_int8x8")}

    def default_kernel_includes(self) -> List[str]:
        return ["super.hh", "gap/globalavgpool_int8x8_wrapper.cc"]

    def op_type(self):
        return GAPOp

    def _parse_from_dict(self, data: dict,
                         shim_prm_offset: int,
                         shim_wgt_offset: int,
                         read_bins: ReadBins,
                         read_model_data: bool,
                         model_data_path: str) -> GAPOp:
        Ni, Ci, Yi, Xi = normalize_shape(data.get("input"))
        No, Yo, Xo, Co = normalize_shape(data.get("output"))
        _ = (Ni, No)

        # L2 addresses (ifm + ofm)
        ifm_tile, ifm_off = self._get_tile_offset(data["input_addr"])
        ofm_tile, ofm_off = self._get_tile_offset(data["output_addr"])

        L2_alloc = L2Alloc(
            ifm_L2_loc=(ifm_tile, ifm_off),
            ofm_L2_loc=(ofm_tile, ofm_off),
            wgt_l2_loc=self._get_wgt_addr(data["wgt_addr"]),
            prm_l2_loc=self._get_prm_addr(data["prm_addr"]),
            enable_ifm_fill=bool(data.get("load_input_from_ddr", True)),
            enable_ofm_spill=bool(data.get("store_output_to_ddr", True)),
            enable_L2_fusion=bool(data.get("enable_L2_fusion", False)),
        )

        L3_alloc = L3Alloc(
            ifm=data.get("L3", {}).get("ifm", [1, 0]),
            ofm=data.get("L3", {}).get("ofm", [0, 0]),
            wgt=[2, shim_wgt_offset],
            prm=[3, shim_prm_offset],
        )

        dataflow_type = 0 if data.get("enable_L2_fusion", False) else 1
        wgt_fmt = WGTFormatting(node_name=data.get("name") or data.get("op"),
                                model_data_path=model_data_path,
                                read_model_data=read_model_data)

        dtype_A, sign_A = 8, 1
        dtype_W, sign_W = 8, 1
        dtype_O, sign_O = 8, 1

        pad_value, is_dma_pad = self._get_dma_pad(data, read_model_data)

        return GAPOp(
            # base
            Ni=Ni,
            Yi=Yi,
            Xi=Xi,
            Ci=Ci,
            No=No,
            Yo=Yo,
            Xo=Xo,
            Co=Co,
            L2=L2_alloc,
            L3=L3_alloc,
            dataflow_type=int(dataflow_type),
            read_bins=read_bins,
            sign_A=int(sign_A),
            sign_W=int(sign_W),
            sign_O=int(sign_O),
            dtype_A=bits_to_bytes(dtype_A),
            dtype_W=bits_to_bytes(dtype_W),
            dtype_O=bits_to_bytes(dtype_O),
            wgt_fmt=wgt_fmt,
            pad_value=pad_value, is_dma_pad=is_dma_pad,
        )

    def shape(self, op_class: GAPOp) -> GAPDims:
        # Hardcoded for L2 Fusion
        prm_bits, bits_per_byte = 8, 8
        aie_rows, aie_cols = 4, 3
        Yis, Xis, Cis = 7, 7, 512
        Yos, Xos, Cos = 1, 1, 512

        return GAPDims(
            Yi=op_class.Yi,
            Xi=op_class.Xi,
            Ci=op_class.Ci,
            Yo=op_class.Yo,
            Xo=op_class.Xo,
            Co=op_class.Co,
            act_bits=bytes_to_bits(op_class.dtype_A),
            out_bits=bytes_to_bits(op_class.dtype_O),
            sign_act=op_class.sign_A,
            sign_out=op_class.sign_O,
            prm_bits=prm_bits,
            bits_per_byte=bits_per_byte,
            aie_rows=aie_rows,
            aie_cols=aie_cols,
            Yis=Yis,
            Xis=Xis,
            Cis=Cis,
            Yos=Yos,
            Xos=Xos,
            Cos=Cos,
            fusion_param=op_class.L2,
        )

    def tiler(self, dims_shape: GAPDims, op_class: GAPOp) -> ScheduleInputs:
        return ScheduleInputs(
            dims_shape,
            dims_shape,
            op_class.dataflow_type,
            op_class.L2,
            op_class.L3,
            DmaPaddingMap(op_class.pad_value, bytes_to_bits(op_class.dtype_A), op_class.is_dma_pad),
        )

    def L2_schedule(self, schedule_input: ScheduleInputs):
        return compile_gap_dataflow_l2(schedule_input)

    def L3_schedule(self, schedule_input: ScheduleInputs):
        return compile_gap_dataflow_l3(schedule_input)

    def preproc(self, schedule_input: ScheduleInputs, op_class: GAPOp):
        dims: GAPDims = schedule_input.shape
        read_bins: ReadBins = op_class.read_bins
        back_end: BackEnd = schedule_input.backend
        asm_mode = int(back_end != BackEnd.Adf)
        cfg = {
            "Y_IN": dims.Yi,
            "X_IN": dims.Xi,
            "C_IN": dims.Ci,
            "Y_OUT": dims.Yo,
            "X_OUT": dims.Xo,
            "C_OUT": dims.Co,
            "ACT_BITS": dims.act_bits,
            "OUT_BITS": dims.out_bits,
            "SIGN_ACT": dims.sign_act,
            "SIGN_OUT": dims.sign_out,
            "PRM_BITS": dims.prm_bits,
            "BITS_PER_BYTE": dims.bits_per_byte,
            "AIE_ROWS": dims.aie_rows,
            "AIE_COLS": dims.aie_cols,
            "SHIFT": dims.shift_res,
            "SCALE": dims.param_value,
            "ASM_MODE": asm_mode,
            "READ_IFM": read_bins.read_ifm,
            "READ_WGT": read_bins.read_wgt,
        }
        save_cfg_json(cfg, "gap_cfg.json")


def get_op():
    """API exposed by GAP for build.py"""
    return GAPBuild()
