# fmt: on
import onnx

from OGOAT.src.L1_fusion.py_match.basic.reshape_to_RTR import ReshapeToRTR
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherOrCategory,
    NoMatch,
)

# Basic patterns
from OGOAT.src.L1_fusion.py_match.basic.binary_op import binary_op
from OGOAT.src.L1_fusion.py_match.basic.concat import (
    Concat,
    ConcatOpTypeChange,
    SplitConcat,
)
from OGOAT.src.L1_fusion.py_match.basic.conv import Conv
from OGOAT.src.L1_fusion.py_match.basic.dataflow import (
    Dataflow,
    TransposeOptQDQ,
    SplitToSlice,
)
from OGOAT.src.L1_fusion.py_match.basic.conv_transpose_to_conv import (
    ConvTransposeToConv,
)
from OGOAT.src.L1_fusion.py_match.basic.gelubasic import GeluBasic
from OGOAT.src.L1_fusion.py_match.basic.non_linear import Resize
from OGOAT.src.L1_fusion.py_match.basic.relu_pwla import ReluPWLA
from OGOAT.src.L1_fusion.py_match.basic.squeeze_unsqueeze_to_RTR import (
    SqueezeUnsqueezeToRTR,
)
from OGOAT.src.L1_fusion.py_match.basic.gemm import Gemm
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul, MatMulNBits
from OGOAT.src.L1_fusion.py_match.basic.pool import Pool
from OGOAT.src.L1_fusion.py_match.basic.reduction import (
    GroupNorm,
    LayerNorm,
    LpNorm,
    Softmax,
    ReduceQdq,
    LpNormIndividual,
)
from OGOAT.src.L1_fusion.py_match.basic.rowwise_runtime import RowWiseOpToRuntime
from OGOAT.src.L1_fusion.py_match.basic.lut import lut
from OGOAT.src.L1_fusion.py_match.basic.layernormbasic import LayerNormBasic


# Advanced patterns
from OGOAT.src.L1_fusion.py_match.adv.linear_slice import LinearSlice
from OGOAT.src.L1_fusion.py_match.adv.linear_slice_transpose import LinearSliceTranspose
from OGOAT.src.L1_fusion.py_match.adv.swish import Swish
from OGOAT.src.L1_fusion.py_match.adv.linear import (
    LinearPlusEWBinary,
    LinearPlusNonLinear,
)
from OGOAT.src.L1_fusion.py_match.adv.rope import (
    RoPE,
    LinearPlusRoPE,
)
from OGOAT.src.L1_fusion.py_match.adv.attention import Attention
from OGOAT.src.L1_fusion.py_match.adv.attention_mladf import AttentionMladf
from OGOAT.src.L1_fusion.py_match.adv.linear import LinearPlusLut
from OGOAT.src.L1_fusion.py_match.adv.rope import RoPE
from OGOAT.src.L1_fusion.py_match.adv.silu import silu
from OGOAT.src.L1_fusion.py_match.adv.add import CascadeAdd
from OGOAT.src.L1_fusion.py_match.adv.binary_relu import EWBinaryPlusRelu
from OGOAT.src.L1_fusion.py_match.adv.batching_by_level import BatchingByLevel
from OGOAT.src.L1_fusion.py_match.adv.conv_to_matmul import ConvtoMatmul
from OGOAT.src.L1_fusion.py_match.adv.group_conv_to_conv import GroupConvToConv
from OGOAT.src.L1_fusion.py_match.adv.matmul_transpose import MatMulTranspose, MatMulTransposeActWgt
from OGOAT.src.L1_fusion.py_match.adv.rtr_optimize import RTROptimize
from OGOAT.src.L1_fusion.py_match.adv.rtr_cancellation import (
    MHA_RTRCancellation,
    RTRStateMatcher,
    TransposeCounter,
    TransposeCounterBefore,
    TransposeCounterAfter,
    TransposeCounterFinal,
)


# Cleanup patterns
from OGOAT.src.L1_fusion.py_match.clean.remove_category import remove_category
from OGOAT.src.L1_fusion.py_match.clean.remove_slice_concat_runtime import (
    RemoveSliceConcatRuntime,
)
from OGOAT.src.L1_fusion.py_match.clean.unfuse import (
    UnfuseSkipLayerNormalization,
    UnfuseSimplifiedLayerNormalization,
)
from OGOAT.src.L1_fusion.py_match.clean.multi_axis_slice import MultiAxisSlice
from OGOAT.src.L1_fusion.py_match.clean.conversion_dequant_quant import ConversionDqQ
from OGOAT.src.L1_fusion.py_match.clean.remove_qdq import Post_Remove_QDQ
from OGOAT.src.L1_fusion.py_match.clean.rename_qdq import Rename_QDQ
from OGOAT.src.L1_fusion.py_match.basic.add_noop import AddNoopSuffix
from OGOAT.src.L1_fusion.py_match.clean.merge_multiple_output import MergeMultipleOutput


