# fmt: on
import logging
from typing import Any, Optional
from OGOAT.src.L1_fusion.py_match.helpers.common_type import Perm, TensorShape
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Tensor
from collections import defaultdict


def compute_permutation(shape1: list[int], shape2: list[int]) -> Optional[Perm]:
    """Computes the permutation that transforms shape1 into shape2. Considering broadcasting rules."""
    perm = [-1] * len(shape1)

    dim_count = defaultdict(int)
    for dim in shape1:
        dim_count[dim] += 1
    # dim which are not in shape2 needs to be mapped to 1 for broadcasting
    for dim in set(dim_count.keys()):
        if dim == 1 or dim in shape2:
            continue
        dim_count[1] += dim_count[dim]
        del dim_count[dim]

    # use identity for not changing dims
    for i in range(len(perm)):
        if perm[i] == -1 and shape1[i] == shape2[i]:
            perm[i] = i
            dim_count[shape1[i]] -= 1
            if dim_count[shape1[i]] == 0:
                del dim_count[shape1[i]]

    # place dims that appear only once
    for dim in set(dim_count.keys()):
        if dim == 1 or dim_count[dim] != 1 or dim not in shape2:
            continue
        index1 = shape1.index(dim)
        index2 = shape2.index(dim)
        perm[index1] = index2
        del dim_count[dim]

    if len(dim_count) > 1:
        logging.error("mutiple same dims found, cannot compute permutation")
        return None

    remaining_perms = [i for i in range(len(perm)) if i not in perm]
    for i in range(len(perm)):
        if perm[i] == -1:
            perm[i] = remaining_perms.pop(0)
    return perm


class PermutationHelper:
    """
    This class provides helper functions related to Transpose's permutation.
    """

    @staticmethod
    def get_identity_perm(
        x: Tensor | TensorShape | Perm | int | None,
    ) -> Optional[Perm]:
        """
        Returns the identity permutation `[1, ..., n]` where
            - n = len(x.shape), IF x is a Tensor, OR:
            - n = len(x),  IF x is a TensorShape or Perm (list of ints), OR:
            - n = x, IF x is an int.
        """
        if not x:
            return None
        if isinstance(x, Tensor):
            x = len(x.get_shape())
        if isinstance(x, list):
            x = len(x)
        return list(range(x))

    @staticmethod
    def is_identity_perm(x: Perm) -> bool:
        is_id = x == list(range(len(x)))
        return is_id

    @staticmethod
    def permute(x: list[Any], perm: Perm) -> list[Any]:
        """Permute values of a list `x` using given permutation `perm`"""
        if len(perm) != len(x) or set(perm) != set(range(len(perm))):
            raise ValueError(f"perm={perm} is not applicable to x={x}")
        return [x[i] for i in perm]

    @staticmethod
    def get_inverse_perm(perm: Perm) -> Perm:
        """
        Returns the inverse permutation of the given permutation `perm`.
        The following holds true:
            - permute(inverse_perm, perm) == permute(perm, inverse_perm) == identity_perm(perm) == [1, ..., n]
            - where n := len(perm)
        """
        inverse_perm = [0] * len(perm)
        for i, p in enumerate(perm):
            inverse_perm[p] = i
        return inverse_perm

    @staticmethod
    def are_cancellable_perms(perm_A, perm_B) -> bool:
        perm_Y = PermutationHelper.permute(perm_A, perm_B)
        return PermutationHelper.is_identity_perm(perm_Y)

    @staticmethod
    def is_nontrivial_permutation(perm: list[int]):
        return not PermutationHelper.is_identity_perm(perm)
