# fmt: on
from OGOAT.src.L1_fusion.py_match.helpers.batch_helper import BatchHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    Node,
)
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BinaryOpHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    opType,
)
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul
from OGOAT.src.L1_fusion.py_match.nodes_tensors import NoMatch, Tensor


import numpy as np
from typing import Optional, Dict, Tuple
from dataclasses import dataclass
from enum import IntEnum


class MHAMode(IntEnum):
    UNSUPPORTED = 0
    TWO_P_ONE = 1
    TWO_P_ONE_MINI = 2
    THREE_P_ZERO_MINI = 3


# FIXME copied from OGOAT.src.Tiler.layer which is already proted to CPP (hence, we have different implementations anyway)
def reshape_into_3D_tensor(shape: list[int], batch_size: int) -> list[int]:
    assert batch_size >= 1
    new_shape = shape.copy()

    # Make sure that we have at least a shape of rank 4 so we can fold the batch dimension or the
    # second to second last dimensions if needed.
    while len(new_shape) < 4:
        new_shape.insert(0, 1)

    # fold the outer most dimension into a single batch dimension matching the batch size requested
    outer_dim = new_shape.pop(0)
    while outer_dim != batch_size and len(new_shape):
        outer_dim *= new_shape.pop(0)

    new_shape.insert(0, outer_dim)
    if len(new_shape) < 3 or new_shape[0] != batch_size:
        raise RuntimeError(
            f"Could not reshape {shape} into a 3d shape with requested batch size '{batch_size}'"
        )

    # Now that the outer most dimension is equal to the batch size,
    # reshape the second dimension and keep the last one intact.
    if len(new_shape) > 3:
        new_shape = new_shape[0:1] + [np.prod(new_shape[1:-1])] + new_shape[-1:]

    return new_shape


