# fmt: on
from collections import defaultdict
from typing import Any
import numpy as np
from OGOAT.src.L1_fusion.py_match.helpers.batch_helper import BatchHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    InputTensor,
    Matcher,
    Node,
    NoMatch,
)
from OGOAT.src.L1_fusion.py_match.checkers import (
    AttrValue,
    CategoryCheck,
    FusedWithQDQNode,
    OpType,
)
from OGOAT.src.L1_fusion.py_match.basic.binary_op import binary_op
from OGOAT.src.L1_fusion.L1_utils.model_IR_utils import sanity_check


class CascadeAdd(Matcher, BatchHelper):
    dependencies = [binary_op]

    def require_add_node(self, node: Element) -> None:
        node = node.require_node()
        node.require(CategoryCheck(binary_op))
        node.require(FusedWithQDQNode())
        node.require(AttrValue("orig_type", "Add"))

    def get_inputs(self) -> list[tuple[InputTensor, InputTensor, InputTensor]]:
        """
        This function returns a list of tuples. Each tuple consists of the input which needs to summed up, the scale and the zero_point for that input.
        The list is ordered such that if the operands are added from front to back, the execution order is the same in the original graph.
        """
        n = self.n
        tensor_stack = [self.n]

        inputs: list[InputTensor] = []
        check_same_op_type = OpType(n.get_op_type())
        while tensor_stack:
            current_node = tensor_stack.pop()

            # at the end of the iteration, True means that both inputs to the current_node are not Add nodes of the same type (which can be batched)
            non_add_inputs = True

            inputA = current_node("A")
            if inputA.check(check_same_op_type):
                if len(inputA.get_readers()) == 1:
                    tensor_stack.append(inputA.get_non_tensor())
                non_add_inputs = False
            else:
                inputs.append(
                    (inputA, current_node("A_scale"), current_node("A_zero_point"))
                )

            inputB = current_node("B")
            if inputB.check(check_same_op_type):
                if len(inputB.get_readers()) == 1:
                    tensor_stack.append(inputB.get_non_tensor())
                non_add_inputs = False
            else:
                inputs.append(
                    (inputB, current_node("B_scale"), current_node("B_zero_point"))
                )

            if non_add_inputs and tensor_stack:
                # two non-Add inputs not a the end: input is not linear
                sanity_check(
                    False,
                    f"input order for node {self.n.get_name()} is not linear",
                    "Warning",
                )

        inputs.reverse()
        return inputs

    def batch_input(self) -> list[Node]:
        n = self.n
        batch_inputs = []

        initializer_value_arrays = defaultdict[str, dict[str, Any]](
            lambda: {"initializer_vals": [], "dtype": None, "name": None}
        )

        for input_type, input_idx in [("A", 0), ("scale", 1), ("zero_point", 2)]:
            inputs = {}

            for i, tensor in enumerate(self.inputs):
                inputs[f"inputs{i}"] = tensor[input_idx]

            if input_type == "A":
                input = self.create_concat_runtime_node(n, input_type, inputs)
            else:
                node_name = n.get_attribute_value("orig_name")
                self.get_init_vals(
                    node_name, inputs, initializer_value_arrays, input_type
                )

                val_dict = initializer_value_arrays[input_type]
                input = self.add_initializer(
                    val_dict["name"],
                    np.concatenate(val_dict["initializer_vals"]),
                    val_dict["dtype"],
                )
            batch_inputs.append(input)

        return batch_inputs

    def match(self):
        n = self.n
        self.require_add_node(n)

        if n("C").check(OpType(n.get_op_type())):
            raise NoMatch("not the last node of adder tree")

        self.inputs = self.get_inputs()
        if len(self.inputs) <= 2:
            raise NoMatch("no batch add")
        self.set_batch_dimension(len(n("A").get_shape()))

    def modify(self):
        n = self.n

        concat_nodes = self.batch_input()
        new_type = n.get_op_type()
        inputs = {}
        inputs["A"] = concat_nodes[0]("concat_result")
        inputs["A_scale"] = concat_nodes[1]
        inputs["A_zero_point"] = concat_nodes[2]

        inputs["output_scale"] = n("C_scale")
        inputs["output_zero_point"] = n("C_zero_point")

        outputs = {
            "output": n("C"),
        }
        attributes = n.get_attributes()
        attributes["num_batches"] = len(self.inputs)
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=attributes,
        )
