# fmt: on
import logging

import numpy as np
from dataclasses import dataclass
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    opType,
)
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_state import (
    ReshapeTransposeDownState,
    ReshapeTransposeUpState,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    Tensor,
    MatcherError,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_helper import (
    ReshapeTransposeHelper as RTRHelper,
)
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import (
    TransposeHelper,
)
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_state import (
    dim,
)
from OGOAT.src.L1_fusion.py_match.helpers.perm_helper import PermutationHelper
from OGOAT.src.L1_fusion.py_match.basic.categories import matmul_category
from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import FusionConfigs

class MatMulTranspose(Matcher, TransposeHelper, PermutationHelper):
    """
    This Pattern will fuse Input Reshape-Transpose with Matmul if it's actxact.
    """
    def __init__(self, mode: str = "ActAct"):
        super().__init__()
        self.mode = mode
    
    @dataclass(frozen=True)
    class Configs:
        MAX_RANK: int
        """ 
        Maximum tensor rank (after optional leading-1 normalization) 
        for which we attempt reshape+transpose fusion.
        """

        ENABLE_INNERMOST_DIM_FUSION: bool
        """
        If False (default), do NOT fuse transposes that move the innermost dimension
        (i.e., cases where perm[-1] != last). Set True to allow those fusions.
        """

        ENABLE_ACTxWGT_FUSION: bool
        """
        If True, allow fuse/match for act x wgt matmul cases (in addition to act x act).
        """

        ENABLE_OUTPUT_RTR_FUSION: bool = False
        """
        If True, enable matching of Reshape+Transpose chains on the output
        (not just direct Transpose). Controlled by prebuilt_mladf_mha flag.
        """

    @classmethod
    def get_configs(cls) -> "MatMulTranspose.Configs":
        MMT_configs = FusionConfigs.get_fusion_configs().MMT_configs
        configs = cls.Configs(**MMT_configs)
        return configs

    @property
    def MAX_RANK(self):
        """
        Maximum tensor rank (after optional leading-1 normalization)
        for which we attempt reshape+transpose fusion.
        """
        configs = self.get_configs()
        return configs.MAX_RANK

    @property
    def ENABLE_INNERMOST_DIM_FUSION(self):
        """
        If False (default), do NOT fuse transposes that move the innermost dimension
        (i.e., cases where perm[-1] != last). Set True to allow those fusions.
        """
        configs = self.get_configs()
        return configs.ENABLE_INNERMOST_DIM_FUSION

    @property
    def ENABLE_ACTxWGT_FUSION(self):
        """
        If True, allow fuse/match for act x wgt matmul cases (in addition to act x act).
        """
        configs = self.get_configs()
        return configs.ENABLE_ACTxWGT_FUSION

    @property
    def ENABLE_OUTPUT_RTR_FUSION(self):
        """
        If True, enable matching of Reshape+Transpose chains on the output.
        Controlled by prebuilt_mladf_mha flag.
        """
        configs = self.get_configs()
        return configs.ENABLE_OUTPUT_RTR_FUSION

    ##########################################

    dependencies = [MatMul()]

    def match(self) -> None:
        n = self.n = self.n.require_node().with_walk_cfg(WalkCfgPlain())
        n.require(CategoryCheck(matmul_category))

        self.fuse_A = self.fuse_B = self.fuse_out = False
        self.state_A = self.state_B = self.state_out = None
        self.inA = n("A")
        self.inB = n("B")
        self.out = n("Y")

        # check for input transpose only if it's act x act
        if n.get_attribute_value("actxact"):
            self.state_A = self.match_up_rtr_chain(self.inA, self.n.get_name(), "A")
            self.state_B = self.match_up_rtr_chain(self.inB, self.n.get_name(), "B")
            if self.ENABLE_OUTPUT_RTR_FUSION:
                # Check for output Reshape+Transpose chain based on configuration
                self.state_out = self.match_down_rtr_chain(self.out, self.n.get_name(), "Y")
            else:
                # only match direct Transpose
                self.state_out = n("Y").check(opType.Transpose)
        elif self.ENABLE_ACTxWGT_FUSION or self.mode == "ActWgt":
            # check for input transpose if act x wgt
            self.state_A = self.match_up_rtr_chain(self.inA, self.n.get_name(), "A")
            if self.ENABLE_OUTPUT_RTR_FUSION:
                # Check for output Reshape+Transpose chain based on configuration
                self.state_out = self.match_down_rtr_chain(self.out, self.n.get_name(), "Y")
            else:
                # Default behavior - only match direct Transpose
                self.state_out = n("Y").check(opType.Transpose)

        if self.state_out:
            if self.ENABLE_OUTPUT_RTR_FUSION:
                # Handle Reshape+Transpose chain on output
                if self.state_out.has_matched_transpose():
                    perm = self.state_out.perm
                    if self.ENABLE_INNERMOST_DIM_FUSION or perm[-1] == len(perm) - 1:
                        self.out = self.state_out.tail
                        self.state_out = self.state_out.optimize_state()
                        self.fuse_out = (
                            self.state_out.has_nontrivial_transpose()
                            and self._isStateOrShapeToND(self.state_out)
                        )
            else:
                # Existing direct transpose handling
                self.transpose_node_out = n("Y").require_node()
                perm_attr = self.transpose_node_out.get_attribute_value("perm")
                out_shape = n("Y").require_tensor().get_shape()
                if (
                    not perm_attr
                    or self.ENABLE_INNERMOST_DIM_FUSION
                    or perm_attr[-1] == len(perm_attr) - 1
                ):
                    self.fuse_out = PermutationHelper.is_nontrivial_permutation(
                        perm_attr
                    ) and self._isStateOrShapeToND(out_shape)

        if self.state_A and self.state_A.has_matched_transpose():
            perm = self.state_A.perm
            if self.ENABLE_INNERMOST_DIM_FUSION or perm[-1] == len(perm) - 1:
                self.inA = self.state_A.head
                self.state_A = self.state_A.optimize_state()
                self.fuse_A = (
                    self.state_A.has_nontrivial_transpose()
                    and self._isStateOrShapeToND(self.state_A)
                )

        if self.state_B and self.state_B.has_matched_transpose():
            perm = self.state_B.perm
            if self.ENABLE_INNERMOST_DIM_FUSION or perm[-1] == len(perm) - 1:
                self.inB = self.state_B.head
                self.state_B = self.state_B.optimize_state()
                self.fuse_B = (
                    self.state_B.has_nontrivial_transpose()
                    and self._isStateOrShapeToND(self.state_B)
                )

        if self.fuse_A and self.fuse_B:
            self._unify_input_dimensions()

        if not any((self.fuse_A, self.fuse_B, self.fuse_out)):
            raise MatcherError(
                f"No transposed inputs and outputs found at Node={n.get_name()}"
            )

    def _match_output_reshape_transpose(self, output):
        """
        Match Reshape followed by Transpose pattern on output.
        This is a simplified downstream matching for outputs.
        """
        # Check if output goes to a Reshape node
        if output.check(opType.Reshape):
            reshape_node = output.require_node()
            # Check if Reshape output goes to a Transpose node
            reshape_output = reshape_node("reshaped")
            if reshape_output.check(opType.Transpose):
                transpose_node = reshape_output.require_node()
                # Create a simple state object to hold the information
                class OutputReshapeTransposeState:
                    def __init__(self, reshape_node, transpose_node):
                        self.reshape_node = reshape_node
                        self.transpose_node = transpose_node
                        self.tail = transpose_node
                        self.perm = transpose_node.get_attribute_value("perm")
                        
                        # Get shapes for the transformation
                        # Capture perm in local variable for lambda closure
                        perm_value = self.perm
                        self.Transpose = type('obj', (object,), {
                            'input_shape': reshape_node("reshaped").require_tensor().get_shape(),
                            'output_shape': transpose_node("transposed").require_tensor().get_shape(),
                            'get_dim': lambda self: len(perm_value) if perm_value else 0
                        })()
                    
                    def has_matched_transpose(self):
                        return True
                    
                    def has_nontrivial_transpose(self):
                        return self.perm and not PermutationHelper.is_identity_perm(self.perm)
                    
                    def optimize_state(self):
                        # For now, just return self - no optimization
                        return self
                
                return OutputReshapeTransposeState(reshape_node, transpose_node)
        
        # If no Reshape+Transpose chain, check for direct Transpose
        return output.check(opType.Transpose)

    def _unify_input_dimensions(self):
        """
        Inserts '1's to input shapes until both inputs have same dim
        """
        reshapeA, reshapeB = self.state_A.ReshapeIN, self.state_B.ReshapeIN
        dimA, dimB = dim(reshapeA.output_shape), dim(reshapeB.output_shape)
        unified_dim = max(dimA, dimB)

        if (nullityA := unified_dim - dimA) > 0:
            # insert 'nullityA' ones into stateA.ReshapeIN
            shapeA_ = RTRHelper.insert_ones_to_shape_left(
                reshapeA.output_shape, nullityA
            )
            permA_ = RTRHelper.insert_ones_to_perm_left(
                self.state_A.Transpose.perm, nullityA
            )
            self.state_A = self.state_A.get_next_state(
                transp_shape=shapeA_, perm=permA_
            )
            pass

        if (nullityB := unified_dim - dimB) > 0:
            # insert 'nullityB' ones into stateB.ReshapeIN
            shapeB_ = RTRHelper.insert_ones_to_shape_left(
                reshapeB.output_shape, nullityB
            )
            permB_ = RTRHelper.insert_ones_to_perm_left(
                self.state_B.Transpose.perm, nullityB
            )
            self.state_B = self.state_B.get_next_state(
                transp_shape=shapeB_, perm=permB_
            )
            pass

        return

    def _add_input_state_reshape(
        self, state: ReshapeTransposeUpState, input_name: str, input_dtype: str
    ) -> Tensor:
        matmul_name = self.n.get_name()
        state_input = Tensor(
            self.n._model_dict,
            self.n._walk_cfg,
            f"{matmul_name}_input_{input_name}",
            None,
        )
        state_input.set_shape(state.Transpose.input_shape.copy(), input_dtype)

        reshape_name = f"{matmul_name}_Reshape_{input_name}"
        shape_initializer = self.add_initializer(
            f"{reshape_name}_shape",
            np.array(state_input.get_shape()).astype(np.int64),
            input_dtype,
        )

        self.add_node(
            type="Reshape",
            domain="",
            inputs={
                "data": state.head.require_tensor(),
                "shape": shape_initializer,
            },
            outputs={"reshaped": state_input},
            attributes={
                "allowzero": 0,
                "orig_name": state.get_last_matched_name(reshape_name),
            },
            new_name=reshape_name,
            add_matcher_name=False,
        )
        return state_input

    def _add_output_state_reshape(
        self, state: ReshapeTransposeDownState, output_dtype: str
    ) -> Tensor:
        matmul_name = self.n.get_name()
        output_name = "Y"
        state_output = Tensor(
            self.n._model_dict,
            self.n._walk_cfg,
            f"{matmul_name}_output_{output_name}",
            None,
        )
        state_output.set_shape(state.Transpose.output_shape.copy(), output_dtype)

        reshape_name = f"{matmul_name}_Reshape_{output_name}"
        shape_initializer = self.add_initializer(
            f"{reshape_name}_shape",
            np.array(state.out_shape).astype(np.int64),
            output_dtype,
        )

        self.add_node(
            type="Reshape",
            domain="",
            inputs={
                "data": state_output,
                "shape": shape_initializer,
            },
            outputs={"reshaped": state.tail.require_tensor()},
            attributes={
                "allowzero": 0,
                "orig_name": state.get_last_matched_name(reshape_name),
            },
            new_name=reshape_name,
            add_matcher_name=False,
        )
        return state_output

    def _isStateOrShapeToND(
        self, inp: ReshapeTransposeUpState | ReshapeTransposeDownState | list[int]
    ) -> bool:
        if isinstance(inp, list):
            shape_dimension = len(inp)
            start_dim = inp[0]
        else:
            shape_dimension = inp.Transpose.get_dim()
            start_dim = inp.Transpose.input_shape[0]

        if shape_dimension > 1 and start_dim == 1:
            shape_dimension -= 1
        return shape_dimension <= self.MAX_RANK

    def get_matmul_transpose_attributes(self):
        # Add the attribute from the original matmul
        attributes = self.n.get_attributes()

        # Add the transpose and shape attribute for the inputs and output
        attributes |= {
            "InTransposeA": 0,
            "InTransposeB": 0,
            "OutTranspose": 0,
            "MatMul_InShapeA": self.n("A").require_tensor().get_shape(),
            "MatMul_InShapeB": self.n("B").require_tensor().get_shape(),
            "MatMul_OutShape": self.n("Y").require_tensor().get_shape(),
            "TransposeA_InShape": self.n("A").require_tensor().get_shape(),
            "TransposeA_OutShape": self.n("A").require_tensor().get_shape(),
            "TransposeB_InShape": self.n("B").require_tensor().get_shape(),
            "TransposeB_OutShape": self.n("B").require_tensor().get_shape(),
            "TransposeY_InShape": self.n("Y").require_tensor().get_shape(),
            "TransposeY_OutShape": self.n("Y").require_tensor().get_shape(),
            "permA": [0, 1, 2, 3],
            "permB": [0, 1, 2, 3],
            "permY": [0, 1, 2, 3],
        }

        # If one of the transpose was fuse update the relevant attributes value
        if self.fuse_A and self.state_A.has_nontrivial_transpose():
            attributes["InTransposeA"] = 1
            attributes["permA"] = self.state_A.perm
            attributes["TransposeA_InShape"] = self.state_A.Transpose.input_shape
            attributes["TransposeA_OutShape"] = self.state_A.Transpose.output_shape
            self.has_fused_transpose = True

        if self.fuse_B and self.state_B.has_nontrivial_transpose():
            attributes["InTransposeB"] = 1
            attributes["permB"] = self.state_B.perm
            attributes["TransposeB_InShape"] = self.state_B.Transpose.input_shape
            attributes["TransposeB_OutShape"] = self.state_B.Transpose.output_shape
            self.has_fused_transpose = True

        if self.fuse_out:
            attributes["OutTranspose"] = 1
            if self.ENABLE_OUTPUT_RTR_FUSION:
                attributes["permY"] = self.state_out.perm
                attributes["TransposeY_InShape"] = self.state_out.Transpose.input_shape
                attributes["TransposeY_OutShape"] = self.state_out.Transpose.output_shape
            else:
                # Handle direct transpose attributes
                attributes["permY"] = self.transpose_node_out.get_attribute_value("perm")
                attributes["TransposeY_InShape"] = (
                    self.transpose_node_out("data").require_tensor().get_shape()
                )
                attributes["TransposeY_OutShape"] = (
                    self.transpose_node_out("transposed").require_tensor().get_shape()
                )
            self.has_fused_transpose = True

        return attributes

    def modify(self) -> None:
        n = self.n.require_node().with_walk_cfg(WalkCfgPlain())

        # INPUT DTYPES
        dtype_A = n("A_zero_point").require_tensor().get_dtype()
        dtype_B = n("B_zero_point").require_tensor().get_dtype()
        dtype_Y = n("Y_zero_point").require_tensor().get_dtype()

        # FUSED NODE: OP-TYPE
        type_name = n.get_op_type()
        op_type_prefix, *_ = type_name.rsplit("_", 1)

        # FUSED NODE: I/O + ATTRIBUTES
        inputs = n.get_inputs_dict()
        outputs = n.get_outputs_dict()
        attributes = self.get_matmul_transpose_attributes()

        if self.fuse_A:  # input_A has matched a ReshapeTranspose pattern
            state_input_A = self._add_input_state_reshape(self.state_A, "A", dtype_A)
            inputs |= {"A": state_input_A}
            attributes["dbgA"] = self.state_A.ReshapeOUT.output_shape

        if self.fuse_B:  # input_B has matched a ReshapeTranspose pattern
            state_input_B = self._add_input_state_reshape(self.state_B, "B", dtype_B)
            inputs |= {"B": state_input_B}
            attributes["dbgB"] = self.state_B.ReshapeOUT.output_shape

        if self.fuse_out:  # output has matched a ReshapeTranspose pattern
            if self.ENABLE_OUTPUT_RTR_FUSION:
                # For Reshape+Transpose chain, add Reshape
                state_output = self._add_output_state_reshape(self.state_out, dtype_Y)
                outputs |= {"Y": state_output}
                attributes["dbgOut"] = self.state_out.ReshapeIN.input_shape
            else:
                # For direct transpose, use the existing connection
                outputs |= {"Y": n("Y.transposed")}

        new_type = f"{op_type_prefix}{'_Transpose_' if self.has_fused_transpose else '_'}{dtype_A}x{dtype_B}x{dtype_Y}"

        # FUSED NODE
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )


class MatMulTransposeActWgt(MatMulTranspose):
    """
    MatMulTranspose variant specifically for ActWgt mode.
    This is a separate class to allow both ActAct and ActWgt patterns
    to run independently without being blocked by matchers_done tracking.
    """
    def __init__(self):
        super().__init__(mode="ActWgt")
