# fmt: on
from OGOAT.src.L1_fusion.py_match.checkers import AttrValue, opType
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, MatcherError


class GemmMutliplier(Matcher):
    def match(self) -> None:
        n = self.n
        n.require(opType.Gemm)

        self.alpha = n.get_attribute_value("alpha")
        if self.alpha != 1.0:
            n("A").require(opType.DequantizeLinear)

        self.beta = n.get_attribute_value("beta")
        if self.beta != 1.0:
            n("C").require(opType.DequantizeLinear)

    def modify(self):
        n = self.n
        if self.alpha != 1.0:
            initializer = n("A.y_scale").require_initializer()
            initializer.multiply(self.alpha)
            n.set_attribute("alpha", 1.0)

        if self.beta != 1.0:
            initializer = n("C.y_scale").require_initializer()
            initializer.multiply(self.beta)
            n.set_attribute("beta", 1.0)


class TransposeGemmInitializer(Matcher):
    def match(self) -> None:
        n = self.n
        n.require(opType.Gemm)
        self.transpose_a = (
            n.check(AttrValue("transA", 1))
            and n("A").check(AttrValue("alpha", 1.0))
            and n("A").check_initializer()
        )
        self.transpose_b = (
            n.check(AttrValue("transB", 1))
            and n("B").check(AttrValue("beta", 1.0))
            and n("B").check_initializer()
        )
        if not self.transpose_a and not self.transpose_b:
            raise MatcherError("no transposed input")

    def modify(self) -> None:
        n = self.n

        if self.transpose_a:
            new_name = n("A").get_name() + "_trans"
            initializer_new = self.add_transposed_initializer(n("A"), new_name)
            self.replace_input(n, n("A"), initializer_new)
            n.set_attribute("transA", 0)

        if self.transpose_b:
            new_name = n("B").get_name() + "_trans"
            initializer_new = self.add_transposed_initializer(n("A"), new_name)
            self.replace_input(n, n("B"), initializer_new)
            n.set_attribute("transB", 0)


class TransposeGemmQdqInitializer(Matcher):
    def match(self) -> None:
        n = self.n
        n.require(opType.Gemm)

        self.transpose_a = (
            n.check(AttrValue("transA", 1))
            and n.check(AttrValue("alpha", 1.0))
            and n("A").check(opType.DequantizeLinear)
            and n("A.x").check_initializer()
        )
        self.transpose_b = (
            n.check(AttrValue("transB", 1))
            and n.check(AttrValue("beta", 1.0))
            and n("B").check(opType.DequantizeLinear)
            and n("B.x").check_initializer()
        )
        if not self.transpose_a and not self.transpose_b:
            raise MatcherError("no transposed input")

    def modify(self) -> None:
        n = self.n

        if self.transpose_a:
            initializer_name_new = n("A.x").get_name() + "_trans"
            initializer_new = self.add_transposed_initializer(
                n("A.x"), initializer_name_new
            )
            self.replace_input(n("A"), n("A.x"), initializer_new)
            n.set_attribute("transA", 0)

        if self.transpose_b:
            initializer_name_new = n("B.x").get_name() + "_trans"
            initializer_new = self.add_transposed_initializer(
                n("B.x"), initializer_name_new
            )
            self.replace_input(n("B"), n("B.x"), initializer_new)
            n.set_attribute("transB", 0)
