# fmt: on
import re
from OGOAT.src.L1_fusion.py_match.checkers import CategoryCheck, opType, DTypeAny
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import (
    QDQHelper,InitName
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    Initializer,
    Matcher,
    Node,
    NoMatch,
    Tensor,
)

from OGOAT.src.L1_fusion.py_match.helpers.rowwise_helper import RowWiseHelper

class SplitConcat(Matcher, QDQHelper, RowWiseHelper):
    """
    Split concat nodes with more than two inputs and different scale and zero_points into cascading concat nodes with two inputs
    """

    def compute_concat_shape(self, node: Node) -> list[int]:
        shape = node.get_inputs()[0].get_shape()
        axis = node.get_attribute_value("axis")
        shape[axis] = sum([t.get_shape()[axis] for t in node.get_inputs()])
        return shape

    def add_qdq(
        self, input_tensor: Tensor, scale: Initializer, zero_point: Initializer
    ) -> Tensor:

        q_name = input_tensor.get_name() + "_quantize"
        q_output = Tensor(
            input_tensor._model_dict, input_tensor._walk_cfg, q_name + "_out"
        )
        q_output.set_shape(input_tensor.get_shape(), zero_point.get_dtype())
        self.add_node(
            type="QuantizeLinear",
            domain="ai.onnx",
            inputs={"x": input_tensor, "y_scale": scale, "y_zero_point": zero_point},
            outputs={"y": q_output},
            attributes={"orig_name": q_name},
        )

        dq_name = input_tensor.get_name() + "_dequantize"
        dq_output = Tensor(
            input_tensor._model_dict, input_tensor._walk_cfg, dq_name + "_out"
        )
        dq_output.set_shape(input_tensor.get_shape(), zero_point.get_dtype())
        self.add_node(
            type="DequantizeLinear",
            domain="ai.onnx",
            inputs={"x": q_output, "x_scale": scale, "x_zero_point": zero_point},
            outputs={"y": dq_output},
            attributes={"orig_name": dq_name},
        )

        return dq_output

    def match(self) -> None:
        n = self.n
        n.require(opType.Concat)
        inputs = n.get_inputs()
        self.number_inputs = len(inputs)
        if self.number_inputs < 3:
            raise NoMatch("Splitting only works for Concat with at least 3 inputs")

        has_qdq, has_same = self.check_input_output_qdq(n)
        if not has_qdq:
            raise NoMatch("Pattern only supports nodes with qdq at all inputs/outputs")
        if has_same:
            raise NoMatch("All same qdq parameters")
        if self.is_rowwise_op(axis=n.get_attribute_value("axis"), inputs=inputs, node_name=n.get_name()):
            raise NoMatch("Do not split rowwise quantized ops")

    def modify(self) -> None:
        n = self.n.require_node()

        next_input_tensor = n.get_inputs_dict()["input0"]
        dtype = next_input_tensor.get_dtype()

        new_concats: list[Node] = []
        for i in range(self.number_inputs - 1):
            input_index = i + 1

            inputs = {
                "input0": next_input_tensor,
                "input1": n.get_inputs_dict()[f"input{input_index}"],
            }

            next_input_tensor = Tensor(
                n._model_dict, n._walk_cfg, n("concat_result").get_name() + f"_{i}"
            )

            outputs = {"concat_result": next_input_tensor}

            concat = self.add_node(
                type=n.get_op_type(),
                domain=n.get_domain(),
                inputs=inputs,
                outputs=outputs,
                attributes=n.get_attributes() | {"orig_name": n.get_name() + f"_{i}"},
            )
            new_concats.append(concat)

            output_shape = self.compute_concat_shape(concat)
            next_input_tensor.set_shape(output_shape, dtype)

            # last added concat does not need qdq nodes added at the output
            if i != self.number_inputs - 2:
                next_input_tensor = self.add_qdq(
                    next_input_tensor,
                    n("concat_result.y_scale"),
                    n("concat_result.y_zero_point"),
                )

        last_output_tensor = n("concat_result")
        self.remove_node(n)
        self.replace_output(
            new_concats[-1], new_concats[-1]("concat_result"), last_output_tensor
        )


class Concat(Matcher, QDQHelper):

    dependencies = [SplitConcat()]

    def match(self) -> None:
        n = self.n
        n.require(opType.Concat)
        self.inputs = n.get_inputs()
        self.inp_cnt = len(self.inputs)
        self.new_dtypes, self.qdq_attributes = self.check_qdq(n, DTypeAny())
        self.output = n("concat_result")
        self.has_qdq, has_same_qdq = self.check_input_output_qdq(n)
        if has_same_qdq:
            raise NoMatch("All same qdq parameters should be removed around concat")

    def modify(self) -> None:
        n = self.n
        inputs: dict[str, Element] = {}
        for i, input in enumerate(n.get_inputs()):
            inputs[f"input_{i}"] = input("x") if input.check(opType.DequantizeLinear) else input
        
        for i, input in enumerate(n.get_inputs()):
            inputs[f"input_{i}_scale"] = input("x_scale") if input.check(opType.DequantizeLinear) else self._get_initializer_or_dummy(
                    input("x_scale"), n, InitName.SCALE
                )
            inputs[f"input_{i}_zero_point"] = input("x_zero_point") if input.check(opType.DequantizeLinear) else self._get_initializer_or_dummy(
                    input("x_zero_point"), n, InitName.SCALE_ZERO_POINT
                )

        if self.output.check(opType.QuantizeLinear):
            inputs["concat_result_scale"] = n("concat_result.y_scale")
            inputs["concat_result_zero_point"] = n("concat_result.y_zero_point")
            outputs = {"concat_result": n("concat_result.y")}
        else:
            output = n.get_outputs()[0]
            inputs["concat_result_scale"] = self._get_initializer_or_dummy(
                        output("y_scale"), n, InitName.OUTPUT_SCALE
                    )
            inputs["concat_result_zero_point"] = self._get_initializer_or_dummy(
                        output("y_zero_point"), n, InitName.OUTPUT_ZERO_POINT
                    )
            outputs = n.get_outputs_dict()
            
    
        dtype = n.get_model_activation_dtype()
        if self.has_qdq:
            dtype = self.new_dtypes.split("x")[-1]

        new_type = f"Concat{self.inp_cnt}" + "_qdq_" + dtype

        copy_attributes = n.get_attributes()
        copy_attributes["num_inputs"] = self.inp_cnt
        copy_attributes["num_of_tensor_inputs"] = self.inp_cnt

        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes|self.qdq_attributes,
        )


class ConcatOpTypeChange(Matcher):

    dependencies = [Concat()]

    def match(self) -> None:
        n = self.n
        n.require(CategoryCheck((Concat())))

    def modify(self) -> None:
        n = self.n.require_node()
        op_type = re.sub(r"^Concat\d*", "Concat", n.get_op_type())
        n.change_op_type(op_type)
