# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.

from typing import Any

import numpy as np

from OGOAT.src.L1_fusion.py_match.basic.dataflow import Dataflow
from OGOAT.src.L1_fusion.py_match.basic.rowwise_runtime import RowWiseOpToRuntime
from OGOAT.src.L1_fusion.py_match.checkers import opType, CategoryCheck
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    MatcherError,
    Node,
    NoMatch,
    Tensor,
)


class MultiAxisSlice(RowWiseOpToRuntime):
    """
    matches slice nodes with mutiple axis and replace them with mutiple slice nodes which only slice on one axis
    """

    def set_disable_q_dq_attributes(self, slice_nodes: list[Node]) -> None:

        runtime_indices = {
            i
            for i, node in enumerate(slice_nodes)
            if node.get_op_type() == "Slice_runtime"
        }

        for i, node in enumerate(slice_nodes):
            if i in runtime_indices:
                continue
            is_first = i == 0
            is_last = i == len(slice_nodes) - 1
            after_runtime = any(
                r < i for r in runtime_indices
            )  # True if any runtime before this node

            if is_first:
                node.set_attribute("disable_q", True)  # keep disable_dq0 unchanged
                continue

            if is_last:
                if not after_runtime:
                    node.set_attribute("disable_dq0", True)  # keep disable_q unchanged
                continue

            node.set_attribute("disable_q", True)
            if not after_runtime:
                node.set_attribute("disable_dq0", True)

    def match(self) -> None:
        n = self.n.require_node()

        n.require(CategoryCheck(Dataflow()))
        if not n.get_op_type().startswith("Slice_qdq"):
            raise NoMatch("No Slice_qdq node")
        axes = n.get_attribute_value("axes")
        if len(axes) <= 1:
            raise NoMatch("No mutiple axis")

    def modify(self) -> None:
        n = self.n.require_node()

        slice_nodes: list[Node] = []

        last_input_tensor: Tensor = n("data")
        last_input_tensor_shape = last_input_tensor.get_shape()
        last_input_tensor_dtype = last_input_tensor.get_dtype()

        input_tensor_name = last_input_tensor.get_name()
        for idx, axis in enumerate(n.get_attribute_value("axes")):
            inputs = n.get_inputs_dict()
            inputs["data"] = last_input_tensor

            last_input_tensor = Tensor(
                model_dict=n._model_dict,
                walk_cfg=n._walk_cfg,
                tensor_name=input_tensor_name + f"_{idx}",
            )
            last_input_tensor_shape[axis] = n("output").get_shape()[axis]
            last_input_tensor.set_shape(
                last_input_tensor_shape, last_input_tensor_dtype
            )

            outputs: dict[str, Any] = {"output": last_input_tensor}

            attrs = n.get_attributes()
            attrs["starts"] = [attrs["starts"][idx]]
            attrs["start"] = [attrs["start"][idx]]
            attrs["ends"] = [attrs["ends"][idx]]
            attrs["end"] = [attrs["end"][idx]]
            attrs["axes"] = [attrs["axes"][idx]]
            attrs["steps"] = [attrs["steps"][idx]]

            attrs["orig_name"] = n.get_name() + f"_{idx}"
            attrs["num_of_tensor_inputs"] = 1

            output_shape = last_input_tensor.require_tensor().get_shape()
            try:
                self._validate_axis_and_shape(axis, [inputs["data"]], n.get_name())
                is_valid = True
            except MatcherError:
                is_valid = False

            can_be_runtime_op = is_valid and inputs["data"].get_shape() != output_shape
            op_type = "Slice_runtime" if can_be_runtime_op else n.get_op_type()

            if can_be_runtime_op:
                inputs = {"data": inputs["data"]}
                attrs.pop("disable_q", None)
                attrs.pop("disable_dq0", None)

            slice = self.add_node(
                type=op_type,
                domain=n.get_domain(),
                inputs=inputs,
                outputs=outputs,
                attributes=attrs,
                new_name=n.get_name() + f"_{idx}",
            )

            slice_nodes.append(slice)

        old_output = n("output")
        self.remove_node(n)
        self.replace_output(slice_nodes[-1], slice_nodes[-1]("output"), old_output)
        self.set_disable_q_dq_attributes(slice_nodes)
