# fmt: on
from itertools import accumulate
from typing import NamedTuple
import onnx

import numpy as np
from OGOAT.src.L1_fusion.py_match.checkers import opType, DTypeAny
from OGOAT.src.L1_fusion.py_match.helpers.dataflow_attribute_generator import (
    DataflowAttributeGenerator,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Initializer,
    Matcher,
    MatcherError,
    NoMatch,
    Node,
)
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import (
    QDQHelper, InitName
)
from OGOAT.src.L1_fusion.py_match.helpers.transpose_helper import TransposeHelper
from OGOAT.src.L1_fusion.py_match.helpers.noop_helper import NoopHelper
from OGOAT.src.L1_fusion.py_match.py_match_utils import (
    get_value_from_dequantize_linear,
)

class Dataflow(
    Matcher,
    DataflowAttributeGenerator,
    TransposeHelper,
    QDQHelper,
    NoopHelper,
):

    def _get_pad_constant_value_initializer(
        self,
        n: Node,
    ) -> Initializer:
        if n.get_attribute_value("mode") == "constant":
            if n("constant_value").check(opType.DequantizeLinear):
                dequant_output = n("constant_value")
                constant_value = get_value_from_dequantize_linear(dequant_output)
                const_value_ini = self.add_initializer(
                    initializer_name=dequant_output.get_name() + "_const",
                    value=constant_value,
                    dtype=onnx.TensorProto.FLOAT,
                )
                return const_value_ini
            elif n("constant_value").check_initializer():
                return n("constant_value")

    def match(self) -> None:
        n = self.n
        n.require(
            opType.Gather
            | opType.GatherElements
            | opType.Slice
            | opType.DepthToSpace
            | opType.Pad
        )

        if n.check(opType.Slice):
            if self.is_noop_slice(n):
                raise NoMatch("Slice_qdq cannot be a no operation")
            else:
                n("starts").require_initializer()
                n("ends").require_initializer()
        elif n.check(opType.Pad):
            n("pads").require_initializer()
        elif n.check(opType.Gather | opType.GatherElements):
            n("indices").require_tensor()

        # For all the dataflow ops, always the first input tensor is the data tensor
        self.input = n.get_inputs()[0]
        self.output = n.get_outputs()[0]
        self.has_dq = self.input.check(opType.DequantizeLinear)
        self.has_q = self.output.check(opType.QuantizeLinear)

        # we pass 1 as we only want to check the  first data tensor
        # 1 is passed to handle initializer like indices in Gather op which could be a initializer or a tensor. But we don't want to consider that initializer for dtype check.
        self.new_dtype, self.qdq_attributes = self.check_qdq(n, DTypeAny(), [1, 2, 3])

        if n.check(opType.DepthToSpace):
            self.input, self.output = self.require_nchw_conversion(
                self.input, self.output
            )
            if self.input.has_qdq():
                self.require_qdq_equal_scale_zeropoint(
                    n("input"), self.input.quantize_node
                )
            if self.output.has_qdq():
                self.require_qdq_equal_scale_zeropoint(
                    self.output.dequantize_node, n("output")
                )

    def modify(self) -> None:
        n = self.n.require_node()
        op_type = n.get_op_type()
        node_name = n.get_name()

        inputs = {
            "data": self.input("x") if self.has_dq else self.input,
            "indices": None,
            "pads": None,
            "data_scale": self._get_initializer_or_dummy(
                self.input("x_scale"), n, InitName.SCALE
            ),
            "data_zero_point": self._get_initializer_or_dummy(
                self.input("x_zero_point"), n, InitName.SCALE_ZERO_POINT
            ),
            "output_scale": self._get_initializer_or_dummy(
                self.output("y_scale"), n, InitName.OUTPUT_SCALE
            ),
            "output_zero_point": self._get_initializer_or_dummy(
                self.output("y_zero_point"), n, InitName.OUTPUT_ZERO_POINT
            ),
        }
        outputs = {"output": (self.output("y") if self.has_q else self.output)}

        if n.check(opType.Gather | opType.GatherElements):
            inputs["indices"] = n("indices")
        elif n.check(opType.Pad):
            inputs["pads"] = n("pads")
            inputs["constant_value"] = self._get_pad_constant_value_initializer(n)

        attributes = {"num_of_tensor_inputs": 1}

        if n.check(opType.Slice):
            attributes |= self.generate_slice_attributes(n)

        copy_attributes = n.get_attributes()
        new_type = op_type + "_qdq_" + self.new_dtype
        self.remove_node(n)
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes | attributes | self.qdq_attributes,
            new_name=node_name + "_" + new_type,
        )


class SliceInitializers(NamedTuple):
    starts: list[int]
    ends: list[int]
    axes: list[int]


