# fmt: on
"""
QDQ Cleanup Pass - AIESW-18212

This module implements a cleanup pass that removes int16 QDQ (QuantizeLinear/DequantizeLinear)
nodes when they appear adjacent to int8 QDQ nodes in the graph.

The optimization targets patterns like:
- int8 QDQ followed by int16 QDQ: Keep int8 QDQ, remove int16 QDQ
- int16 QDQ followed by int8 QDQ: Keep int8 QDQ, remove int16 QDQ

This pass should be executed before NHWC conversion as a cleanup optimization.
"""

import onnx
from onnx import TensorProto
from OGOAT.src.utils.context import Logger
from OGOAT.src.L1_fusion.py_match.model_dict import ModelDict
from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import OnnxOpsWrapper
from OGOAT.src.L1_fusion.py_match.nodes_tensors import Node, WalkCfgPlain
from OGOAT.src.L1_fusion.py_match.checkers import opType, DTypes


def check_qdq_same(node_quantize: Node, node_dequantize: Node) -> bool:
    # check inputs: y_scale, y_zero_point, x_scale, x_zero_point,
    y_scale = node_quantize("y_scale").require_tensor().get_initializer_array()
    y_zero_point = (
        node_quantize("y_zero_point").require_tensor().get_initializer_array()
    )
    x_scale = node_dequantize("x_scale").require_tensor().get_initializer_array()
    x_zero_point = (
        node_dequantize("x_zero_point").require_tensor().get_initializer_array()
    )
    return y_scale == x_scale and y_zero_point == x_zero_point


class QDQCleanupOptimizer:
    def __init__(self, model: onnx.ModelProto, logger: Logger):
        self.md = ModelDict(model, OnnxOpsWrapper())

    def cleanup(self) -> int:
        removed_count = 0
        md = self.md
        for node_name in md.get_node_names():
            if node_name not in md.get_node_names():
                continue
            # check node op types
            q_pred = Node(md, WalkCfgPlain(), node_name)
            if not q_pred.check(opType.QuantizeLinear):
                continue
            if not q_pred("y").check_node():
                continue
            dq_pred = q_pred("y").get_non_tensor()
            if not dq_pred.check(opType.DequantizeLinear):
                continue

            if not check_qdq_same(q_pred, dq_pred):
                continue
            if not dq_pred("y").check_node():
                continue
            q_post = dq_pred("y").get_non_tensor()
            if not q_post.check(opType.QuantizeLinear):
                continue
            dq_post = q_post("y").get_non_tensor()
            if not dq_post.check(opType.DequantizeLinear):
                continue
            if not check_qdq_same(q_post, dq_post):
                continue

            q_pred_zp = q_pred("y_zero_point").require_tensor()
            q_post_zp = q_post("y_zero_point").require_tensor()
            new_input_name = ""
            old_output_name = ""
            if q_pred_zp.check(DTypes("uint8", "int8")) and q_post_zp.check(
                DTypes("uint16", "int16")
            ):
                #                                     /----------------------------------------------------------------------------\
                # fp32-> Q-> int8-> DQ-> fp32[new_input_name]-> Q[q_unconnected]-> int16-> DQ[q_unconnected]-> fp32[old_input_name]-> reader
                # fp32-> Q-> int8-> DQ-> fp32[new_input_name]-> reader
                new_input_name = dq_pred("y").get_name()
                old_input_name = dq_post("y").get_name()
                q_unconnected = q_post
                dq_unconnected = dq_post
                if not md.get_reader_names(old_input_name):
                    # if old_input_name is graph output, need to update new_input_name as new graph output
                    for output in md._model.graph.output:
                        if output.name == old_input_name:
                            md._model.graph.output.remove(output)
                            break
                    new_output = onnx.helper.make_tensor_value_info(
                        new_input_name,
                        dq_pred("y").get_dtype(),
                        dq_pred("y").get_shape(),
                    )
                    md._model.graph.output.append(new_output)

                    # delete qdq
                    md.remove_node(q_unconnected.get_name())
                    md.remove_node(dq_unconnected.get_name())
                    removed_count += 1
                    continue
            elif q_pred_zp.check(DTypes("uint16", "int16")) and q_post_zp.check(
                DTypes("uint8", "int8")
            ):
                #           /--------------------------------------------------------------------------------\
                # fp32[new_input_name]-> Q[q_unconnected]-> int16-> DQ[q_unconnected]-> fp32 [old_input_name]-> Q[reader]-> int8-> DQ-> fp32-> ...
                # fp32[new_input_name]-> Q[reader]-> int8-> DQ-> fp32-> ...
                new_input_name = q_pred("x").get_name()
                old_input_name = dq_pred("y").get_name()
                q_unconnected = q_pred
                dq_unconnected = dq_pred
            else:
                continue

            for reader in md.get_reader_names(old_input_name):
                md.replace_input(reader, old_input_name, new_input_name)

            # delete qdq
            md.remove_node(q_unconnected.get_name())
            md.remove_node(dq_unconnected.get_name())
            removed_count += 1
        return removed_count


def cleanup_int16_qdq(model: onnx.ModelProto, logger: Logger) -> int:
    """
    Remove int16 QDQ pairs that are adjacent to int8 QDQ pairs.

    This function identifies patterns where int8 Q->DQ is directly followed by int16 Q->DQ:
    - Pattern: op1 -> Q(int8) -> DQ(int8) -> Q(int16) -> DQ(int16) -> op2
    - Result:  op1 -> Q(int8) -> DQ(int8) -> op2

    Or vice versa (int16 followed by int8, remove int16):
    - Pattern: op1 -> Q(int16) -> DQ(int16) -> Q(int8) -> DQ(int8) -> op2
    - Result:  op1 -> Q(int8) -> DQ(int8) -> op2

    Returns the number of int16 QDQ pairs removed.
    """
    optimizer = QDQCleanupOptimizer(model, logger)
    return optimizer.cleanup()


def run_qdq_cleanup(
    model: onnx.ModelProto, logger: Logger, enable_cleanup: bool = True
) -> onnx.ModelProto:
    """
    Entry point for QDQ cleanup pass.

    Args:
        model: ONNX model to process
        logger: Logger instance for output
        enable_cleanup: If False, skip the cleanup (controlled by --qdq_int16_cleanup flag)

    Returns:
        Modified model with int16 QDQ pairs removed (if enabled)
    """
    if not enable_cleanup:
        logger.info("Int16 QDQ cleanup pass disabled (--qdq_int16_cleanup 0)")
        return model

    cleanup_int16_qdq(model, logger)
    return model
