# fmt: on
from typing import List
import logging
import math

from OGOAT.src.L1_fusion.py_match.helpers.common_type import Perm, TensorShape
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    MatcherError,
    Element,
    Tensor
)


class RowWiseHelper:

    def is_rowwise_op(
            self, axis: int, inputs: list[Tensor], node_name: str
        ) -> bool:
            """Check if the operation is row-wise."""

            try:
                self._validate_axis_and_shape(axis, inputs, node_name)
                return True
            except MatcherError as e:
                logging.debug(f"Row-wise op validation failed: {e}")
                return False
            
    def _validate_axis_and_shape(
            self, axis: int, inputs: list[Tensor], node_name: str
        ) -> None:
            """Validate axis and input shapes for row-wise ops."""

            num_dimensions = len(inputs[0].get_shape())

            if axis == -1 or axis == num_dimensions - 1:
                raise MatcherError(
                    f"The axis must not be the index of the Col, i.e. the last index of the shape for the node {node_name}."
                )

            if axis < 0 and (axis := axis + num_dimensions) < 0:
                raise MatcherError(
                    f"Incorrect axis - {axis} found for the node {node_name}."
                )

            if axis > 0:
                for input in inputs:
                    if (
                        any(value != 1 for value in input.get_shape()[:axis])
                        or input.get_shape()[axis] < 1
                    ):
                        raise MatcherError(
                            f"The value before the value at the axis must be all equals 1 for the node {node_name}."
                        )