# fmt: on
from typing import Optional, List

import numpy as np
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul
from OGOAT.src.L1_fusion.py_match.basic.gemm import Gemm
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    opType,
)
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherError,
    Node,
    Tensor,
)


class LinearSliceTranspose(Matcher, TransposeHelper, QDQHelper):
    """
    Pattern matcher for QKV projection pattern with Slice and Transpose.
    
    Matches:
        Gemm/MatMul -> [optional Reshapes] -> Slice/Slice_qdq (x3) -> [Squeeze] -> Transpose (x3 for Q,K,V)
    
    Transforms to:
        Gemm -> Q (retained) -> DQ -> Reshape -> Q -> DQ -> Transpose -> Q -> DQ -> Slice (x3)
        With additional Transpose for K branch
    """
    
    dependencies = [MatMul(), Gemm()]
    
    def __init__(self):
        super().__init__()
        self.split_node: Optional[Node] = None
        self.slice_nodes: List[Node] = []
        self.use_slices: bool = False
        self.transpose_nodes: List[Node] = []
        self.split_outputs: List[Tensor] = []
        self.skipped_nodes: List[Node] = []
        self.num_heads: int = 0
        self.head_dim: int = 0
        self.seq_len: int = 0
        self.hidden_size: int = 0
        
    def _create_tensor(self, name: str, shape: list[int], dtype: str) -> Tensor:
        """Create a new tensor with given properties."""
        tensor = Tensor(
            self.n._model_dict,
            self.n._walk_cfg,
            name,
            None,
        )
        tensor.set_shape(shape, dtype)
        return tensor
        
    def match(self) -> None:
        """Match the QKV projection pattern."""
        n = self.n
        
        # Match Gemm or MatMul
        n.require(CategoryCheck(MatMul()) | CategoryCheck(Gemm()))

        
        # Check what follows the linear op
        # For fused MatMul nodes, Y is already a tensor (the quantized output)
        # For unfused nodes, Y should be a QuantizeLinear node
        if n.get_op_type().startswith("MatMul_qdq") or n.get_op_type().startswith("Gemm_qdq"):
            # Already fused - Y is a tensor
            if not n("Y").check_tensor(): 
                raise MatcherError("LinearSliceTranspose: Expected tensor output for fused MatMul")
        else:
            # Not fused - Y should be QuantizeLinear
            if not n("Y").check(opType.QuantizeLinear):
                raise MatcherError("LinearSliceTranspose: Expected quantized Gemm/MatMul")
            
        # Get the output shape to determine dimensions
        gemm_output_shape = n("Y").get_shape()
        if len(gemm_output_shape) != 2:
            raise MatcherError("LinearSliceTranspose: Expected 2D output from Gemm/MatMul")
            
        self.seq_len = gemm_output_shape[0]
        self.hidden_size = gemm_output_shape[1]
        
        # Skip to find main readers (handling optional DQ-Q pairs)
        main_readers = n("Y").skip().require_tensor().get_readers()
        
        # Skip extra DQ and Q after main node if present
        if (
            len(main_readers) == 1
            and n("Y").check(opType.DequantizeLinear)
            and n("Y.y").check(opType.QuantizeLinear)
            and self.check_qdq_equal_scale_zeropoint(n("Y"), n("Y.y"))
        ):
            main_readers = n("Y.y.y").skip().require_tensor().get_readers()
            
        # Check for Split node OR multiple Slice nodes (including fused Slice_qdq)
        # After SplitToSlice, we'll have 3 Slice nodes as readers
        if len(main_readers) == 3:
            # Check if all are Slice nodes (including fused Slice_qdq variants)
            all_slices = all(
                reader.get_op_type().startswith("Slice") or 
                reader.get_op_type().startswith("Slice_qdq") 
                for reader in main_readers
            )
            if all_slices:
                self.slice_nodes = main_readers
                self.use_slices = True
            else:
                raise MatcherError("LinearSliceTranspose: Expected 3 Slice nodes")
        elif len(main_readers) == 1:
            potential_split = main_readers[0]
            self.use_slices = False
        else:
            raise MatcherError("LinearSliceTranspose: Unexpected number of readers")
        
        # Check if some nodes were skipped (optional noops - support multiple Reshapes)
        skipped_readers = n("Y").get_readers()
        self.skipped_nodes = []
        while skipped_readers != main_readers:
            if len(skipped_readers) != 1:
                raise MatcherError("LinearSliceTranspose: Optional noop must be the only reader")
                
            skipped_node = skipped_readers[0]
            
            # Support Reshape as optional noop (can have multiple)
            if skipped_node.check(opType.Reshape):
                self.skipped_nodes.append(skipped_node)
                skipped_readers = skipped_node("reshaped").get_readers()
            else:
                raise MatcherError("LinearSliceTranspose: Unsupported optional noop")
                
        if self.use_slices:
            # Already have Slice nodes - verify they slice the correct axis
            self.split_outputs = []
            for i, slice_node in enumerate(self.slice_nodes):
                axes = slice_node.get_attribute_value("axes")
                if axes != [2] and axes != [-3]:
                    raise MatcherError("LinearSliceTranspose: Slice must be on third last dimension")
                # For Slice nodes, the output is named "output"
                self.split_outputs.append(slice_node("output"))
        else:
            # Verify it's a Split node
            if not potential_split.check(opType.Split):
                # Let's see what outputs this node has
                try:
                    attrs = potential_split.get_attributes()
                except:
                    pass
                raise MatcherError("LinearSliceTranspose: Expected Split node")
                
            self.split_node = potential_split
            
            # Check Split parameters
            axis = self.split_node.get_attribute_value("axis")
            if axis != -3 and axis != 2:  # Should split on third last dimension
                raise MatcherError("LinearSliceTranspose: Split must be on third last dimension")
                
            # Get Split outputs (should be 3 for Q, K, V)
            # Split node outputs are named output_0, output_1, output_2
            self.split_outputs = []
            for i in range(3):
                output_name = f"output_{i}"
                if not self.split_node(output_name).check_tensor():
                    raise MatcherError(f"LinearSliceTranspose: Missing Split output {i}")
                self.split_outputs.append(self.split_node(output_name))
            
        # Verify each output goes through Transpose (Squeeze is auto-skipped)
        self.transpose_nodes = []
        for i, split_output in enumerate(self.split_outputs):
            # Skip through any intermediate nodes to find Transpose
            readers = split_output.skip().require_tensor().get_readers()
            
            if len(readers) != 1:
                raise MatcherError(f"LinearSliceTranspose: Slice output {i} should have single reader")
                
            transpose_node = readers[0]
            if not transpose_node.check(opType.Transpose):
                raise MatcherError(f"LinearSliceTranspose: Expected Transpose after Slice output {i}")
                
            self.transpose_nodes.append(transpose_node)
            
        # Extract head configuration from first Transpose output
        first_transpose_output = self.transpose_nodes[0]("transposed").get_shape()
        if len(first_transpose_output) != 4:
            raise MatcherError("LinearSliceTranspose: Expected 4D output from Transpose")
            
        # Shape should be [batch, heads, seq, head_dim]
        if first_transpose_output[0] != 1:
            raise MatcherError("LinearSliceTranspose: Batch dimension must be 1")
            
        self.num_heads = first_transpose_output[1] * 3  # Total heads for all Q, K, V
        self.head_dim = first_transpose_output[3]
        
        # Verify dimensions are consistent
        if self.hidden_size != self.num_heads * self.head_dim:
            raise MatcherError("LinearSliceTranspose: Inconsistent dimensions")
        
            
    def modify(self) -> None:
        """Apply the transformation to create the new pattern."""
        n = self.n.require_node()
        
        # Get the MatMul's int8 output directly
        gemm_output = n("Y")  # This is already int8 from MatMul_qdq
        
        # Get the dtype from the MatMul output
        gemm_dtype = gemm_output.get_dtype()  # Should be int8 or uint8
        
        # Reshape to [1, seq_len, num_heads, head_dim] - operates on int8
        reshape_shape = self.add_initializer(
            n.get_name() + "_reshape_shape",
            np.array([1, self.seq_len, self.num_heads, self.head_dim], dtype=np.int64)
        )
        
        reshape_output = self._create_tensor(
            n.get_name() + "_reshaped",
            [1, self.seq_len, self.num_heads, self.head_dim],
            gemm_dtype  # Keep same dtype as MatMul output
        )
        
        reshape = self.add_node(
            type="Reshape",
            domain="ai.onnx.contrib",
            inputs={
                "data": gemm_output,
                "shape": reshape_shape,
            },
            outputs={"reshaped": reshape_output},
            attributes={},
            new_name=n.get_name() + "_reshape",
        )
        
        # Transpose to [batch, heads, seq, head_dim] - operates on int8
        transpose_output = self._create_tensor(
            n.get_name() + "_transposed",
            [1, self.num_heads, self.seq_len, self.head_dim],
            gemm_dtype  # Keep same dtype
        )
        
        transpose = self.add_node(
            type="Transpose",
            domain="ai.onnx.contrib",
            inputs={"data": reshape_output},
            outputs={"transposed": transpose_output},
            attributes={"perm": [0, 2, 1, 3]},
            new_name=n.get_name() + "_transpose",
        )
        
        # Create Slice_runtime nodes for Q, K, V
        heads_per_part = self.num_heads // 3
        
        for idx, (name, needs_extra_transpose) in enumerate([
            ("Q", False),
            ("K", True),   # K needs extra transpose
            ("V", False)
        ]):
            start = idx * heads_per_part
            end = (idx + 1) * heads_per_part
            
            # Get the original output tensor that we need to connect to
            final_output = self.transpose_nodes[idx]("transposed")
            
            # Determine the shape for the slice output
            if needs_extra_transpose:
                # K will be transposed, so intermediate shape is different
                slice_shape = [1, heads_per_part, self.seq_len, self.head_dim]
            else:
                # Q and V go directly to output
                slice_shape = final_output.get_shape()
            
            # Create Slice_runtime - operates on int8
            slice_output = self._create_tensor(
                n.get_name() + f"_{name}_sliced",
                slice_shape,
                gemm_dtype  # Keep same dtype
            )
            
            # Create initializers for slice parameters
            axes_init = self.add_initializer(
                n.get_name() + f"_{name}_axes",
                np.array([1], dtype=np.int64)
            )
            starts_init = self.add_initializer(
                n.get_name() + f"_{name}_starts",
                np.array([start], dtype=np.int64)
            )
            ends_init = self.add_initializer(
                n.get_name() + f"_{name}_ends",
                np.array([end], dtype=np.int64)
            )
            
            slice_node = self.add_node(
                type="Slice_runtime",
                domain="ai.onnx.contrib",
                inputs={
                    "data": transpose_output,
                    "starts": starts_init,
                    "ends": ends_init,
                    "axes": axes_init,
                },
                outputs={"output": slice_output if not needs_extra_transpose else slice_output},
                attributes={
                    "axes": [1],
                    "starts": [start],
                    "start": [start],
                    "ends": [end],
                    "end": [end],
                    "steps": [1]
                },
                new_name=n.get_name() + f"_{name}_slice_runtime",
            )
            
            # For K, add extra transpose (operates on int8)
            if needs_extra_transpose:
                # Transpose K: [0, 1, 3, 2]
                # Output directly to the final tensor
                self.add_node(
                    type="Transpose",
                    domain="ai.onnx.contrib",
                    inputs={"data": slice_output},
                    outputs={"transposed": final_output},  # Direct output to final tensor
                    attributes={"perm": [0, 1, 3, 2]},
                    new_name=n.get_name() + f"_{name}_transpose",
                )
            else:
                # For Q and V, slice output goes directly to final output
                # Update the slice node to output directly to final tensor
                self.remove_node(slice_node)
                self.add_node(
                    type="Slice_runtime",
                    domain="ai.onnx.contrib",
                    inputs={
                        "data": transpose_output,
                        "starts": starts_init,
                        "ends": ends_init,
                        "axes": axes_init,
                    },
                    outputs={"output": final_output},  # Direct output to final tensor
                    attributes={
                        "axes": [1],
                        "starts": [start],
                        "start": [start],
                        "ends": [end],
                        "end": [end],
                        "steps": [1],
                    },
                    new_name=n.get_name() + f"_{name}_slice_runtime",
                )
        
        # Remove original Split/Slice and Transpose nodes
        if self.use_slices:
            for slice_node in self.slice_nodes:
                self.remove_node(slice_node)
        else:
            self.remove_node(self.split_node)
            
        for transpose_node in self.transpose_nodes:
            self.remove_node(transpose_node)
            
        # Remove any skipped nodes (optional noops)
        for skipped_node in self.skipped_nodes:
            self.remove_node(skipped_node)
