# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
from OGOAT.src.L1_fusion.py_match.checkers import opType
from OGOAT.src.L1_fusion.py_match.helpers.elementwise_helper import (
    is_elementwise,
    is_unary,
)
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    NoMatch,
    Tensor,
    WalkCfgPlain,
)


class MoveReshape(Matcher):
    """
    Moves a reshape node after a unary elementwise node
    """

    def match(self) -> None:
        n = self.n.with_walk_cfg(WalkCfgPlain())
        self.reshape = n.require(opType.Reshape).require_node()
        self.elemwise = n("reshaped").require_node()

        if not is_elementwise(self.elemwise):
            raise NoMatch("not an elementwise operator")

        if not is_unary(self.elemwise):
            raise NoMatch("not an unary operator")

        if len(self.reshape("data").get_readers()) > 1:
            raise NoMatch("more than 1 readers are not supported")

    def modify(self) -> None:

        # T1 [shape X1] -> Reshape -> T2 [Shape X2] -> Elem -> T3 [shape X2].
        # T1 [shape X1] -> Elem -> T2 [Shape X1] -> Reshaped -> T3 [shape X2]
        tensor1: Tensor = self.reshape("data")
        tensor2: Tensor = self.reshape("reshaped")
        tensor3: Tensor = self.elemwise.get_outputs()[0]

        self.replace_input(self.elemwise, self.elemwise.get_inputs()[0], tensor1)
        self.replace_output(self.reshape, tensor2, tensor3)
        self.replace_output(self.elemwise, tensor3, tensor2)
        self.replace_input(self.reshape, tensor1, tensor2)

        shape1 = tensor1.get_shape()
        dtype1 = tensor1.get_dtype()
        tensor2.set_shape(shape1, dtype1)