@dataclass
class Attention(Matcher, TransposeHelper, QDQHelper):
    dependencies = [MatMul()]

    # Matmul Q x K
    matmul_in: Optional[Node] = None

    # Matmul Softmax_output x  V
    matmul_out: Optional[Node] = None

    # The add that is an element wise on the two last dimension
    mask_add: Optional[Node] = None

    # The add that is a broadcast on the two last dimension
    bias_add: Optional[Node] = None

    # The softmax part of the attention pattern
    softmax: Optional[Node] = None
    sfm_quant: Optional[Node] = None
    sfm_dq: Optional[Node] = None

    # The add that needs to be hoisted out,
    # with the add to which it needs to be connected
    hoist_out: Optional[Node] = None
    connection: Optional[Node] = None

    # True if the attention pattern match with a support transpose at its K input
    has_k_transpose: bool = False

    # Number of heads matched in the attention pattern
    nb_heads: int = 1

    # Batch size of the K and V input, should be:
    # nb_heads >= nb_groups and nb_heads % nb_groups == 0
    nb_groups: int = 1

    # Batch size of the bias and mask if present
    # nb_heads >= nb_mask and nb_heads % nb_mask == 0
    # nb_heads >= nb_bias and nb_heads % nb_bias == 0
    nb_mask: int = 1
    nb_bias: int = 1

    # MHA mode detected using the QxK input shape, this will be used to derive
    # the op type of the fused node
    mha_mode: MHAMode = MHAMode.UNSUPPORTED

    def _reset(self):
        # reset state of the parent class
        super()._reset()

        # reset state of the attributes specific to the attention pattern
        self.matmul_in = None
        self.matmul_out = None
        self.softmax = None
        self.sfm_quant = None
        self.sfm_dq = None
        self.mask_add = None
        self.bias_add = None
        self.hoist_out = None
        self.connection = None
        self.has_k_transpose = False
        self.nb_heads = 1
        self.nb_groups = 1
        self.nb_mask = 1
        self.nb_bias = 1
        self.mha_mode = MHAMode.UNSUPPORTED

    def is_matmul_actxact_wo_bias(self, node: Node) -> bool:
        has_no_bias = node.get_attribute_value("bias") is None
        is_actxact = node.get_attribute_value("actxact") == 1

        return has_no_bias and is_actxact

    def get_mha_mode(self, q_input: Tensor, k_input: Tensor) -> MHAMode:
        reshaped_q = reshape_into_3D_tensor(q_input.get_shape(), self.nb_heads)
        reshaped_k = reshape_into_3D_tensor(k_input.get_shape(), self.nb_heads)
        assert len(reshaped_k) == 3 and len(reshaped_q) == 3  # sanity check

        k_dim = reshaped_q[2]  # sin_dh
        n_dim = reshaped_k[2]  # sin_kv

        has_bias = self.mask_add is not None
        has_mask = self.bias_add is not None
        mha_mode, _ = self.select_mha(
            Sin_kv=n_dim, Sin_dh=k_dim, Mask=has_mask, Bias=has_bias
        )

        if mha_mode == MHAMode.UNSUPPORTED:
            raise NoMatch("Unsupported mha mode found when computing from input shapes")

        return mha_mode

    @staticmethod
    def reshape_add_input(
        mask_shape: list[int], matmul_output_shape: list[int], num_heads: int
    ) -> list[int]:
        # Reshape the matmul output using the num_heads as batch dimension (outer most dimension)
        # After reshape the tensor will have the form: H x M x N
        matmul_out_shape_3d = reshape_into_3D_tensor(matmul_output_shape, num_heads)

        # mask should be at least a 3d tensor for running the comparison
        if len(mask_shape) < 3:
            mask_shape = reshape_into_3D_tensor(mask_shape, 1)

        # Expect the mask to be element wise or broadcastable on the matmul output.
        # We currently only support broadcast on the M and H dimension.
        # The N dimension is expected to be same on "mask" input and Matmul output.
        # Supported shapes for "mask" input are:
        # - H x M x N
        # - H x 1 x N
        # - factor_of(H) x 1 x N
        # - factor_of(H) x M x N
        if mask_shape[-1] != matmul_out_shape_3d[-1]:
            raise NoMatch(
                "Last dimension of the attention mask is expected to be equal to N"
            )

        M_dim = matmul_out_shape_3d[-2]
        while len(mask_shape) >= 3:
            # Broadcastable second dimension found
            if mask_shape[-2] == 1 or mask_shape[-2] == M_dim:
                break

            mask_shape = (
                mask_shape[:-3] + [np.prod(mask_shape[-3:-1])] + mask_shape[-1:]
            )

        if len(mask_shape) < 3:
            raise NoMatch(
                "Could not reshape mask input into a 3 tensor with M dim broadcastable"
            )

        # Now that the last and second last dimension are correct, fold the outer most dims into a single dimension
        if len(mask_shape) > 3:
            mask_shape = [np.prod(mask_shape[:-2])] + mask_shape[-2:]

        if mask_shape[0] != 1 and num_heads % mask_shape[0] != 0:
            raise NoMatch("Non broadcastable outer dimension")
        return mask_shape

    def extract_add_type(self, B_shape_3d: list[int]) -> str:
        """
        Extract the type of the Add node (different from the type computed by
        the binary op pattern). Here we are only interested about the dimensions that
        will be used in the kernel, so we disregard the first dimension (H).
        If:
         - MxN -> Element wise over the two last dim: mask
         - 1xN -> Broadcast on the second last dim: bias
        A input is the output of the QxK matmul
        B input is the input coming from outside the attention pattern
        """
        # We expect the following last two dim for:
        # - A -> MxN
        # - B MxN (mask) or 1xN (bias)
        if B_shape_3d[-2] != 1:
            return "mask"
        return "bias"

    def check_hoisting_add_validity(self, qxk_nb_heads: int) -> int:
        """
        Check that hoisting out the first add and connect it to the second input
        of the second add is valid.
        - We are expecting that a multi directional broadcast can happens on the two
          inputs of the first add after hoisting.
        - We are expecting the activations inputs type not change
        - We are expecting that the second input of the second add is an activation

        If all the expectation are met this function will return the new number of heads
        expected for the second add.
        """
        # Check that we can broadcast the new input if we hoist out add_1
        if not BinaryOpHelper.is_valid_multidirectional_broadcast(
            self.hoist_out("B").get_shape(), self.connection("B").get_shape()
        ):
            raise NoMatch(
                "Hoisting out the first add should give a valid multibroadcast Add"
            )

        # Check that the new output shape is meeting the requirement for the
        # "M" and "B" inputs of the attention pattern.
        new_hoist_add_out_shape = BinaryOpHelper.get_output_shape(
            self.hoist_out("B").get_shape(), self.connection("B").get_shape()
        )
        new_head_size = self.reshape_add_input(
            new_hoist_add_out_shape, self.connection("A").get_shape(), qxk_nb_heads
        )[0]

        if (
            self.connection("B").get_dtype() != self.hoist_out("B").get_dtype()
            or self.connection("B").get_dtype() != self.hoist_out("C").get_dtype()
        ):
            raise NoMatch(
                "Hoisting out the first add should not change the tensor data type"
            )

        # To hoist the first add out of the attention pattern we connect it
        # the second input of the second add (the one coming from outside the pattern).
        # That's why we need it be an activation input.
        if self.connection("B").check_initializer():
            raise NoMatch(
                "Second input of the second Add needs to be an activation in order to perform the required modification"
            )

        return new_head_size

    def match_add_nodes_types(
        self, first_add: Optional[Node], second_add: Optional[Node], qxk_nb_heads: int
    ) -> None:
        """
        Extract the type of the Add nodes that were matched and make sure
        that they do not have the same type.
        If that's the case we will hoist out the first add and check
        that the modification is valid.
        """
        # Get the type of the second add and save it for modify method
        second_add_type = None
        if second_add:
            B_shape_3d = Attention.reshape_add_input(
                second_add("B").get_shape(),
                second_add("A").get_shape(),
                qxk_nb_heads,
            )
            second_add_type = self.extract_add_type(B_shape_3d)

        if second_add_type == "bias":
            self.bias_add = second_add
            self.nb_bias = B_shape_3d[0]
        elif second_add_type == "mask":
            self.mask_add = second_add
            self.nb_mask = B_shape_3d[0]

        # Get the type of the first add and save it for modify method
        # If the first add has the same type as the second add, enable
        # hoisting it out before fusing.
        first_add_type = None
        if first_add:
            B_shape_3d = Attention.reshape_add_input(
                first_add("B").get_shape(),
                first_add("A").get_shape(),
                qxk_nb_heads,
            )
            first_add_type = self.extract_add_type(B_shape_3d)

        if second_add and first_add and second_add_type == first_add_type:
            self.hoist_out = first_add
            self.connection = second_add
        elif first_add_type == "bias":
            self.bias_add = first_add
            self.nb_bias = B_shape_3d[0]
        elif first_add_type == "mask":
            self.mask_add = first_add
            self.nb_mask = B_shape_3d[0]

        # If not asked to hoist out the add stop here, otherwise check that
        # hoisting is possible and what would be the new head count for the
        # second add.
        if not self.hoist_out:
            return

        new_head_size = self.check_hoisting_add_validity(qxk_nb_heads)
        if self.bias_add:
            self.nb_bias = new_head_size
        else:
            self.nb_mask = new_head_size

    def match_add_softmax_chain(self, qxk_nb_heads: int) -> None:
        """
        Match the add(s) + softmax chain and check if the first add needs to be
        hoisted out of the attention pattern.

                                          [bf16]
                                      act3 ----   act_2 -
                                               |         | [bf16]
                                               |         |
            Detect a chain of:         [bf16]  v         v
                act_1 -> DequantizeLinear -> Add_1 -> Add_2 -> Softmax -> QuantizeLinear

            And hoist out the first Add (in modify method):

                                    [bf16]
                           act_3 -----------
                                            |
                                  [bf16]    v
                           act_2 -------> Add_1 --
                                                  |
                                          [bf16]  v
                act_1 -> DequantizeLinear -----> Add_2 -> Softmax -> QuantizeLinear
        """
        # Match the two optional Adds and the required softmax
        curr_node = self.matmul_in("Y").require_node()
        first_add = None
        second_add = None

        self.sfm_dq = curr_node.require(opType.DequantizeLinear).require_node()
        curr_node = curr_node("y")
        if curr_node.check(opType.Add):
            first_add = curr_node.require_node()
            curr_node = first_add("C")

        # If the output is quantize make sure that it is followed by a dequantized with the same scale and zp
        if curr_node.check(opType.QuantizeLinear):
            if not curr_node("y").check(
                opType.DequantizeLinear
            ) or not self.check_qdq_equal_scale_zeropoint(curr_node("y"), curr_node):
                raise NoMatch("The first and second add ops needs to be bf16")
            curr_node = curr_node("y.y")

        if curr_node.check(opType.Add):
            second_add = curr_node.require_node()
            curr_node = second_add("C")

        # If the output is quantize make sure that it is followed by a dequantized with the same scale and zp
        if curr_node.check(opType.QuantizeLinear):
            if not curr_node("y").check(
                opType.DequantizeLinear
            ) or not self.check_qdq_equal_scale_zeropoint(curr_node("y"), curr_node):
                raise NoMatch("The second add and the softmax needs to be bf16")
            curr_node = curr_node("y.y")

        self.softmax = curr_node.require(opType.Softmax).require_node()
        self.sfm_quant = (
            curr_node("output").require(opType.QuantizeLinear).require_node()
        )

        # Extract the add node types and make sure that they do not have the same type
        self.match_add_nodes_types(first_add, second_add, qxk_nb_heads)

    def match_opt_k_transpose(self) -> None:
        """
        Match an optional transpose at the K input. This transpose is currently
        only supported if it is swapping the last two dimensions of the tensor.
        """
        self.has_k_transpose = self.matmul_in("B").check(opType.Transpose)
        if not self.has_k_transpose:
            return

        # check that the transpose is swapping the the last two dimension
        permutation = self.matmul_in("B").require_node().get_attributes()["perm"]
        permutation[-1], permutation[-2] = permutation[-2], permutation[-1]
        self.has_k_transpose &= permutation == sorted(permutation)

    def match(self) -> None:
        """
        Pattern:
                                  [ m ] -    [ b ] -       V -
                                         |          |         |
                                         v          v         v
            q, k -> MatMul_actxact -> [ Add ] -> [ Add] -> Softmax -> MatMul_actxact -> output
        """
        self.matmul_in = self.n.require(CategoryCheck(MatMul())).require_node()
        if not self.is_matmul_actxact_wo_bias(self.matmul_in):
            raise NoMatch("QxK Matmul should be an actxact w/o bias operation")

        q_shape = self.matmul_in("A").get_shape()
        k_shape = self.matmul_in("B").get_shape()
        qxk_nb_heads = BatchHelper.extract_matmul_batch_nb(q_shape, k_shape)

        # Match an optional transpose for the K input, with matching qdq nodes
        # and swapping the last two dimension.
        self.match_opt_k_transpose()

        # Check for the "Dequant -> [[Add ->] Add ->] Softmax -> Quant" chain after the QxK matmul
        self.match_add_softmax_chain(qxk_nb_heads)

        self.matmul_out = (
            self.sfm_quant("y").require(CategoryCheck(MatMul())).require_node()
        )
        if not self.is_matmul_actxact_wo_bias(self.matmul_out):
            raise NoMatch("SfmxV Matmul should be an actxact w/o bias operation")
        self.v_input_name = BinaryOpHelper.get_other_input_name(
            self.matmul_out, self.sfm_quant("y")
        )
        self.matmul_out_B_name = BinaryOpHelper.get_input_name(
            self.matmul_out, self.sfm_quant("y")
        )

        qxk_shape = self.matmul_out("A").get_shape()
        v_shape = self.matmul_out("B").get_shape()
        smf_cross_v_nb_heads = BatchHelper.extract_matmul_batch_nb(qxk_shape, v_shape)

        # extract the number of heads from the batch dimensions of the matmul
        if smf_cross_v_nb_heads != qxk_nb_heads:
            raise NoMatch("Different batching number for the matmul not supported")
        self.nb_heads = qxk_nb_heads
        self.nb_groups = qxk_nb_heads

        # Extract the MHA mode given the input shape of the QxK Matmul
        self.mha_mode = self.get_mha_mode(self.matmul_in("A"), self.matmul_in("B"))

    def hoist_out_add(self) -> None:
        """
        Hoist out the first add of the chain if it is has the same type as the second add.
        To do so we will connect the matmul QxK output to the second Add and the first Add
        will be performed before the second output of the second Add.

        Info: Drawing of the matched chain + transformation can be found in the docstr of the
          'match_add_softmax_chain' method.
        """
        # FIXME: turn it into a logger message when the logger is available in the matcher class
        print(
            f"Addition op '{self.hoist_out.get_name()}' will be hoisted of the attention pattern based on the node {self.softmax.get_name()} and should be connected the \"{'M' if self.mask_add else 'B'}\" input"
        )

        self.replace_input(self.hoist_out, self.hoist_out("A"), self.connection("B"))
        self.replace_input(self.connection, self.connection("B"), self.hoist_out("C"))

        new_output_shape = BinaryOpHelper.get_output_shape(
            self.hoist_out("A").get_shape(), self.hoist_out("B").get_shape()
        )
        self.hoist_out("C").set_shape(new_output_shape, self.hoist_out("C").get_dtype())

    def get_mha_op_type(self) -> str:
        """
        Compute the final op type using the mha mode, the bias and mask infos
        and the inputs and output type.
        """

        match self.mha_mode:
            case MHAMode.TWO_P_ONE | MHAMode.TWO_P_ONE_MINI:
                op_type = "MHA_2p1"
            case MHAMode.THREE_P_ZERO_MINI:
                op_type = "MHA_3p0_1col"
            case _:
                assert False, "Unreachable cannot have a unsupported mha mode"

        if self.bias_add:
            op_type += "_bias"
        if self.mask_add:
            op_type += "_mask"

        # Select output type depending on whether the second matmul is fused or
        # not.
        out_type = self.sfm_quant("y_zero_point").get_dtype()
        if self.mha_mode == MHAMode.THREE_P_ZERO_MINI:
            out_type = self.matmul_out("Y_zero_point").get_dtype()

        in1_type = self.matmul_in("A_zero_point").get_dtype()
        in2_type = self.matmul_in("B_zero_point").get_dtype()
        return op_type + "_qdq_" + in1_type + "x" + in2_type + "x" + out_type

    def modify(self) -> None:
        # Hoist the add out of the pattern before fusing if requested
        if self.hoist_out:
            self.hoist_out_add()

        # Minimal set of inputs which can be needed:
        # - MHA_2p1 w/o bias
        inputs = {
            # Dequantize input nodes
            "Q": self.matmul_in("A"),
            "K": self.matmul_in("B"),
            "M": None,  # input for the mask add
            "B": None,  # input for the bias add
            "V": None,
            # Q/K matmul dequantize scale and zp
            "Q_dq_scale": self.matmul_in("A_scale"),
            "Q_dq_zero_point": self.matmul_in("A_zero_point"),
            "K_dq_scale": self.matmul_in("B_scale"),
            "K_dq_zero_point": self.matmul_in("B_zero_point"),
            # Q/K matmul quantize scale and zp
            "QxK_q_scale": self.matmul_in("Y_scale"),
            "QxK_q_zero_point": self.matmul_in("Y_zero_point"),
            # Add(s) + Softmax chain qdq params
            "Softmax_dq_scale": self.sfm_dq("x_scale"),
            "Softmax_dq_zero_point": self.sfm_dq("x_zero_point"),
            "Softmax_q_scale": self.sfm_quant("y_scale"),
            "Softmax_q_zero_point": self.sfm_quant("y_zero_point"),
            # Sfm/V matmul dequantize scale and zp
            "V_dq_scale": None,
            "V_dq_zero_point": None,
            "SfmxV_B_dq_scale": None,
            "SfmxV_B_dq_zero_point": None,
            # Sfm/V matmul output scale and zp
            "SfmxV_q_scale": None,
            "SfmxV_q_zero_point": None,
        }
        outputs = {
            "output": self.sfm_quant("y"),
        }

        # Determine whether the q and k input shapes are too big or not to fuse the matmul out tensor
        # If we are in mha 3p0 mode the matmul out is fused with the other nodes
        nb_inputs = len(inputs)
        if self.mha_mode == MHAMode.THREE_P_ZERO_MINI:
            # The Sfm/V matmul can be fused
            inputs["V"] = self.matmul_out(self.v_input_name)

            # Add the qdq scale and zp of the Sfm/V matmul
            inputs["V_dq_scale"] = self.matmul_out(self.v_input_name + "_scale")
            inputs["V_dq_zero_point"] = self.matmul_out(
                self.v_input_name + "_zero_point"
            )

            inputs["SfmxV_B_dq_scale"] = self.matmul_out(
                self.matmul_out_B_name + "_scale"
            )
            inputs["SfmxV_B_dq_zero_point"] = self.matmul_out(
                self.matmul_out_B_name + "_zero_point"
            )

            inputs["SfmxV_q_scale"] = self.matmul_out("Y_scale")
            inputs["SfmxV_q_zero_point"] = self.matmul_out("Y_zero_point")
            assert len(inputs) == nb_inputs  # sanity check

            # New output of the fused node is now the output of the SfmxV matmul
            outputs = {
                "output": self.matmul_out("Y"),
            }

        attributes = dict()
        new_op_type = self.get_mha_op_type()

        # Add the mask input
        if self.mask_add:
            inputs["M"] = self.mask_add("B")
            attributes["num_mask"] = self.nb_mask
            attributes["mask_sequence"] = list(range(self.nb_mask)) * (
                self.nb_heads // self.nb_mask
            )

        # Add the bias input
        if self.bias_add:
            inputs["B"] = self.bias_add("B")
            attributes["num_bias"] = self.nb_bias
            attributes["bias_sequence"] = list(range(self.nb_bias)) * (
                self.nb_heads // self.nb_bias
            )

        copy_attributes = self.softmax.get_attributes()
        num_of_tensor_inputs = 3 if self.mha_mode == MHAMode.THREE_P_ZERO_MINI else 2
        attributes |= {
            "trans_to_nhwc": 0,
            # FIXME: remove mul_const attribute when not used anymore.
            "mul_const": 1,
            "num_heads": self.nb_heads,
            "InTransposeK": 0,
            "axis": copy_attributes["axis"],
            "orig_name": self.softmax.get_name(),
            "num_groups": self.nb_groups,
            "groups_sequence": list(range(self.nb_groups))
            * (self.nb_heads // self.nb_groups),
            "num_of_tensor_inputs": num_of_tensor_inputs,
            "mha_mode": int(self.mha_mode),
        }

        if self.has_k_transpose:
            # Add the relevant attributes related to the transpose of the K input
            attributes |= self.get_unrolled_perm_attribute(
                self.matmul_in("B").require_node(), "permK"
            )
            attributes["InTransposeK"] = 1

            # input of the fused node become the input of the transpose
            inputs["K"] = self.matmul_in("B.data")

        self.remove_node(self.softmax)
        self.add_node(
            type=new_op_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )

    def select_mha(
        self,
        Sin_kv: int,
        Sin_dh: int,
        Mask: bool,
        Bias: bool,
    ) -> Tuple[int, Optional[Dict[str, int]]]:
        stack_limit = 60352
        mask_flag = 1 if Mask else 0
        bias_flag = 1 if Bias else 0
        IfmBytes = 2
        OutBytes = 2
        TdmBytes = 4
        C0Bytes = 8
        QdqNodes = 6
        QdqPrm = 16
        QdqPrmBytes = 4
        CoreQdqPrmSize = QdqNodes * QdqPrm * QdqPrmBytes
        CoreAlignSize = 64
        Sq = 16
        Dh = Sin_dh
        Skv = (((Sin_kv - 1) // 8) + 1) * 8
        CoreQuerySize = Sq * Dh * IfmBytes
        CoreKeySize = Skv * Dh * IfmBytes
        CoreValSize = Skv * Dh * IfmBytes
        CoreMaskSize = mask_flag * Skv * IfmBytes
        CoreBiasSize = (Sq * Skv * IfmBytes) if bias_flag else 0

        def iceil(x: int, d: int) -> int:
            return ((x + d - 1) // d) * d

        CoreAct1SumSize = iceil(Skv * TdmBytes + 1, 256)
        CoreAct2SumSize = iceil(Skv * TdmBytes + 1, 256)
        CoreC0Size = Skv * C0Bytes
        CoreTdmBufSize = 2 * Skv * Sq * TdmBytes

        def _base_dict() -> Dict[str, int]:
            return {
                "Skv": Skv,
                "CoreQuerySize": CoreQuerySize,
                "CoreKeySize": CoreKeySize,
                "CoreValSize": CoreValSize,
                "CoreMaskSize": CoreMaskSize,
                "CoreBiasSize": CoreBiasSize,
                "CoreAct1SumSize": CoreAct1SumSize,
                "CoreAct2SumSize": CoreAct2SumSize,
                "CoreC0Size": CoreC0Size,
                "CoreTdmBufSize": CoreTdmBufSize,
                "mask_flag": mask_flag,
                "bias_flag": bias_flag,
            }

        layout1 = _base_dict()
        layout1["out_dim"] = Dh
        layout1["CoreOutSize"] = Sq * Dh * OutBytes
        layout1["CoreQueryPingAddr"] = 0
        layout1["CoreKeyPingAddr"] = iceil(
            layout1["CoreQueryPingAddr"] + CoreQuerySize, CoreAlignSize
        )
        layout1["CoreValPingAddr"] = iceil(
            layout1["CoreKeyPingAddr"] + CoreKeySize, CoreAlignSize
        )
        layout1["CoreMaskPingAddr"] = iceil(
            layout1["CoreValPingAddr"] + CoreValSize, CoreAlignSize
        )
        layout1["CoreQdqPingAddr"] = iceil(
            layout1["CoreMaskPingAddr"] + CoreMaskSize, CoreAlignSize
        )
        layout1["CoreTdm1Size"] = CoreTdmBufSize // 2
        layout1["CoreTdm1Addr"] = iceil(
            layout1["CoreQdqPingAddr"] + CoreQdqPrmSize, CoreAlignSize
        )
        layout1["CoreTdm2Size"] = CoreTdmBufSize // 2
        layout1["CoreTdm2Addr"] = iceil(
            layout1["CoreTdm1Addr"] + layout1["CoreTdm1Size"], CoreAlignSize
        )
        layout1["CoreOutPingAddr"] = layout1["CoreTdm1Addr"]
        layout1["CoreAct1SumAddr"] = iceil(
            layout1["CoreTdm2Addr"] + layout1["CoreTdm2Size"], CoreAlignSize
        )
        layout1["CoreAct2SumAddr"] = iceil(
            layout1["CoreAct1SumAddr"] + CoreAct1SumSize, CoreAlignSize
        )
        layout1["CoreC0Addr"] = iceil(
            layout1["CoreAct2SumAddr"] + CoreAct2SumSize, CoreAlignSize
        )
        layout1["CoreScratchAddr"] = iceil(
            layout1["CoreC0Addr"] + CoreC0Size, CoreAlignSize
        )
        if layout1["CoreScratchAddr"] < stack_limit:
            return MHAMode.THREE_P_ZERO_MINI, layout1

        layout2 = _base_dict()
        layout2["out_dim"] = Skv
        layout2["CoreOutSize"] = Sq * Skv * OutBytes
        layout2["CoreValPingAddr"] = 0
        layout2["CoreKeyPingAddr"] = 0
        layout2["CoreMaskPingAddr"] = iceil(
            layout2["CoreKeyPingAddr"] + CoreKeySize, CoreAlignSize
        )
        layout2["CoreQdqPingAddr"] = iceil(
            layout2["CoreMaskPingAddr"] + CoreMaskSize, CoreAlignSize
        )
        layout2["CoreTdm1Size"] = CoreTdmBufSize // 2
        layout2["CoreTdm1Addr"] = iceil(
            layout2["CoreQdqPingAddr"] + CoreQdqPrmSize, CoreAlignSize
        )
        layout2["CoreTdm2Size"] = CoreTdmBufSize // 2
        layout2["CoreTdm2Addr"] = iceil(
            layout2["CoreTdm1Addr"] + layout2["CoreTdm1Size"], CoreAlignSize
        )
        layout2["CoreOutPingAddr"] = layout2["CoreTdm2Addr"]
        layout2["CoreAct2SumAddr"] = iceil(
            layout2["CoreTdm2Addr"] + layout2["CoreTdm2Size"], CoreAlignSize
        )
        core_c0_after_act2 = iceil(
            layout2["CoreAct2SumAddr"] + CoreAct2SumSize, CoreAlignSize
        )
        layout2["CoreScratchAddr"] = iceil(
            core_c0_after_act2 + CoreC0Size, CoreAlignSize
        )
        layout2["CoreQueryPingAddr"] = layout2["CoreTdm2Addr"]
        layout2["CoreC0Addr"] = iceil(
            layout2["CoreTdm2Addr"] + CoreQuerySize, CoreAlignSize
        )
        layout2["CoreAct1SumAddr"] = iceil(
            layout2["CoreC0Addr"] + CoreC0Size, CoreAlignSize
        )
        if (
            layout2["CoreAct1SumAddr"] + CoreAct1SumSize <= layout2["CoreAct2SumAddr"]
        ) and (layout2["CoreScratchAddr"] < stack_limit):
            return MHAMode.TWO_P_ONE_MINI, layout2
        AieRows = 4
        AieCols = 8
        Num4x4 = 2 if AieCols == 8 else 1
        NumAieCompCols = AieCols // Num4x4
        Ngran = 8
        N_next = iceil(Sin_kv, AieRows * NumAieCompCols * Ngran)
        Skv_alt = N_next // (AieRows * NumAieCompCols)
        CoreTdm1Size = Sq * Skv_alt * OutBytes
        CoreTdm2Size = Sq * Skv_alt * OutBytes

        layout3: Dict[str, int] = {
            "Skv": Skv_alt,
            "out_dim": Skv_alt,
            "CoreQuerySize": (Sq * Dh * IfmBytes) * Num4x4,
            "CoreKeySize": Skv_alt * Dh * IfmBytes,
            "CoreValSize": 0,
            "CoreMaskSize": Skv_alt * IfmBytes,
            "CoreBiasSize": (bias_flag * Sq * Skv_alt * IfmBytes),
            "CoreOutSize": Sq * Skv_alt * OutBytes,
            "CoreAct1SumSize": iceil(TdmBytes * Sq, 512),
            "CoreAct2SumSize": iceil(Skv_alt * TdmBytes, 512),
            "CoreC0Size": Skv_alt * C0Bytes,
            "CoreTdmBufSize": 2 * Skv_alt * Sq * TdmBytes,
            "mask_flag": mask_flag,
            "bias_flag": bias_flag,
            "CoreTdm1Size": CoreTdm1Size,
            "CoreTdm2Size": CoreTdm2Size,
            "CoreValPingAddr": 0,
        }
        layout3["CoreQueryPingAddr"] = 0
        layout3["CoreKeyPingAddr"] = iceil(
            layout3["CoreQueryPingAddr"] + layout3["CoreQuerySize"], CoreAlignSize
        )
        layout3["CoreTdm1Addr"] = iceil(
            layout3["CoreKeyPingAddr"] + layout3["CoreKeySize"], CoreAlignSize
        )
        layout3["CoreTdm2Addr"] = iceil(
            layout3["CoreTdm1Addr"] + CoreTdm1Size, CoreAlignSize
        )
        layout3["CoreOutPingAddr"] = layout3["CoreTdm2Addr"]
        layout3["CoreQdqPingAddr"] = iceil(
            layout3["CoreTdm2Addr"] + CoreTdm2Size, CoreAlignSize
        )
        layout3["CoreAct1SumAddr"] = iceil(
            layout3["CoreQdqPingAddr"] + CoreQdqPrmSize, CoreAlignSize
        )
        layout3["CoreAct2SumAddr"] = iceil(
            layout3["CoreAct1SumAddr"] + layout3["CoreAct1SumSize"], CoreAlignSize
        )
        layout3["CoreC0Addr"] = iceil(
            layout3["CoreAct2SumAddr"] + layout3["CoreAct2SumSize"], CoreAlignSize
        )
        layout3["CoreMaskPingAddr"] = iceil(
            layout3["CoreC0Addr"] + layout3["CoreC0Size"], CoreAlignSize
        )
        layout3["CoreScratchAddr"] = iceil(
            layout3["CoreMaskPingAddr"] + layout3["CoreMaskSize"], CoreAlignSize
        )
        if (
            layout3["CoreOutPingAddr"] + layout3["CoreOutSize"]
            <= layout3["CoreQdqPingAddr"]
        ) and (layout3["CoreScratchAddr"] < stack_limit):
            return MHAMode.TWO_P_ONE, layout3

        return MHAMode.UNSUPPORTED, None
