from typing import Any

from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.bias_helper import BinaryOpHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Matcher, Node, Tensor


class UnfuseSkipLayerNormalization(Matcher, BinaryOpHelper):

    def match(self) -> None:
        n = self.n
        n.require(opType.SkipLayerNormalization)

    def modify(self) -> None:
        n = self.n

        has_bias = len(n.get_inputs()) == 5
        if has_bias:
            # Bias Node
            bias_name = n.get_name() + "_bias"
            bias_inputs = {"A": n("input"), "B": n("bias")}
            bias_outputs = {
                "C": Tensor(n._model_dict, n._walk_cfg, bias_name + "_output", None)
            }
            shape = self.get_output_shape(
                bias_inputs["A"].require_tensor().get_shape(),
                bias_inputs["B"].require_tensor().get_shape(),
            )
            bias_outputs["C"].require_tensor().set_shape(
                shape, bias_inputs["A"].require_tensor().get_dtype()
            )
            bias = self.add_node(
                type="Add",
                domain="ai.onnx",
                inputs=bias_inputs,
                outputs=bias_outputs,
                attributes={},
                new_name=bias_name,
            )

        # Add Node
        add_name = n.get_name() + "_add"
        add_inputs = {"A": n("input"), "B": n("skip")}
        if has_bias:
            add_inputs["A"] = bias("C")
        add_outputs = {
            "C": Tensor(n._model_dict, n._walk_cfg, add_name + "_output", None)
        }
        shape = self.get_output_shape(
            bias_inputs["A"].require_tensor().get_shape(),
            bias_inputs["B"].require_tensor().get_shape(),
        )
        add_outputs["C"].require_tensor().set_shape(
            shape, add_inputs["A"].require_tensor().get_dtype()
        )
        add = self.add_node(
            type="Add",
            domain="ai.onnx",
            inputs=add_inputs,
            outputs=add_outputs,
            attributes={},
            new_name=add_name,
        )

        # Layernorm Node
        layernorm_inputs = {
            "X": add("C"),
            "Scale": n("gamma"),
            "B": n("beta"),
        }
        layernorm_outputs = {"Y": n("output")}
        layernrom_attributes: dict[str, Any] = {}
        if "epsilon" in n.require_node().get_attributes():
            layernrom_attributes["epsilon"] = n.require_node().get_attribute_value(
                "epsilon"
            )
        self.remove_node(n)
        layernorm = self.add_node(
            type="LayerNormalization",
            domain="ai.onnx.contrib",
            inputs=layernorm_inputs,
            outputs=layernorm_outputs,
            attributes=layernrom_attributes,
            new_name=n.get_name() + "_layernorm",
        )


class UnfuseSimplifiedLayerNormalization(Matcher):

    def match(self) -> None:
        n = self.n
        n.require(opType.SimplifiedLayerNormalization)

    def modify(self) -> None:
        n = self.n.require_node()

        domain = n.get_domain()
        inputs = {"X": n("X"), "Scale": n("scale"), "B": n("B")}
        outputs = {"Y": n("Y")}
        attributes = n.get_attributes()
        attributes["simplified"] = True
        new_name = n.get_name()

        self.remove_node(n)
        self.add_node(
            type="LayerNormalization",
            domain=domain,
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
            new_name=new_name,
        )