yaml_to_py: dict[str, MatcherOrCategory] = {
    "Gemm": Gemm(),  # basic flex
    "GroupNorm": GroupNorm(),  # basic flex
    "ConvTransposeToConv": ConvTransposeToConv(),
    "ConvtoMatMul": ConvtoMatmul(),  # -> advanced 1.5
    "GroupConvToConv": GroupConvToConv(),
    "LinearPlusNonLinear": LinearPlusNonLinear(),
    "Conv": Conv(),  # -> basic flex
    "Concat": Concat(),  # basic flex
    "MatMul": MatMul(),  # basic flex
    "MatMulNBits": MatMulNBits(),  # basic flex
    "Pool": Pool(),
    "BinaryOp": binary_op,  # basic flex
    "GeluBasic": GeluBasic(),  # basic flex
    "LayerNormBasic": LayerNormBasic(),  # basic flex
    "Lp_Norm": LpNorm(),  # basic flex
    "Silu": silu,  # basic flex -> advanced 1.5
    "Swish": Swish(),  # basic flex -> advanced 1.5
    "Softmax_qdq": Softmax(),  # -> basic flex
    "LayerNorm": LayerNorm(),  # -> basic flex
    "Resize": Resize(),
    "Dataflow": Dataflow(),  # basic flex
    "Post_Remove_QDQ": Post_Remove_QDQ(),  # basic
    "LinearPlusLut": LinearPlusLut(),  # -> advanced, linear + lut
    # "LinearPlusNorm": LinearPlusNorm(),
    "RoPE": RoPE(),
    "CascadeAdd": CascadeAdd(),
    "LinearPlusRoPE": LinearPlusRoPE(),
    "Attention": Attention(),  # -> advanced
    "Attention_mladf": AttentionMladf(),
    "MatMulTranspose": MatMulTranspose(),
    "MatMulTransposeActWgt": MatMulTransposeActWgt(),
    "RTROptimize": RTROptimize(),
    "LinearSlice": LinearSlice(),
    "LinearSliceTranspose": LinearSliceTranspose(),
    # "LinearPlusEWBinary" : LinearPlusEWBinary(), # Linear plus EW binary op advanced fusion
    "Lut": lut,  # basic flex
    "EWBinaryPlusRelu": EWBinaryPlusRelu(),
    "ReluPWLA": ReluPWLA(),
    "AddNoopSuffix": AddNoopSuffix(),
    "RowWiseOpToRuntime": RowWiseOpToRuntime(),
    "RemoveCategory": remove_category,
    "RenameQDQ": Rename_QDQ(),
    "ConversionDqQ": ConversionDqQ(),
    "TransposeOptQDQ": TransposeOptQDQ(),
    # "BatchMatMulTranspose": BatchMatMulTranspose(),
    "Batching": BatchingByLevel(),
    "RemoveSliceConcatRuntime": RemoveSliceConcatRuntime(),
    "LinearPlusEWBinary": LinearPlusEWBinary(),
    "UnfuseSkipLayerNormalization": UnfuseSkipLayerNormalization(),
    "UnfuseSimplifiedLayerNormalization": UnfuseSimplifiedLayerNormalization(),
    "SplitToSlice": SplitToSlice(),
    "MultiAxisSlice": MultiAxisSlice(),
    "SplitConcat": SplitConcat(),
    "MHA_RTRCancellation": MHA_RTRCancellation(),
    "CountTransposesBefore": TransposeCounterBefore(),
    "CountTransposes": TransposeCounter("MHA_RTR"),
    "CountTransposesAfter": TransposeCounterAfter(),
    "CountTransposesFinal": TransposeCounterFinal(),
    "RTRPointers": RTRStateMatcher(),
    "ConcatOpTypeChange": ConcatOpTypeChange(),
    "Reduce_qdq": ReduceQdq(),
    "SqueezeUnsqueezeToRTR": SqueezeUnsqueezeToRTR(),
    "MergeMultipleOutput": MergeMultipleOutput(),
    "ReshapeToRTR": ReshapeToRTR(),
    "Lp_NormIndividual": LpNormIndividual(),
}
