# fmt: on
from typing import List
import logging
import math

from OGOAT.src.L1_fusion.py_match.helpers.common_type import Perm, TensorShape
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    MatcherError,
    Element,
)


class NoopHelper:

    def is_noop_reshape(self, node: Element) -> bool:
        """
        Return True if the passed reshape node is a no-operation according to the memory layout in hardware.
        """

        try:
            input_shape = node("data").require_tensor().get_shape()
            output_shape = node("reshaped").require_tensor().get_shape()
        except MatcherError:
            return False

        logging.info(f"\n input_shape: {input_shape},   output_shape: {output_shape},")
        return input_shape[-1] == output_shape[-1] or (
            input_shape[-1] % 8 == 0 and output_shape[-1] % 8 == 0
        )

    def is_noop_transpose(self, node: Element) -> bool:
        """
        Return True if the passed transpose node is a no-operation according to
        the memory layout in hardware.
        """

        try:
            shape = node("data").require_tensor().get_shape()
            perm = node.require_node().get_attribute_value("perm")
            permuted_shape = node("transposed").require_tensor().get_shape()
        except MatcherError:
            return False

        logging.info(
            f"\n shape: {shape},   perm: {perm},   permuted_shape: {permuted_shape}"
        )

        if math.prod(permuted_shape) == 1:
            # if total number of elements are 1, then it is not a noop
            return True
        elif permuted_shape[-1] % 8 != 0:
            # if last dimension needs padding (not divisable by 8) then it is not a noop
            return False

        no_op = self.is_noop_transpose_shape_perm(shape, perm)
        logging.info(f" Is Transpose node no_op? : {no_op}")
        return no_op

    def is_noop_transpose_shape_perm(
        self, input_shape: TensorShape, perm: Perm
    ) -> bool:
        """
        input_shape -- input shape to transpose operator
        perm -- perm attribute of transpose operator

        Return True if the transpose does not change anything in the underlying
        memory layout of the values, i.e., the transpose is no-op in HW.

        See TransposeHelperTests.test_is_noop_tranpose_shape_perm in
        py_match_fast_unit_tests.py for examples.
        """
        # invalid perm (wrong length or not a perm) --> don't report no-op
        if len(perm) != len(input_shape) or set(perm) != set(range(len(perm))):
            return False
        # remove all dimensions of size 1 from index and perm
        # (1 entries in shape don't change memory layout)
        input_shape_wo1: list[int] = []
        idx_wo1 = 0
        perm_wo1 = perm.copy()
        for dim in input_shape:
            # dimension is one --> remove the index from perm, don't add to input shape
            if dim == 1:
                perm_wo1 = [
                    (p if p < idx_wo1 else p - 1) for p in perm_wo1 if p != idx_wo1
                ]
            # dimension is not one --> add to input shape, keep perm
            else:
                input_shape_wo1.append(dim)
                idx_wo1 += 1
        # perm with entries of dimension 1 removed is identity --> no-op
        return all(p == i for i, p in enumerate(perm_wo1))

    def is_noop_slice(self, node: Element) -> bool:
        try:
            input_shape = node("data").require_tensor().get_shape()
            output_shape = node("output").require_tensor().get_shape()
        except MatcherError:
            return False
        return input_shape == output_shape

    def is_noop_reduction(self, input: List[int], output: List[int]) -> bool:
        """
        Return True if the reduction is a no-operation (i.e. only dimensions of size 1 were reduced).
        This essentially checks if the number of elements remains the same.
        """
        return math.prod(input) == math.prod(output)

    def is_noop_unsqueeze(self, node: Element) -> bool:
        """
        Return True if the node is an Un-/squeeze operation on the last dimension.
        """
        if node.require_node().get_op_type() == "Squeeze":
            shape = node("data").require_tensor().get_shape()
        elif node.require_node().get_op_type() == "Unsqueeze":
            shape = node("expanded").require_tensor().get_shape()
        else:
            return False

        axes = node("axes").require_initializer().get_value()

        # Check if axes contains the last dimension as an entry
        # The last dimension index is len(input_shape) (after unsqueeze) or -1
        return len(shape) - 1 not in axes and -1 not in axes
