from typing import Any
from OGOAT.src.L1_fusion.py_match.adv.matmul_transpose import (
    MatMulTranspose,
    MatMulTranspose4D,
)
from OGOAT.src.L1_fusion.py_match.checkers import CategoryCheck
from OGOAT.src.L1_fusion.py_match.helpers.batch_helper import BatchHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, WalkCfgPlain


class BatchMatMulTranspose(Matcher, BatchHelper):
    dependencies = [MatMulTranspose()]

    def match(self):
        n = self.n.with_walk_cfg(WalkCfgPlain())
        n.require(CategoryCheck(MatMulTranspose()) | CategoryCheck(MatMulTranspose4D()))
        self.batched_node_list = self.get_batch_nodes(n, True)

    def modify(self):
        n = self.n
        inputs = self.get_concated_batch_inputs(self.batched_node_list)
        outputs = self.get_splitted_batch_outputs(self.batched_node_list)
        new_type = n.get_op_type()
        new_name = "BatchMatMulTranspose" + n.get_attribute_value("orig_name")
        attributes = n.get_attributes()
        for node in self.batched_node_list:
            self.remove_node(node)

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=new_name,
        )
