import re
from functools import cmp_to_key
from OGOAT.src.L1_fusion.py_match.adv.attention import Attention
from OGOAT.src.L1_fusion.py_match.adv.rope import LinearPlusRoPE
from OGOAT.src.L1_fusion.py_match.checkers import AttrValue, CategoryCheck
from OGOAT.src.L1_fusion.py_match.helpers.batch_helper import (
    BatchDataGenerator,
    BatchHelper,
)
from OGOAT.src.L1_fusion.py_match.helpers.fusion_configs import FusionConfigs
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    Matcher,
    MatcherError,
    Node,
    Tensor,
    WalkCfgBase,
    WalkCfgPlain,
)

from OGOAT.src.L1_fusion.py_match.basic.categories import (
    batchable_linear_category,
)
from OGOAT.src.utils.context import Logger
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner


class BatchingByLevel(Matcher, BatchHelper):
    """
    This class run batching level by level.
    Macher of each level of the graph
    Run batching each level
    """

    dependencies = [batchable_linear_category, LinearPlusRoPE()]
    node_blacklist: set[str] = set()

    def get_pre_computed_batched_shape_for_act_input(
        self, unique_act_inputs: dict[str, Tensor]
    ) -> list[int]:

        shape: list[int] = []
        if len(unique_act_inputs) == 1:
            # Only one unique input tensor exists
            unique_tensor = next(iter(unique_act_inputs.values()))
            shape = self.get_individual_shape_for_batching(unique_tensor)

        else:
            input_tensors = list(unique_act_inputs.values())
            shape = self.get_batched_shape([t.get_shape() for t in input_tensors])
        return shape

    def batch_match(self, model_dict: ModelDict) -> dict[int, list[Node]]:
        mha_list = model_dict.get_node_names_starts_with("MHA_")
        # MatMul_qdq MatMul_bias_qdq, MatMul_act_act
        matmul_list = model_dict.get_node_names_starts_with("MatMul_")
        combined_dict = mha_list | matmul_list
        if not combined_dict:
            return {}

        config = FusionConfigs.get_fusion_configs()
        pre_batch_nodes: list[str] = []
        for key, value in combined_dict.items():
            if not value:
                continue
            matched_batch_operator = next(
                (
                    op_type
                    for op_type in config.enable_batch_operator
                    if BatchingByLevel.op_match_op_type(op_type, key)
                ),
                None,
            )
            # MatMul_qdq_biasgelu_uint16xint8xuint16 in linear plus lut is not batchable
            if matched_batch_operator:
                pre_batch_nodes.extend(value)

        node_to_level = model_dict.get_levels([n for n in pre_batch_nodes])
        level_to_nodes: dict[int, list[Node]] = {}
        for key, value in node_to_level.items():
            if value not in level_to_nodes:
                level_to_nodes[value] = []
            node = Node(model_dict, WalkCfgPlain(), key)
            level_to_nodes[value].append(node)

        return (
            dict(sorted(level_to_nodes.items(), key=lambda x: x[0], reverse=True))
            if len(level_to_nodes) > 1
            else level_to_nodes
        )

    def contains_blacklisted_node(self, batch_nodes: list[Node]) -> bool:
        batch_node_names = set([node._node_name for node in batch_nodes])
        return len(batch_node_names & self.node_blacklist) > 0

    def run_single_batch_group(self, batch_nodes: list[Node], logger: Logger) -> bool:
        """
        Return True if batching for the group of nodes was successful, and False otherwise
        """
        # list of list of nodes: batch_nodes_list
        self.batch_nodes = batch_nodes
        try:
            self.do_match(batch_nodes[0])
        except MatcherError:
            return False
        self.do_modify(logger)
        return True

    def check_and_handle_error(self, runner: SafeRunner, batch_nodes: list[Node], logger: Logger) -> None:
        """
        Check if an error was caught by the runner, if that' the case we will add all nodes from
        that group to the blacklist, emit a warning to inform the user and raise a SafeRunnerError
        for the caller.
        """
        if not runner.has_failed:
            return

        batch_node_names = [node._node_name for node in batch_nodes]
        for node_name in batch_node_names:
            self.node_blacklist.add(node_name)

        logger.warning(f"Error when batching group of nodes '{batch_node_names}'. They will be skipped if batching is restarted.")
        logger.reset_indentation()
        runner.raise_error()

    def run(
            self, _model_dict: ModelDict, walk_cfg: "WalkCfgBase", runner: SafeRunner, logger: Logger
    ) -> int:
        self.match_cnt = 0
        level_to_nodes = self.batch_match(_model_dict)
        if not level_to_nodes:
            return 0
        for level, nodes in level_to_nodes.items():
            batch_lists = self.nodes_list_at_level(nodes)
            for batch_nodes in batch_lists:
                if self.contains_blacklisted_node(batch_nodes):
                    logger.debug(f"Skipping batch list {batch_nodes} as it contains a blacklisted node")
                    continue

                matched = runner.run(self.run_single_batch_group, batch_nodes, logger)
                self.check_and_handle_error(runner, batch_nodes, logger)

                if matched:
                    self.match_cnt += 1

        return self.match_cnt

    @staticmethod
    def op_match_op_type(base_op_type: str, op_type: str) -> bool:
        # base: matmul_qdq, matmul_qdq_uint16
        if not op_type.startswith(base_op_type):
            return False
        suffix = op_type[len(base_op_type) :]
        if not suffix:
            return True
        return re.match(r"^_(int4|int8|int16|uint4|uint8|uint16)", suffix) is not None

    def nodes_list_at_level(self, nodes_at_same_level: list[Node]) -> list[list[Node]]:
        batch_nodes_at_same_level: list[list[Node]] = []
        nodes_to_batch = set(nodes_at_same_level)
        for node in nodes_at_same_level:
            # # break if only one node in nodes_to_batch
            if len(nodes_to_batch) < 2:
                break
            # skip node already in batch
            if node not in nodes_to_batch:
                continue

            # Skip nodes with multiple outputs, as their batching behavior may be ambiguous or unsupported
            if len(node.get_outputs()) > 1:
                continue

            batchable_nodes = self.get_batch_nodes(
                node,
                nodes_at_same_level,
                FusionConfigs.get_fusion_configs().batch_by_out_tensor,
            )
            if len(batchable_nodes) < 2:
                nodes_to_batch.remove(node)
                continue
            batch_nodes_at_same_level.append(batchable_nodes)
            nodes_to_batch -= set(batchable_nodes)
        return batch_nodes_at_same_level

    def check_activation_inputs_order(self, act_input_names: list[str]) -> None:
        """
        For each activation input name, check if the sequence id is ordered or not
        Raises MatcherError if any are unordered.
        """
        for act_input_name in act_input_names:
            seq_id_list = self.batch_data_generator.get_act_inputs_sig_id_list(
                act_input_name
            )
            if not self._is_ordered(seq_id_list):
                raise MatcherError(
                    f"Batching by level requires ordered activation inputs, but found unordered for act_input : {act_input_name}"
                )

    def match(self) -> None:
        if self.n.check(CategoryCheck(Attention())):
            self.n.require(AttrValue("num_heads", 1))

        self.batch_nodes_sorted = sorted(
            self.batch_nodes, key=cmp_to_key(self.cmp_batch_position)
        )

        self.batch_data_generator = BatchDataGenerator(self.batch_nodes_sorted)
        self.batch_data_generator.generate_data()

        self.act_input_names = self.batch_data_generator.get_activation_input_names()
        self.check_activation_inputs_order(self.act_input_names)

        if self.n.check(CategoryCheck(Attention())):
            k_unique_inputs = self.batch_data_generator.get_unique_act_inputs_dict("K")
            self.k_nb_heads = self.get_pre_computed_batched_shape_for_act_input(
                k_unique_inputs
            )[0]
            if "V" in self.act_input_names:
                v_unique_inputs = self.batch_data_generator.get_unique_act_inputs_dict(
                    "V"
                )
                self.v_nb_heads = self.get_pre_computed_batched_shape_for_act_input(
                    v_unique_inputs
                )[0]
            else:
                self.v_nb_heads = self.k_nb_heads

            if self.k_nb_heads != self.v_nb_heads:
                raise MatcherError(
                    f"Group heads should be the same for K and V inputs: k {self.k_nb_heads}, v {self.v_nb_heads}"
                )

    def modify(self) -> None:

        n = self.n = self.batch_nodes_sorted[0]
        copy_attributes = n.get_attributes()
        length = len(self.batch_nodes_sorted)
        if n.check(CategoryCheck(Attention())):
            copy_attributes["num_heads"] = length
            self.set_batch_dimension(4)
        else:
            copy_attributes["num_batches"] = length
            self.set_batch_dimension(3)

        # Get unique activation inputs dict for all batch nodes for each activation input
        #  For example, if we have Q, K, V, M, B activation inputs for the batch nodes,
        #  we will get a dictionary with keys as activation input names and values as unique tensors for those activation inputs.
        unique_act_inputs_dict = (
            self.batch_data_generator.get_all_unique_act_inputs_dict()
        )
        inputs = self.get_concated_batch_inputs(
            self.batch_nodes_sorted, unique_act_inputs_dict
        )
        outputs = self.get_splitted_batch_outputs(self.batch_nodes_sorted)
        new_type = n.get_op_type()

        if n.check(CategoryCheck(Attention())):
            copy_attributes["num_groups"] = self.k_nb_heads
            copy_attributes["groups_sequence"] = (
                self.batch_data_generator.get_act_inputs_sig_id_list("K")
            )

            if inputs.get("M"):
                copy_attributes["num_mask"] = inputs["M"].get_shape()[0]
                copy_attributes["mask_sequence"] = (
                    self.batch_data_generator.get_act_inputs_sig_id_list("M")
                )
            if inputs.get("B"):
                copy_attributes["num_bias"] = inputs["B"].get_shape()[0]
                copy_attributes["bias_sequence"] = (
                    self.batch_data_generator.get_act_inputs_sig_id_list("B")
                )

        for node in self.batch_nodes_sorted:
            self.remove_node(node)

        self.add_node(
            type=new_type,
            domain="ai.onnx.contrib",
            inputs=inputs,
            outputs=outputs,
            attributes=copy_attributes,
        )
