# fmt: on
from typing import Optional

import numpy as np
from OGOAT.src.L1_fusion.py_match.adv.conv_to_matmul import ConvtoMatmul
from OGOAT.src.L1_fusion.py_match.basic.dataflow import Dataflow
from OGOAT.src.L1_fusion.py_match.basic.matmul import MatMul
from OGOAT.src.L1_fusion.py_match.checkers import (
    CategoryCheck,
    opType,
)
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.helpers.reshape_transpose_helper import (
    ReshapeTransposeHelper,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    Matcher,
    MatcherError,
    Node,
    Tensor,
)
from OGOAT.src.L1_fusion.py_match.basic.gemm import Gemm


class LinearSlice(Matcher, QDQHelper):
    dependencies = [Dataflow()]

    @staticmethod
    def _get_position_if_valid_slice(node: Node, axis: int, split_factor: int) -> bool:
        """
        Returns the position of the slice node if it slices an equal part of the input
        """
        if split_factor == 1:
            return 0

        axes = node.get_attribute_value("axes")
        if axes != [axis]:
            raise MatcherError("unsupported slice parameters")

        starts = node.get_attribute_value("starts")
        ends = node.get_attribute_value("ends")
        input_shape = node("data").get_shape()

        step = input_shape[axis] // split_factor
        curr_start = 0
        for idx, curr_end in enumerate(range(step, input_shape[axis] + 1, step)):
            if starts == [curr_start] and ends == [curr_end]:
                return idx
            curr_start = curr_end

        raise MatcherError("unsupported slice parameters")

    def _create_tensor(self, name: str, shape: list[int], dtype: str) -> Tensor:
        tensor = Tensor(
            self.n._model_dict,
            self.n._walk_cfg,
            name,
            None,
        )
        tensor.set_shape(shape, dtype)
        return tensor

    def match(self) -> None:
        n = self.n
        n.require(
            CategoryCheck(MatMul())
            | CategoryCheck(Gemm())
            | CategoryCheck(ConvtoMatmul())
        )
        n("B").require_initializer()
        self.has_bias = n("Bias").check_tensor()
        if self.has_bias:
            self.has_bias_zero_point = n("Bias_zero_point").check_initializer()
            self.has_bias_scale = n("Bias_scale").check_initializer()

        main_readers = n("Y").skip().require_tensor().get_readers()
        # skip extra DQ and Q after main node, if present
        #                                                   /---> ...
        # linear_qdq ---> (optional noop) ---> DQ ---> Q ---
        #                                                   \---> ...
        if (
            len(main_readers) == 1
            and n("Y").check(opType.DequantizeLinear)
            and n("Y.y").check(opType.QuantizeLinear)
            and self.check_qdq_equal_scale_zeropoint(n("Y"), n("Y.y"))
        ):
            main_readers = n("Y.y.y").skip().require_tensor().get_readers()

        #                                   /---> slice_qdq_0
        # linear_qdq --- (optional noop) ---
        #                                   \---> slice_qdq_1

        self.split_factor = len(main_readers)
        self.slices: dict[int, Node] = {}
        self.axis: Optional[int] = None
        for out_node in main_readers:
            out_node.require(CategoryCheck(Dataflow()))
            if not out_node.get_op_type().startswith("Slice"):
                raise MatcherError("LinearSlice: Slice node is required.")

            if not self.axis:
                if (
                    not out_node.get_attribute_value("axes")
                    or len(out_node.get_attribute_value("axes")) != 1
                ):
                    raise MatcherError("LinearSlice: wrong axes attribute for Slice")
                self.axis = out_node.get_attribute_value("axes")[0]

            self.slices[
                LinearSlice._get_position_if_valid_slice(
                    out_node, self.axis, self.split_factor
                )
            ] = out_node

        # check if some nodes were skipped (optional noops are presented)
        skipped_readers = n("Y").require_tensor().get_readers()
        self.skipped_nodes: list[Node] = []
        while skipped_readers != main_readers:
            if len(skipped_readers) != 1:
                raise MatcherError(
                    "LinearSlice: the optional noop must be the only reader"
                )

            # TODO: support for not only Reshape
            skipped_reader = skipped_readers[0].require_node()
            skipped_reader.require(opType.Reshape | opType.Transpose)
            self.skipped_nodes.append(skipped_reader)
            skipped_readers = skipped_reader.get_outputs()[0].get_readers()

        # verify that slice is on the last dimension of the linear op output tensor
        if self.skipped_nodes:
            _, _, out_mask = ReshapeTransposeHelper.reshape_diff(
                n("Y").get_shape(), self.skipped_nodes[-1].get_outputs()[0].get_shape(), True
            )
            if self.axis not in out_mask[-1]:
                raise MatcherError("Slice acts not on the last dimension.")
        
        # try to split `B` and `Bias` initializers; if splitting is not possible,
        # `get_initializers_for_split` will raise MatcherError
        self.b = self.get_initializers_for_split(n("B"), self.split_factor)
        self.b_zero_point = self.get_initializers_for_split(
            n("B_zero_point"), self.split_factor
        )
        self.b_scale = self.get_initializers_for_split(n("B_scale"), self.split_factor)

        if self.has_bias:
            self.bias = self.get_initializers_for_split(n("Bias"), self.split_factor)

            if self.has_bias_zero_point:
                self.bias_zero_point = self.get_initializers_for_split(
                    n("Bias_zero_point"), self.split_factor
                )

            if self.has_bias_scale:
                self.bias_scale = self.get_initializers_for_split(
                    n("Bias_scale"), self.split_factor
                )

            # verify the split result: the split bias must match the split B
            for i in range(self.split_factor):
                if self.b[i].get_shape()[-1] != self.bias[i].get_shape()[-1]:
                    raise MatcherError(f"B and Bias shape are not matching.")

    def modify(self) -> None:
        n = self.n.require_node()

        for idx in range(self.split_factor):
            inputs = n.get_inputs_dict()

            inputs["B"] = self.b[idx]
            if self.has_bias:
                inputs["Bias"] = self.bias[idx]
                if self.has_bias_zero_point:
                    inputs["Bias_zero_point"] = self.bias_zero_point[idx]
                if self.has_bias_scale:
                    inputs["Bias_scale"] = self.bias_scale[idx]

            inputs["B_zero_point"] = self.b_zero_point[idx]
            inputs["B_scale"] = self.b_scale[idx]

            inputs["Y_scale"] = self.slices[idx]("output_scale")
            inputs["Y_zero_point"] = self.slices[idx]("output_zero_point")

            if self.skipped_nodes:
                out_shape = n("Y").get_shape()
                out_shape[-1] //= self.split_factor

                state_output = self._create_tensor(
                    n("Y").get_name() + f"_{idx}",
                    out_shape,
                    n("Y").require_tensor().get_dtype(),
                )
                outputs = {"Y": state_output}
            else:
                outputs = {"Y": self.slices[idx]("output")}

            linear_node = self.add_node(
                type=n.get_op_type(),
                domain="ai.onnx.contrib",
                inputs=inputs,
                outputs=outputs,
                attributes=n.get_attributes() | dict(orig_name=n.get_attributes()["orig_name"] + f"_slice_{idx}"),
                new_name=n.get_name() + f"_slice_{idx}",
                
            )

            for i, reshape in enumerate(self.skipped_nodes):
                if i == 0:
                    data = linear_node("Y")

                # if the skipped node isn't the last one, a new tensor must be created
                if i < len(self.skipped_nodes) - 1:
                    out_shape = reshape("reshaped").get_shape()

                    # split the last dimension (the last in context of the linear op)
                    for j in range(len(out_shape) - 1, -1, -1):
                        if out_shape[j] % self.split_factor == 0:
                            out_shape[j] //= self.split_factor
                            break

                    output = self._create_tensor(
                        reshape("reshaped").get_name() + f"_{idx}",
                        out_shape,
                        reshape("reshaped").require_tensor().get_dtype(),
                    )
                else:
                    output = self.slices[idx]("output")

                shape = self.add_initializer(
                    reshape("shape").get_name() + f"_{idx}",
                    np.array(output.get_shape()),
                )

                self.add_node(
                    type="Reshape",
                    domain="ai.onnx.contrib",
                    inputs={
                        "data": data,
                        "shape": shape,
                    },
                    outputs={"reshaped": output},
                    attributes=reshape.get_attributes(),
                    new_name=reshape.get_name() + "ls" + f"_{idx}",
                    add_matcher_name=False,
                )

                data = output

            self.remove_node(self.slices[idx])
