# fmt: on
import argparse
from collections import defaultdict
import logging
import onnx
from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import OnnxOpsWrapper
from OGOAT.src.L1_fusion.L1_utils.utils import save_model
from OGOAT.src.L1_fusion.py_match.helpers.qdq_helper import QDQHelper
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.py_match.nodes_tensors import (
    InputTensor,
    MatcherError,
    Node,
    OutputTensor,
    WalkCfgPlain,
)
from OGOAT.src.L1_fusion.py_match.checkers import opType

LINEAR_ONNX_OPS = ["Conv", "MatMul", "MHA"]
BINARY_OPS = ["Add", "Mul", "Sub", "Div"]


class QdqTagging(QDQHelper):

    def __init__(self, model, qdq_optimization: bool):
        self._model_dict = ModelDict(model, OnnxOpsWrapper())
        self._qdq_optimization = qdq_optimization
        # node name -> Node
        self._nodes_dict: dict[str, Node] = defaultdict()
        # node name -> list of names of nodes writing a tensor read by this node
        self._in_nodes_dict = self._model_dict._in_nodes_dict
        # node name -> list of names of nodes reading a tensor written by this node
        self._out_nodes_dict = self._model_dict._out_nodes_dict
        # initializer name -> ONNX initializer (onnx.TensorProto)
        self.flag_dict: dict[str, bool] = defaultdict(lambda: False)
        # tensor name -> nodes that writes the tensor
        self.input_dict: dict[str, list[Node]] = defaultdict(list)
        # tensor name -> nodes that reads the tensor
        self.output_dict: dict[str, list[Node]] = defaultdict(list)
        self.node_list: list[Node] = list()
        return

    def is_skip_op(self, node: Node) -> bool:
        """
        skip:
            1. dataflow ops: transpose, gather, concat or concat_qdq
            2. noop/runtime ops : reshape_noop, concat_runtime
        not skip:
            1. unfused onnx op: Dq/Q/Cast/Where/ ... etc.,
            2. linear ops, nonlinear ops, linear_non_linear ops: Add_qdq_uint16xuint16, Matmul_silu_qdq, etc
        """
        # skip fused noop and runtime nodes
        if self.is_noop_runtime_op(node):
            return True
        # skip unfused and fused dataflow ops
        if self.is_dataflow_op(node):
            return True
        return False

    @staticmethod
    def is_noop_runtime_op(node: Node) -> bool:
        """Fused Noops: noop nodes and row-wise concat/slice, which has optype ends with `_noop` or `_runtime`"""
        base_type, *fused_types = node.get_op_type().split("_")
        if fused_types and fused_types[-1] in ("noop", "runtime"):
            return True  # it is fused noop
        return False

    @staticmethod
    def is_unfused_or_noop_runtime_op(node: Node) -> bool:
        base_type, *fused_types = node.get_op_type().split("_")
        if not fused_types or fused_types[-1] in ("noop", "runtime"):
            return True  # it is unfused op or fused noop
        if base_type == "Quant" or base_type == "Dequant":
            return True  # Quant or Quant although fused it should not tag dq/q
        return False

    def is_dataflow_op(self, node: Node) -> bool:
        """
        return True if node is a dataflow op. It is a dataflow op if the native dtype of the op is ["any"].
        """
        op_type = node.get_op_type()
        split_op_type = op_type.split("_")
        if split_op_type[-1] in ("noop", "runtime"):
            return False
        return self.get_native_dtype(op_type) == ["any"]

    def is_float_reading_op(self, node: Node) -> bool:
        if self.is_skip_op(node):
            return False
        op_type = node.get_op_type()
        if self.get_kernel_for_op(op_type) is None:
            return self._is_non_linear_op(node)
        return self.get_native_dtype(op_type) == ["bf16"]

    def is_float_writing_op(self, node: Node) -> bool:
        if self.is_skip_op(node):
            return False
        op_type = node.get_op_type()
        if self.get_kernel_for_op(op_type) is None:
            return self._is_non_linear_op(node) or self._is_linear_non_linear_op(node)
        return self.get_native_dtype(op_type) in (["bf16"], ["int16", "bf16"])

    # FIXME: remove the following two methods when all fused ops have corresponding kernels
    def _is_non_linear_op(self, node: Node) -> bool:

        if self.is_skip_op(node):
            return False

        base_type, *fused_types = node.get_op_type().split("_")
        if not fused_types:
            return False  # it is unfused onnx op
        return base_type not in LINEAR_ONNX_OPS and not self.is_dataflow_op(node)

    def _is_linear_non_linear_op(self, node: Node) -> bool:
        if self.is_skip_op(node):
            return False

        base_type, *fused_types = node.get_op_type().split("_")
        if not fused_types:
            return False  # it is unfused onnx op
        return (
            base_type == "MatMul"
            and "silu" in node.get_op_type()
            or "gelu" in node.get_op_type()
        )

    def tag_qdq_nodes(self):
        self.initialize_qdq_tagging()
        if not self._qdq_optimization:
            return
        self.check_src_dst_floating_ops()
        self.check_tensor_src_dst_has_same_tensor()
        self.update_qdq_tagging()
        return

    def initialize_qdq_tagging(self):

        # node name -> available qdq selectors
        qdq_selectors: dict[str, list[str]] = defaultdict(list)
        #  tensor name -> bool, True only when the tensor input and output nodes are all nonlinear ops.
        for node_name in self._model_dict.get_node_names():
            node = Node(self._model_dict, WalkCfgPlain(), node_name)
            self._nodes_dict[node_name] = node
            attributes = node.get_attributes()
            if 'disable_dq0' in attributes and 'disable_q' in attributes:
                continue  # skip fused node that has disable_dq0/q attribute
            
            # if node is NOT noop/runtime, unfused dataflow, or fused dataflow --> node_list += [node]
            if not self.is_skip_op(node):
                self.node_list.append(node)

            # skip qdq initialization for unfused and noop/runtime nodes, even if they present in kernel_dict
            if self.is_unfused_or_noop_runtime_op(node):
                continue

            op_type = node.get_op_type()
            # if kernel is available then populate qdq_selectors from metadata
            # if node already has a qdq_selector with assigned value -- keep it as is
            if self.get_kernel_for_op(op_type):
                qdq_selector = self.get_qdq_selector(op_type)
                qdq_selectors[node_name] = list(qdq_selector)
                # for now ignore selector's default value
                for selector, _ in qdq_selector.items():
                    node.setdefault_attribute(selector, 0)
            elif self._is_non_linear_op(node):
                input_node_length = 0
                for tensor in node.get_inputs():
                    # FIXME when there are ops that other than binaryop, that has const or initializer
                    # as input that need disable_dq0 tagging, then modify this condition. Currently, only activation input
                    # or graph input are considered that disable_dq0 needs to be added.
                    if not tensor.check_initializer() and not tensor.check(
                        opType.Constant
                    ):
                        input_node_length += 1
                if node_name not in qdq_selectors:
                    # add node name to qdq_selectors if it doesn't exist in kernel metadata
                    # set node attribute if it doesn't exist in kernel metadata
                    attribute_list = []
                    if not node.has_attribute("disable_q"):
                        attribute_list.append("disable_q")
                    if not node.has_attribute("disable_dq0"):
                        attribute_list.extend(
                            f"disable_dq{i}" for i in range(input_node_length)
                        )
                    qdq_selectors[node_name] = attribute_list
                    for _selector in qdq_selectors[node_name]:
                        node.setdefault_attribute(_selector, 0)

                # if node of orig_type is binary op and if incoming nodes is one, need to append disable_dq1 = 0 into attributes list
                if (
                    input_node_length == 1
                    and (
                        node.get_op_type() in BINARY_OPS
                        or "orig_type" in attributes
                        and attributes["orig_type"] in BINARY_OPS
                    )
                    and node("A").check_tensor()
                    and node("B").check_tensor()
                ):
                    if "disable_dq1" not in attributes:
                        qdq_selectors[node_name].insert(
                            input_node_length, f"disable_dq1"
                        )
                        node.setdefault_attribute("disable_dq1", 0)
            elif self._is_linear_non_linear_op(node):
                if node_name not in qdq_selectors:
                    qdq_selectors[node_name] = ["disable_q"]
                    node.setdefault_attribute("disable_q", 0)
            else:
                print(
                    f"\033[33mWarning: Node='{node_name}' has unknown op_type='{op_type}' in qdq_tagging.\033[0m"
                )
        return

    def check_src_dst_floating_ops(self):
        """check if both input and output nodes are all float dtype ops"""
        # node_list without noop, runtime, possibly dataflow op
        for node in self.node_list:
            is_float_reader = self.is_float_reading_op(node)
            is_float_writer = self.is_float_writing_op(node)
            for tensor in node.get_inputs():
                # check if it's a proper input and not scale or weight
                if tensor.check_initializer():
                    continue
                if node not in self.output_dict[tensor.get_name()]:
                    self.output_dict[tensor.get_name()].append(
                        node
                    )  # node is a reader of tensor
                # add non skippable nodes to input_dict
                if is_float_reader:
                    self.handle_prev_skip_nodes(tensor)

            for tensor in node.get_outputs():
                if tensor.check_initializer():
                    continue
                if node not in self.input_dict[tensor.get_name()]:
                    self.input_dict[tensor.get_name()].append(
                        node
                    )  # node is a writer of tensor
                # add non skippable nodes to output_dict
                if is_float_writer:
                    self.handle_next_skip_nodes(tensor)

        # update flag_dict for all nodes in node_list
        for node in self.node_list:
            is_float_reader = self.is_float_reading_op(node)
            is_float_writer = self.is_float_writing_op(node)
            for tensor in node.get_inputs():
                # check if it's a proper input and not scale or weight
                if tensor.check_initializer():
                    continue
                tensor_name = tensor.get_name()
                # when node is float reader, check if all non-skippable writer of the node input tensor can write floats, store the result to flag_dict[tensor_name]
                # only when reader node and all writer nodes are non-linear ops, set flag_dict[tensor_name] to True
                all_writers = self.input_dict[tensor_name]
                if self.flag_dict.get(tensor_name, True):
                    self.flag_dict[tensor_name] = is_float_reader and bool(
                        all_writers
                        and all(
                            self.is_float_writing_op(writer) for writer in all_writers
                        )
                    )
            for tensor in node.get_outputs():
                if tensor.check_initializer():
                    continue
                tensor_name = tensor.get_name()
                # only when writer node and all reader nodes are non-linear ops, set flag_dict[tensor_name] to True
                all_readers = self.output_dict[tensor_name]
                if self.flag_dict.get(tensor_name, True):
                    self.flag_dict[tensor_name] = is_float_writer and bool(
                        all_readers
                        and all(
                            self.is_float_reading_op(reader) for reader in all_readers
                        )
                    )
        return

    def handle_next_skip_nodes(self, tensor: OutputTensor):
        """
        skip only runtime, noop, dataflow. Unfused node won't be skipped.
        """
        tensor_name = tensor.get_name()
        current_nodes = tensor.get_readers()
        if any(node.check_node() == False for node in current_nodes):
            # all tensor reader should be node type
            self.flag_dict[tensor_name] = False
            return
        if any(self.is_skip_op(reader.require_node()) for reader in current_nodes):
            for current_node in current_nodes:
                self.collect_nodes_in_path(
                    current_node.require_node(), tensor_name, False
                )
        return

    def handle_prev_skip_nodes(self, tensor: InputTensor):
        """
        skip only runtime, noop, dataflow. Unfused node won't be skipped.
        """
        tensor_name = tensor.get_name()
        try:
            current_node = tensor.require_tensor().get_writer().require_node()
        except MatcherError as m_e:
            # all tensor writer should be node type
            self.flag_dict[tensor_name] = False
            return
        if self.is_skip_op(current_node):
            self.collect_nodes_in_path(current_node, tensor_name, True)
        return

    def collect_nodes_in_path(
        self, _current: Node, tensor_name: str, collect_writers: bool
    ):
        """
        Depth-First Search for non-skip ops that write the tensor_name of current node if collect_writers is True, otherwise search for non-skip ops that consume output of current node.
        """
        if self.is_skip_op(_current):
            node_names_in_path = (
                self._in_nodes_dict[_current.get_name()]
                if collect_writers
                else self._out_nodes_dict[_current.get_name()]
            )
            for node_name in node_names_in_path:
                _current = self._nodes_dict[node_name]
                self.collect_nodes_in_path(_current, tensor_name, collect_writers)
        else:
            if collect_writers:
                if _current not in self.input_dict[tensor_name]:
                    self.input_dict[tensor_name].append(_current)
            else:
                if _current not in self.output_dict[tensor_name]:
                    self.output_dict[tensor_name].append(_current)
        return

    def check_tensor_src_dst_has_same_tensor(self):
        for tensor_name in self.flag_dict:
            if (
                self.flag_dict[tensor_name]
                and self.input_dict[tensor_name]
                and self.output_dict[tensor_name]
            ):
                # Add_qdq->(tensor_a)->Reshape_noop->(tensor_b)->Softmax_qdq
                # Add_qdq->(tensor_a)->Softmax_qdq
                # Add_qdq->(tensor_b)->Softmax_qdq
                # FIXME, should a list. not only get the first node
                producer_list = self.input_dict[tensor_name]
                if any(
                    len(producer_node.get_inputs()) < 2
                    for producer_node in producer_list
                ):
                    self.flag_dict[tensor_name] = False
                    print(
                        f"warning: not all nodes producing {tensor_name} have at least 2 input tensors"
                    )
                    continue

                same_last_two_values = all(
                    producer_list[0].get_inputs()[-1] == producer_node.get_inputs()[-1]
                    and producer_list[0].get_inputs()[-2]
                    == producer_node.get_inputs()[-2]
                    for producer_node in producer_list
                )

                if not same_last_two_values:
                    self.flag_dict[tensor_name] = False
                    print(
                        f"warning: not all producer node of {tensor_name} have the same last two tensor input value"
                    )
                    continue

                zero_point_input = list(producer_list[0].get_inputs())[-1]
                scale_input = list(producer_list[0].get_inputs())[-2]
                for output_node in self.output_dict[tensor_name]:
                    if (
                        zero_point_input in output_node.get_inputs()
                        and scale_input in output_node.get_inputs()
                    ):
                        zero_point_output = next(
                            (
                                input
                                for input in list(output_node.get_inputs())
                                if input.get_name() == zero_point_input.get_name()
                            ),
                            None,
                        )

                        scale_output = next(
                            (
                                input
                                for input in list(output_node.get_inputs())
                                if input.get_name() == scale_input.get_name()
                            ),
                            None,
                        )
                        # only when the last two inputs of input node exists in output node, and value are the same, keep flag_dict[tensor_name] to true
                        self.flag_dict[tensor_name] = self.check_tensor_equal_value(
                            zero_point_input.require_initializer(),
                            zero_point_output.require_initializer(),
                        ) and self.check_tensor_equal_value(
                            scale_input.require_initializer(),
                            scale_output.require_initializer(),
                        )
                    else:
                        # if the last two inputs of input node doesn't exist in output node, set flag_dict[tensor_name] to false
                        self.flag_dict[tensor_name] = False

    def check_initializer_scale_zeropoint_same(self, node: Node) -> bool:
        """Check if all dq and q tensor in fused input tensor list, they all have scale and zeropoint tensor"""
        all_scale_tensor = list(
            filter(
                lambda x: x.get_name().endswith("_scale") or "_scale_" in x.get_name(),
                node.get_inputs(),
            )
        )
        all_zero_point_tensor = list(
            filter(
                lambda x: x.get_name().endswith("_zero_point")
                or "_zero_point_" in x.get_name(),
                node.get_inputs(),
            )
        )
        if all_scale_tensor and all_zero_point_tensor:
            scale_value = all_scale_tensor[0].get_initializer_array()
            zero_pt_value = all_zero_point_tensor[0].get_initializer_array()
            has_same_scale_value = all(
                tensor.get_initializer_array() == scale_value
                for tensor in all_scale_tensor[1:]
            )
            has_same_zero_pt_value = all(
                tensor.get_initializer_array() == zero_pt_value
                for tensor in all_zero_point_tensor[1:]
            )
            return has_same_scale_value and has_same_zero_pt_value
        return False

    def update_qdq_tagging(self):
        # set the disable_q and disable_dq{index} to 1, if value of tensor in flag_dict is true, and tensor has input and output nodes
        for tensor_name in self.flag_dict:
            if (
                not self.flag_dict[tensor_name]
                or not self.input_dict[tensor_name]
                or not self.output_dict[tensor_name]
            ):
                continue

            producer_list = self.input_dict[tensor_name]
            # only set disable_q to 1 if tensor has only one producer node.
            # due to skipping the writer, tensor may have multiple producers in input_dict.
            # Add_qdq_1->(tensor_a1)->Reshape_noop->(tensor_b1)->concat
            # Add_qdq_2->(tensor_a2)->Reshape_noop->(tensor_b2)->concat->tensor_c ->div_qdq

            # both tensor_a and tensor_b are stored in flag_dict as true
            # input_dict:
            # [tensor_c, [Add_qdq_1, Add_qdq_2]]
            # [tensor_a1, Add_qdq_1]
            # [tensor_a2, Add_qdq_2]
            # when tensor_name is tensor_c, producer_list is [Add_qdq_1, Add_qdq_2], we neither set the disable_q for Add_qdq_1 nor Add_qdq_2
            # the disable_q of Add_qdq_1 or Add_qdq_2 will be set to 1 when tensor_name is tensor_a1 or tensor_a2
            if len(producer_list) == 1:
                out_tensors = producer_list[0].get_outputs()
                if all(
                    self.flag_dict[out_tensor.get_name()] for out_tensor in out_tensors
                ):
                    producer_list[0].set_attribute("disable_q", 1)

            # Add_qdq_1->(tensor_a1)->Div_1_qdq
            #          ->(tensor_a1)->Div_2_qdq
            #          ->(tensor_a1)->Div_3_qdq
            # Add_qdq_2->(tensor_a2)->Div_1_qdq
            # Div_1_qdq->input: tensor_a1, tensor_a2
            # when tensor_name is tensor_a1, disable_dq0 is set to 1
            # when tensor_name is tensor_a2, disable_dq1 is set to 1
            for output_node in self.output_dict[tensor_name]:
                index_counter = 0
                for _input in output_node.get_inputs():
                    if _input.check_initializer():
                        continue
                    if len(producer_list) == 1 and tensor_name == _input.get_name():
                        # only set disable_dq{i} to 1 when input node of output node is the same as producer node of tensor_name
                        output_node.set_attribute(f"disable_dq{index_counter}", 1)
                        break
                    index_counter += 1
        return


