# fmt: on
from typing import Any

import numpy as np
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.dataflow_attribute_generator import (
    DataflowAttributeGenerator,
)

from OGOAT.src.L1_fusion.py_match.helpers.noop_helper import NoopHelper
from OGOAT.src.L1_fusion.py_match.helpers.rowwise_helper import RowWiseHelper

from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherError,
    Node,
    Tensor,
    NoMatch,
)


class RowWiseOpToRuntime(Matcher, DataflowAttributeGenerator, NoopHelper, RowWiseHelper):
    """
    This class matches the row-wise Concat/Slice/Gather op and set type of Concat/Slice/Gather as Concat_runtime/Slice_runtime/Gather_runtime, see AIESW-776. The following patterns should match this pattern.
    1x1x1x64
    1x1x63x64                     1x63x1x64                      63x1x1x64
    1x1xRowxCol (axis 2,-2)       1xRowx1xCol (axis 1,-3)        Rowx1x1xCol (axis 0,-4)

    1x1x64x64
    1x63x64x64
    1xRowxCol(64x64)
    """

    def _get_axis_attribute(self, node: Node) -> Any:
        """Get the axis attribute from a node, raise if missing."""

        if "axis" not in node.get_attributes():
            raise MatcherError(
                f"axis attribute is missing for the node {node.get_name()}."
            )
        return node.get_attribute_value("axis")

    def match(self) -> None:
        n = self.n.require_node()
        node_name = n.get_name()

        if n.check(opType.Concat):
            axis = self._get_axis_attribute(n)
            inputs = n.get_inputs()
            if len(inputs) <= 1:
                raise MatcherError(
                    f"number of inputs of Concat must be bigger than 1 for the node {node_name}."
                )

            for input in inputs:
                if input.check_initializer() or input("x").check_initializer():
                    raise MatcherError(
                        f"concat_runtime requires act for inputs but const found {input}."
                    )

        elif n.check(opType.Gather):
            axis = self._get_axis_attribute(n)
            self.indices_value = n("indices").require_initializer().get_value_as_array()
            indices_array = np.atleast_1d(np.asarray(self.indices_value))
            if indices_array.ndim > 1:
                raise MatcherError(
                    f"indices_value must be scalar or 1D array (got : {indices_array.shape})"
                )

            if indices_array.size > 1:
                if not np.all(np.diff(indices_array) == 1):
                    raise MatcherError(
                        f"indices_value must contain contiguous (consecutive) values"
                    )
            inputs = [n("data")]

        elif n.check(opType.Slice):
            if self.is_noop_slice(n):
                raise NoMatch("Slice runtime cannot be a no operation")
            if "axes" not in n.get_schema_input_names():
                raise MatcherError(f"axes input is missing for the node {node_name}.")
            axes = n("axes").require_tensor().get_initializer_array()
            if len(axes) != 1:
                raise MatcherError(
                    f"expected single axis, got axes {axes} for the node {node_name}."
                )
            axis = axes[0]
            if n("data") is None:
                raise MatcherError(f"data input is missing for the node {node_name}.")
            inputs = [n("data")]

        else:
            raise MatcherError(
                f"pattern doesn't match row wise to runtime convert for the node {node_name}."
            )

        self._validate_axis_and_shape(axis=axis, inputs=inputs, node_name=node_name)

    def modify(self) -> None:
        n = self.n
        new_type = n.require_node().get_op_type() + "_runtime"
        inputs = n.get_inputs_dict()
        outputs = n.get_outputs_dict()

        if n.check(opType.Slice):
            copy_attributes = self.generate_slice_attributes(n)
        else:
            copy_attributes = n.get_attributes()
            if n.check(opType.Gather):
                copy_attributes["indices"] = self.indices_value.tolist()
        self.remove_node(n)

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes,
            new_name=n.get_name() + "_runtime",
        )
