# fmt: on
from typing import Any
from OGOAT.src.L1_fusion.py_match.basic.dataflow import DataflowAttributeGenerator
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    DTypeAny,
    opType,
)
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    Matcher,
    Node,
    OutputTensor,
)


class MinMax(Matcher, QDQHelper):
    def match(self) -> None:
        n = self.n.require_node()
        self.new_dtypes, self.qdq_attributes = self.check_qdq(n, DTypeAny())
        self.orig_out_name = n.get_schema_output_names()[0]

    def modify(self) -> None:
        n = self.n.require_node()
        output_tensor_dtype = self.new_dtypes.split("x")[-1]
        new_type = (
            n.get_op_type() + len(n.get_inputs()) + "_" + "_qdq_" + output_tensor_dtype
        )
        inputs: dict[str, Element] = {}
        for i, inp in enumerate(n.get_inputs()):
            inputs["x" + i] = inp("x")
            inputs["x_scale" + i] = inp("x_scale")
            inputs["x_zero_point" + i] = inp("x_zero_point")

        inputs["y_scale"] = n(self.orig_out_name + ".y_scale")
        inputs["y_zero_point"] = n(self.orig_out_name + ".y_zero_point")

        outputs = {self.orig_out_name: n(self.orig_out_name + ".y")}
        copy_attributes = n.get_attributes()
        copy_attributes["num_inputs"] = len(n.get_inputs())
        self.remove_node(n)
        copy_attributes["num_of_tensor_inputs"] = 1

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | self.qdq_attributes,
        )


class Clip(Matcher, QDQHelper):
    def match(self):
        n = self.n.require_node()
        n.require(opType.Clip)
        self.new_dtypes, self.qdq_attributes = self.check_qdq(n, DTypeAny())

    def modify(self):
        n = self.n
        new_type = (
            "Clip_qdq_"
            + self.new_dtypes.split("x")[0]
            + "x"
            + self.new_dtypes.split("x")[-1]
        )

        inputs, outputs = self.get_in_out_dict_for_qdq_node(n)

        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | self.qdq_attributes,
        )


class Relu(Matcher, QDQHelper):
    def match(self) -> None:
        n = self.n

        n.require(opType.Relu)
        self.new_dtypes, self.qdq_attributes = self.check_qdq(n, DTypeAny())

    def modify(self) -> None:
        n = self.n.require_node()

        new_type = "Relu_qdq_" + self.new_dtypes

        inputs, outputs = self.get_in_out_dict_for_qdq_node(n)

        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | self.qdq_attributes,
        )


class Resize(Matcher, QDQHelper):

    def generate_resize_attributes(
        self, node: Node, input_names_to_gen_attr: list[str]
    ) -> dict[str, Any]:

        attributes = {}
        for in_name in input_names_to_gen_attr:
            input_tensor = node.get_connection(in_name)
            if not input_tensor.get_name():
                continue
            init_val = input_tensor.require_tensor().get_initializer_array()
            for i in range(len(init_val)):
                attributes[in_name + "_" + str(i + 1)] = init_val[i]

        return attributes

    def match(self) -> None:
        n = self.n
        n.require(opType.Resize)
        self.input = n("X")
        self.output = n("Y")
        self.new_dtype, self.qdq_attributes = self.check_qdq(n, DTypeAny(), [1])

    def modify(self) -> None:
        n = self.n.require_node()
        op_type = n.get_op_type()
        new_type = op_type + "_qdq_" + self.new_dtype

        inputs = {
            "data": (
                self.input("x")
                if self.input.check(opType.DequantizeLinear)
                else self.input
            ),
            "data_scale": (
                self.input("x_scale")
                if self.input("x_scale").check_initializer()
                else None
            ),
            "data_zero_point": (
                self.input("x_zero_point")
                if self.input("x_zero_point").check_initializer()
                else None
            ),
            "Y_scale": (
                self.output("y_scale")
                if self.output("y_scale").check_initializer()
                else None
            ),
            "Y_zero_point": (
                self.output("y_zero_point")
                if self.output("y_zero_point").check_initializer()
                else None
            ),
        }

        outputs = {
            "Y": (
                self.output("y")
                if self.output.check(opType.QuantizeLinear)
                else self.output
            )
        }

        input_attributes = ["roi"]
        if n("scales").check_initializer():
            input_attributes.append("scales")
        elif n("sizes").check_initializer():
            input_attributes.append("sizes")
        else:
            print(f"Not found scales or sizes for {n.get_name()}")

        attributes_from_input = self.generate_resize_attributes(n, input_attributes)
        copy_attributes = n.get_attributes()
        copy_attributes["num_of_tensor_inputs"] = 1
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | attributes_from_input | self.qdq_attributes,
        )


def get_non_linear_output_tensor(node: Node) -> OutputTensor:
    if node.check(CategoryCheck(MinMax())) and not node("min").check_nowhere():
        return node("min")
    elif node.check(CategoryCheck(MinMax())) and not node("max").check_nowhere():
        return node("Max")
    elif node.check(CategoryCheck(Clip())):
        return node("output")
    elif node.check(CategoryCheck(Relu())):
        return node("Y")
