# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.

from dataclasses import dataclass, field
from logging import Logger
from typing import ClassVar, Optional
import onnx
from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import OnnxOpsWrapper
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Element,
    InputTensor,
    Matcher,
    Node,
    OutputTensor,
    Tensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.checkers import opType


class DtypeFrozen:
    _frozen_nodes: ClassVar[Optional[set[str]]] = None

    def __init__(self, model: onnx.ModelProto, onnx_ops_wrapper: OnnxOpsWrapper):
        self.model_dict = ModelDict(model, onnx_ops_wrapper)
        self.walk_cfg_plain = WalkCfgPlain()

    def find_dtype_frozen_nodes(self):
        """
        Find all the nodes for which no data type changes will happen.
        """

        if DtypeFrozen._frozen_nodes is not None:
            return
        else:
            DtypeFrozen._frozen_nodes = set()

        for input_name in self.model_dict.get_graph_input_names():
            input_tensor = Tensor(self.model_dict, self.walk_cfg_plain, input_name)
            self.search_downwards_and_freeze(input_tensor, [])

        for output_name in self.model_dict.get_graph_output_names():
            output_tensor = Tensor(self.model_dict, self.walk_cfg_plain, output_name)
            self.search_upwards_and_freeze(output_tensor, [])

    def search_downwards_and_freeze(
        self, tensor_found: Tensor, node_list: list[Element], visited: set[str] = None
    ) -> None:
        """From Graph Input, search downwards in the graph and mark only the first QuantizeLinear node in the traversal for dtype freezing."""
        
        if visited is None:
            visited = set()

        for reader_node in tensor_found.get_readers():
            # Prevent infinite loops by checking if we've already visited this node
            node_name = reader_node.get_name()
            if node_name in visited:
                continue
                
            current_path = node_list + [reader_node]
            visited.add(node_name)

            if reader_node.check(opType.QuantizeLinear):
                DtypeFrozen._frozen_nodes.update(
                    node.get_name() for node in current_path
                )
                continue

            for output in reader_node.get_outputs():
                self.search_downwards_and_freeze(output, current_path, visited)

    def search_upwards_and_freeze(
        self, tensor_found: Tensor, node_list: list[Element], visited: set[str] = None
    ) -> None:
        """From Graph Output, search upwards in graph and mark only the first DequantizeLinear node in the traversal path for dtype freezing."""
        
        if visited is None:
            visited = set()

        writer_node = tensor_found.get_writer()
        if writer_node.check_nowhere() or writer_node.check_initializer():
            return

        # Prevent infinite loops by checking if we've already visited this node
        node_name = writer_node.get_name()
        if node_name in visited:
            return
            
        current_path = node_list + [writer_node]
        visited.add(node_name)

        if writer_node.check(opType.DequantizeLinear):
            DtypeFrozen._frozen_nodes.update(node.get_name() for node in current_path)
            return

        else:
            for input_tensor in writer_node.get_inputs():
                self.search_upwards_and_freeze(input_tensor, current_path, visited)