def main(args):
    model_path = args["model_path"]
    qdq_optimization = bool(int(args["qdq_optimization"]))
    load_data = int(args["load_data"])

    # load model
    model = onnx.load_model(model_path, load_external_data=load_data)
    out_model_path = model_path[:-5] + "_updated.onnx"
    qdq_tagging_obj = QdqTagging(model, qdq_optimization)
    qdq_tagging_obj.tag_qdq_nodes()

    save_model(model, out_model_path, external_data=load_data)


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "-d",
        "--debug",
        help="Print lots of debugging statements",
        action="store_const",
        dest="loglevel",
        const=logging.DEBUG,
    )
    parser.add_argument(
        "-mp",
        "--model_path",
        help="path to onnx model and output destination.Required Field",
    )
    parser.add_argument(
        "-ld",
        "--load_data",
        help="path to additional model data file for large models. Optional Field. Default value = 0",
        default="0",
    )
    parser.add_argument(
        "--qdq_optimization",
        type=int,
        choices=[0, 1],
        default=0,
        help="Enable QDQ optimization at end of L1 fusion. Default is 0 (disabled).",
    )

    args = parser.parse_args()
    if not args.model_path:
        parser.error(
            "Please pass path/to/onnx/model using -mp or --model_path flags.\npython3 parse_onnx_model.py --help\n\t\t\tfor further info."
        )
    logging.basicConfig(level=args.loglevel)
    logging.debug("Debug mode is enabled!")

    main(vars(args))