class SplitToSlice(Matcher):
    @staticmethod
    def compute_slice_initializers_from_split(
        output_shapes: list[list[int]], axis: int, has_split: bool = False
    ) -> list[SliceInitializers]:
        """Compute Slice node initializers from Split node parameters."""

        rank = len(output_shapes[0])
        if not all(len(shape) == rank for shape in output_shapes):
            raise MatcherError("All output shapes must have the same rank")

        # Validate that all dimensions except the split axis are identical

        if not has_split:
            for dim in range(rank):
                if dim == axis:
                    continue
                first_dim_size = output_shapes[0][dim]
                if not all(shape[dim] == first_dim_size for shape in output_shapes):
                    raise MatcherError(
                        f"All output shapes must have identical dimensions except along axis {axis}"
                    )

        split_sizes = [shape[axis] for shape in output_shapes]

        # Build offset positions: [0, first_size, first_size+second_size, ...]
        # Example: split_sizes=[512, 512, 512] → offsets=[0, 512, 1024, 1536]
        offsets = [0, *accumulate(split_sizes)]
        return [
            SliceInitializers(starts=[offsets[i]], ends=[offsets[i + 1]], axes=[axis])
            for i in range(len(split_sizes))
        ]

    def match(self) -> None:
        n = self.n.require_node()
        n.require(opType.Split)

    def modify(self) -> None:
        n = self.n.require_node()
        split_node_name = n.get_name()
        has_split = n("split").check_initializer()

        # Get split configuration from attributes
        axis = n.get_attribute_value("axis")
        output_shapes = [
            output.require_tensor().get_shape() for output in n.get_outputs()
        ]

        # Compute slice initializers for all outputs
        slice_initializers = self.compute_slice_initializers_from_split(
            output_shapes, axis, has_split
        )

        inputs = {"data": n("input")}

        # Create individual Slice nodes for each split output
        for i, split_tensor in enumerate(n.get_outputs()):
            output_tensor = split_tensor.require_tensor()
            slice_node_name = split_node_name + "_slice_" + str(i)

            # Create initializers for slice parameters
            starts_init = self.add_initializer(
                initializer_name=split_node_name + f"_slice_{i}_starts",
                value=np.array(slice_initializers[i].starts, dtype=np.int64),
            )
            ends_init = self.add_initializer(
                initializer_name=split_node_name + f"_slice_{i}_ends",
                value=np.array(slice_initializers[i].ends, dtype=np.int64),
            )
            axes_init = self.add_initializer(
                initializer_name=split_node_name + f"_slice_{i}_axes",
                value=np.array(slice_initializers[i].axes, dtype=np.int64),
            )

            self.add_node(
                type="Slice",
                domain="ai.onnx",
                inputs=inputs
                | {"starts": starts_init, "ends": ends_init, "axes": axes_init},
                outputs={"output": output_tensor},
                attributes={},
                new_name=slice_node_name,
            )

        self.remove_node(n)


class TransposeOptQDQ(Dataflow):

    def match(self) -> None:
        n = self.n
        n.require(opType.Transpose)

        self.qdq = self.check_dq_and_q(n("data"), n("transposed"))
        self.input = self.qdq.dq.quant_or_orig_tensor
        self.output = self.qdq.q.quant_or_orig_tensor

        self.x_zero_point_type = (
            self.qdq.dq.zero_point.get_dtype()
            if self.qdq.dq.present
            else self.input.require_tensor().get_dtype()
        )
        self.y_zero_point_type = (
            self.qdq.q.zero_point.get_dtype()
            if self.qdq.q.present
            else self.output.require_tensor().get_dtype()
        )

        self.disable_dq = self.qdq.q_prm_equal or not self.qdq.dq.present
        self.disable_q = self.qdq.q_prm_equal or not self.qdq.q.present

        self.attributes = n.get_attributes()

    def modify(self) -> None:
        new_type = (
            "Transpose_qdq_" + self.x_zero_point_type + "x" + self.y_zero_point_type
        )
        inputs = {
            "data": self.input,
            "data_scale": self._get_initializer_or_dummy(
                self.qdq.dq.scale, self.n, InitName.SCALE
            ),
            "data_zero_point": self._get_initializer_or_dummy(
                self.qdq.dq.zero_point, self.n, InitName.SCALE_ZERO_POINT
            ),
            "Y_scale": self._get_initializer_or_dummy(
                self.qdq.q.scale, self.n, InitName.OUTPUT_SCALE
            ),
            "Y_zero_point": self._get_initializer_or_dummy(
                self.qdq.q.zero_point, self.n, InitName.OUTPUT_ZERO_POINT
            ),
        }
        outputs = {"transposed": self.output}
        qdq_attributes = {
            "disable_dq0": self.disable_dq,
            "disable_q": self.disable_q,
            "num_of_tensor_inputs": 1,
        }
        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=self.attributes | qdq_attributes,
        )
        self.remove_node(self.n)