class FusionFrozen(Matcher):
    def __init__(self, model: onnx.ModelProto, onnx_ops_wrapper: OnnxOpsWrapper):
        self.model_dict = ModelDict(model, onnx_ops_wrapper)
        self.walk_cfg_plain = WalkCfgPlain()
        self.frozen_nodes_from_input = list[Node]()
        self.frozen_nodes_from_output = list[Node]()

    def match(self):
        for input_name in self.model_dict.get_graph_input_names():
            node_name = self.model_dict.get_reader_names(input_name)[0]
            input_tensor = InputTensor(
                self.model_dict, self.walk_cfg_plain, input_name, node_name
            )
            self.search_downwards_and_freeze(input_tensor, [])

        for output_name in self.model_dict.get_graph_output_names():
            node_name = self.model_dict.get_writer_names(output_name)[0]
            output_tensor = OutputTensor(
                self.model_dict, self.walk_cfg_plain, output_name, node_name
            )
            self.search_upwards_and_freeze(output_tensor, [])

    def search_downwards_and_freeze(
        self, tensor_found: Tensor, node_list: list[Element], visited: set[str] = None
    ) -> None:
        if visited is None:
            visited = set()
            
        for reader_node in tensor_found.get_readers():
            # Prevent infinite loops by checking if we've already visited this node
            node_name = reader_node.require_node().get_name()
            if node_name in visited:
                continue
                
            visited.add(node_name)
            
            if reader_node.require_node().check(
                opType.GatherElements
                | opType.Gather
                | opType.Slice
                | opType.DepthToSpace
                | opType.Transpose
                | opType.Reshape
            ):
                next_tensor = reader_node.get_outputs()[0]
                self.search_downwards_and_freeze(next_tensor, node_list + [reader_node], visited)
                continue

            if reader_node.require_node().check(opType.QuantizeLinear):
                # if chain ends with q, and contains dataflow nodes between q and input L1_fusion_frozen=1
                nodes_to_add = [
                    n
                    for n in (node_list + [reader_node])
                    if n not in self.frozen_nodes_from_input
                ]
                self.frozen_nodes_from_input.extend(nodes_to_add)
                continue

    def search_upwards_and_freeze(
        self, tensor_found: Tensor, node_list: list[Element], visited: set[str] = None
    ) -> None:
        if visited is None:
            visited = set()
            
        writer_node = tensor_found.get_writer()
        if writer_node is None:
            return
            
        # Prevent infinite loops by checking if we've already visited this node
        node_name = writer_node.require_node().get_name()
        if node_name in visited:
            return
            
        visited.add(node_name)
        
        if writer_node.require_node().check(
            opType.GatherElements
            | opType.Gather
            | opType.Slice
            | opType.DepthToSpace
            | opType.Transpose
            | opType.Reshape
        ):
            next_tensor = writer_node.get_inputs()[0]
            self.search_upwards_and_freeze(next_tensor, node_list + [writer_node], visited)
            return

        # Check for DequantizeLinear ops as start of the chain
        if writer_node.require_node().check(opType.DequantizeLinear):
            nodes_to_add = [
                n
                for n in (node_list + [writer_node])
                if n not in self.frozen_nodes_from_output
            ]
            self.frozen_nodes_from_output.extend(nodes_to_add)
            return

    def modify(self) -> None:
        for node in self.frozen_nodes_from_input:
            node.require_node().set_attribute("L1_fusion_frozen", 1)

        for node in self.frozen_nodes_from_output:
            output_tensor = node.get_outputs()[0].require_tensor()
            readers = output_tensor.get_readers()
            if len(readers) > 1:
                # out<-slice<- gather <-concat  <- Dq <-Q
                #             /                    /
                #        sub <-                sub<-
                # out<-slice<-    gather    <-concat  <- Dq  <- Q
                #                           /                /
                #      sub  <- gather_copy <-    sub<-Dq_copy<-
                filtered_readers = [
                    r
                    for r in readers
                    if r not in self.frozen_nodes_from_output
                    and not r.check_graph_output()
                ]
                self.create_temp_node(output_tensor, node, filtered_readers)
            node.require_node().set_attribute("L1_fusion_frozen", 1)

    def create_temp_node(
        self, tensor: Tensor, node_forked: Node, readers: list[Node]
    ) -> None:
        for reader in readers:
            new_tensor = Tensor(
                tensor._model_dict, tensor._walk_cfg, tensor.get_name() + "_copy", None
            )
            shape = tensor.get_shape()
            dtype = tensor.get_dtype()
            new_tensor.set_shape(shape, dtype)
            outputs = {}
            schema_output_name = node_forked.get_schema_output_names()[0]
            outputs[schema_output_name] = new_tensor
            reader._model_dict.replace_input(
                reader.get_name(), tensor.get_name(), new_tensor.get_name()
            )
            self.temp_node = self.add_node(
                type=node_forked.get_op_type(),
                domain=node_forked.get_domain(),
                inputs=node_forked.get_inputs_dict(),
                outputs=outputs,
                attributes=node_forked.get_attributes(),
                new_name=node_forked.get_name() + "_copy",
                model_dict=tensor._model_dict,
                required_attr={"orig_name": node_forked.get_name()},
            )


def tag_l1_fusion_frozen(
    model: onnx.ModelProto,
    onnx_ops_wrapper: OnnxOpsWrapper,
    logger: Logger,
) -> None:
    """
    Run the frozen L1 fusion on the model.
    This method will apply the fusion patterns defined for frozen models.
    """
    fusion_frozen = FusionFrozen(model, onnx_ops_wrapper)
    logger.info("Starting frozen L1 fusion patterns...")
    fusion_frozen.match()
    fusion_frozen.modify()
    logger.info("Frozen L1 fusion patterns completed.")


def find_dtype_frozen_nodes(
    model: onnx.ModelProto,
    onnx_ops_wrapper: OnnxOpsWrapper,
    logger: Logger,
) -> None:
    """
    Find all  nodes for which it needs to freeze all the input dtypes.
    """
    dtype_frozen = DtypeFrozen(model, onnx_ops_wrapper)
    logger.info("Starting frozen dtype fusion patterns...")
    dtype_frozen.find_dtype_frozen_nodes()
    logger.info("Frozen dtype fusion patterns completed.")
