# fmt: on
from OGOAT.src.L1_fusion.py_match.helpers.common_type import (
    IndexArray,
    Perm,
    ShapeMask,
    TensorShape,
)
from OGOAT.src.L1_fusion.py_match.helpers.perm_helper import PermutationHelper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Tensor


class ReshapeTransposeHelper:
    """
    This class provides helper functions for swapping Reshape and Transpose nodes.
    """

    @staticmethod
    def get_tensor_dim(x: Tensor | TensorShape | None) -> int:
        """Return the number of dimension of this Tensor"""
        if not x:
            return 0
        if isinstance(x, Tensor):
            x = x.get_shape()
        return len(x)

    @staticmethod
    def get_tensor_rank(x: Tensor | TensorShape | None) -> int:
        """Return the rank of this tensor == the number of non-degenerate dimensions (of size > 1)"""
        if not x:
            return 0
        if isinstance(x, Tensor):
            x = x.get_shape()
        rk_ = [d for d in x if d != 1]
        return len(rk_)

    @staticmethod
    def get_node_rank(in_shape: TensorShape, out_shape: TensorShape) -> int:
        """
        rank(Node) := rank(Node.output) - rank(Node.input)
        """
        in_rank = ReshapeTransposeHelper.get_tensor_rank(in_shape)
        out_rank = ReshapeTransposeHelper.get_tensor_rank(out_shape)
        return out_rank - in_rank

    @staticmethod
    def is_identity_reshape(in_shape: TensorShape, out_shape: TensorShape) -> bool:
        """
        Identity Reshape IF  Reshape.input == Reshape.output
        """
        is_valid_diff, in_mask, out_mask = ReshapeTransposeHelper. reshape_diff(
            in_shape, out_shape
        )
        if not is_valid_diff:
            return False

        for group_in, group_out in zip(in_mask, out_mask):
            # if Reshape splits or joins dimension then it is not identity
            if len(group_in) > 1 or len(group_out) > 1:
                return False
        return True

    @staticmethod
    def reshape_diff(
        in_shape: TensorShape, out_shape: TensorShape, allow_N2M=True,
    ) -> tuple[bool, ShapeMask, ShapeMask]:
        """
        `ShapeMask` is a **list of lists of indices**,
        or simply -- a **list of IndexArrays** with each array individually referred to as a **group**.
        ShapeMasks are used to keep track of the mapping between input and output tensors of a Reshape node.
        In particular, the following should always hold:
            - groups could be empty
            - the number of groups in the input and output ShapeMasks should be the same
            - there is a trivial bijection between groups of in and out ShapeMasks:
                    in_mask[i] <--> out_mask[i]
        where:
            - len(in_mask[i]) == 1 or len(out_mask[i]) == 1 or both
            - PROD(in_shape[in_mask[i]]) == PROD(out_shape[out_mask[i]]), where we define PROD([]) := 1
            - the dimensions unaffected by Reshape are assigned to individual singleton groups:
                    IF  x_shape[d] == y_shape[d]  THEN x_mask[i] == y_mask[i] == [d]
            - the dimensions of size 1 ('ones') are assigned to individual singleton groups:
                    IF  x_shape[d] == 1  THEN  x_mask[i] == [d]
            - empty groups in one mask correspond to singleton groups indexing 'ones' in another mask:
                    IF  x_mask[i] is empty  THEN  y_mask[i] == [d] AND y_shape[d] == 1


        **Example**:
            - in_shape=[4,5,6]  AND  out_shape=[20,2,3]
                ==> in_mask=[[0,1],[2]]  AND  out_mask=[[0],[1,2]]
            - in_shape=[1,7,12]  AND  out_shape=[7,3,4]
                ==> in_mask=[[0],[1],[2]]  AND  out_mask=[[],[0],[1,2]]

        **Args**:
            - in_shape -- input shape of reshape
            - out_shape -- shape after reshape

        **Returns**:  Return masks for input and output shapes to track transformations that are done by Reshape.
            - is_valid_diff: bool -- flag if masks can be evaluated
            - in_mask: ShapeMask -- mask for in_shape
            - out_mask: ShapeMask -- mask for out_shape
        """

        l, r = len(in_shape)-1, len(out_shape)-1
        in_mask, out_mask = [], []
        while l >= 0 or r >= 0:
            if l >= 0 and r >= 0 and in_shape[l] == out_shape[r]:
                in_mask.append([l])
                out_mask.append([r])
                l -= 1
                r -= 1
            elif l >= 0 and in_shape[l] == 1:
                # unmatched `1` from input, match it to an empty group in output
                in_mask.append([l])
                out_mask.append([])
                l -= 1
            elif r >= 0 and out_shape[r] == 1:
                # unmatched `1` from output, match it to an empty group in input
                in_mask.append([])
                out_mask.append([r])
                r -= 1
            else: # allow_N2M is True, in_shape[l] != out_shape[r] != 1
                acc_shape_in, group_in = in_shape[l], [l]
                acc_shape_out, group_out = out_shape[r], [r]
                l -= 1
                r -= 1
                while (l >= 0 or r >= 0) and acc_shape_in != acc_shape_out:
                    if l >= 0 and acc_shape_in < acc_shape_out:
                        # merging dimensions
                        group_in.append(l)
                        acc_shape_in *= in_shape[l]
                        l -= 1
                    elif r >= 0 and acc_shape_out < acc_shape_in:
                        group_out.append(r)
                        acc_shape_out *= out_shape[r]
                        r -= 1
                    else:
                        break

                in_mask.append(group_in[::-1])
                out_mask.append(group_out[::-1])

                pass
        # print(f"{l=}; {r=};")
        return l < 0 and r < 0, in_mask[::-1], out_mask[::-1]

    @staticmethod
    def get_perm_for_mask(mask: ShapeMask, perm: Perm) -> tuple[bool, Perm]:
        """Check that the added dimensions are not separated by transpose"""
        for group in mask:
            if len(group) <= 1:
                continue

            idx = perm.index(group[0])
            for e in group[1:]:
                idx += 1
                if idx >= len(perm) or perm[idx] != e:
                    # the group is separated
                    return False, []

        heads = [group[0] for group in mask]
        mask_perm = [heads.index(p) for p in perm if p in heads]
        return (True, mask_perm)

    @staticmethod
    def remove_ones_from_mask(
        shape: TensorShape, mask: ShapeMask
    ) -> tuple[ShapeMask, IndexArray]:
        new_mask = []
        shift = 0
        removed_ones = []
        for group in mask:
            new_group = []
            for e in group:
                if shape[e] == 1:
                    shift += 1
                    removed_ones.append(e)
                else:
                    new_group.append(e - shift)
            if new_group:
                new_mask.append(new_group)
        return (new_mask, removed_ones)

    @staticmethod
    def remove_ones_from_shape(shape: TensorShape) -> tuple[TensorShape, IndexArray]:
        new_shape = []
        removed_ones = []

        for i, d in enumerate(shape):
            if d == 1:
                removed_ones.append(i)
            else:
                new_shape.append(d)
        return new_shape, removed_ones

    @staticmethod
    def remove_ones_from_perm(perm: Perm, removed_ones: IndexArray) -> Perm:
        """Remove 1's from perm"""
        for i in sorted(removed_ones, reverse=True):
            perm = [(p if p < i else (p - 1)) for p in perm if p != i]
        #
        return perm

    @staticmethod
    def insert_ones_to_perm(perm: Perm, removed_ones: IndexArray) -> Perm:
        """Inserts 1's to perm"""
        # FIXME: for correct semantics we the position of removed ones in the original perm
        new_perm: Perm = [None] * (len(perm) + len(removed_ones))
        for i in sorted(removed_ones, reverse=False):
            perm = [((p + 1) if p >= i else p) for p in perm]
            new_perm[i] = i

        perm = iter(perm)
        for i, p in enumerate(new_perm):
            if p is not None:
                continue
            new_perm[i] = next(perm)
        return new_perm

    @staticmethod
    def insert_ones_to_perm_left(perm: Perm, num: int) -> Perm:
        """Inserts `num` 1's to perm on the left (outer dims)"""
        if not num:
            return perm
        perm = [p + num for p in perm]
        new_perm = list(range(num)) + perm
        return new_perm

    @staticmethod
    def insert_ones_to_shape_left(shape: TensorShape, num: int) -> TensorShape:
        """Inserts `num` 1's to shape on the left (outer dims)"""
        if not num:
            return shape
        new_shape: TensorShape = ([1] * num) + shape
        return new_shape

    @staticmethod
    def trans_reshape_swap(
        trans_shape: TensorShape, perm: Perm, out_shape: TensorShape
    ) -> tuple[bool, TensorShape, Perm]:
        """
        original:
            - [trans_shape] -> Trans(perm) -> [in_shape] -> Reshape -> [out_shape]
            - [in_shape] == [trans_shape] * perm

        result:
            - [trans_shape] -> Reshape -> [new_tr_shape] -> Trans(new_perm) -> [out_shape]

        Return:  `is_valid_perm`, `new_tr_shape`, `new_perm`
        """
        # reversed:
        #       [out_shape] -> Reshape -> [in_shape] -> Trans(inverse_perm) -> [trans_shape]
        in_shape = PermutationHelper.permute(trans_shape, perm)
        inverse_perm = PermutationHelper.get_inverse_perm(perm)

        # reuse reshape_trans to check if Reshape and Transpose can be swapped:
        #   [out_shape] -> Trans(new_inverse_perm) -> [new_tr_shape] -> Reshape -> [trans_shape]
        is_valid_perm, new_inverse_perm = ReshapeTransposeHelper.reshape_trans_swap(
            out_shape, in_shape, inverse_perm
        )
        if not is_valid_perm:
            return (False, [], [])

        # reverse back:
        #       [trans_shape] -> Reshape -> [new_tr_shape] -> Trans(new_perm) -> [out_shape]
        new_tr_shape = PermutationHelper.permute(out_shape, new_inverse_perm)
        new_perm = PermutationHelper.get_inverse_perm(new_inverse_perm)

        return (True, new_tr_shape, new_perm)

    @staticmethod
    def reshape_trans_swap(
        in_shape: TensorShape, out_shape: TensorShape, perm: Perm
    ) -> tuple[bool, Perm]:
        """
            - [in_shape] -> Reshape -> [out_shape] -> Transpose(perm)

        Return (True, `new_perm`):
            - [in_shape] -> Transpose(new_perm) -> Reshape -> [out_shape]
        """
        is_valid_diff, in_mask, out_mask = ReshapeTransposeHelper.reshape_diff(
            in_shape, out_shape
        )
        if not is_valid_diff:
            return (False, [])

        in_mask, in_removed_ones = ReshapeTransposeHelper.remove_ones_from_mask(
            in_shape, in_mask
        )
        out_mask, out_removed_ones = ReshapeTransposeHelper.remove_ones_from_mask(
            out_shape, out_mask
        )

        # remove 1s from perm
        perm = ReshapeTransposeHelper.remove_ones_from_perm(perm, out_removed_ones)

        is_valid_perm, mask_perm = ReshapeTransposeHelper.get_perm_for_mask(
            out_mask, perm
        )
        if not is_valid_perm:
            return (False, [])

        new_perm = []
        for p in mask_perm:
            new_perm += in_mask[p]

        # add 1s to perm
        new_perm = ReshapeTransposeHelper.insert_ones_to_perm(new_perm, in_removed_ones)

        return (True, new_perm)
