# fmt: on
import collections
import numpy as np
import onnx
from onnx import helper
from typing import Any, Iterable, Literal, Optional
from dataclasses import dataclass

from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import (
    OnnxOpsWrapper,
    dtype_to_ops_type,
)
from OGOAT.src.L1_fusion.L1_utils.utils import (
    construct_dict,
)
from OGOAT.src.L1_fusion.L1_utils.utils import (
    onnxTensor_dtype_to_np_dtype,
    onnxTensor_np_dtype_to_dtype,
    onnxTensorProto_to_array,
    onnxTensorProto_from_array,
)
from OGOAT.src.utils.context import Logger


@dataclass
class QdqInfo:
    scale: float
    scale_dtype: int
    zero_point: float
    zero_point_dtype: int
    output_type: str


class LostInitializerHelper:
    """
    Helper class for detecting initializer which gets lost during fusion.
    """

    def __init__(self) -> None:
        self.reported_initializers: dict[str, list[str]] = collections.defaultdict(list)
        self.reset()

    def reset(self) -> None:
        self.counter: set[str, int] = collections.defaultdict(int)

    def inc(self, initializer_name: str) -> None:
        self.counter[initializer_name] += 1

    def dec(self, initializer_name: str) -> None:
        self.counter[initializer_name] -= 1

    def check(self, matcher_name: str, logger: Logger) -> None:
        for initializer in self.counter:
            # we cannot precisly calculate the number of nodes reading an initializer
            # e.g. the zero_point and scale of a DequanizeLinear node is only read once, but how often they are actaully used depend on the output y and how how many nodes use it as an input
            if (
                self.counter[initializer] < 0
                and initializer not in self.reported_initializers[matcher_name]
            ):
                self.reported_initializers[matcher_name].append(initializer)
                logger.debug(
                    f"warning: possibly losing initializer with name='{initializer}', counter={self.counter[initializer]}"
                )


class ModelDict:
    def __init__(
        self,
        model: onnx.ModelProto,
        onnx_ops: OnnxOpsWrapper,
    ) -> None:
        self._model = model
        self._onnx_ops = onnx_ops

        (
            # initializer name -> ONNX initializer (onnx.TensorProto)
            self._ini_dict,
            # node name -> ONNX node (onnx.NodeProto)
            self._nodes_dict,
            # node name -> list of names of nodes writing a tensor read by this node
            self._in_nodes_dict,
            # node name -> list of names of nodes reading a tensor written by this node
            self._out_nodes_dict,
            # element name -> ONNX value info (onnx.ValueInfoProto; entries in that dictionary are optional)
            self._value_info_dict,
            # input name -> ONNX value info (onnx.ValueInfoProto)
            self._input_names,
            # output name -> ONNX value info (onnx.ValueInfoProto)
            self._output_names,
        ) = construct_dict(model)
        # tensor name -> set of names of nodes reading this tensor
        self._readers: dict[str, set[str]] = collections.defaultdict(set)
        # tensor name -> set of names of nodes writing this tensor
        # operator type -> set of names of nodes with the operator type
        self._op_type_dict: dict[str, set[str]] = collections.defaultdict(set)
        # (while working on the graph, it is temporarily okay to have multiple writers)
        self._writers: dict[str, set[str]] = collections.defaultdict(set)
        for node_name, node in self._nodes_dict.items():
            for tensor_name in node.input:
                if tensor_name == "":
                    continue  # unused optional input
                self._readers[tensor_name].add(node_name)
            for tensor_name in node.output:
                self._writers[tensor_name].add(node_name)

            self._op_type_dict[node.op_type].add(node_name)

        # determine overall "activation data type" of graph
        zero_point_dtypes: dict[str, int] = collections.defaultdict(lambda: 0)

        for op_type in ("QuantizeLinear", "DequantizeLinear"):
            for name in self._op_type_dict.get(op_type, []):
                node = self._nodes_dict[name]
                # Skip if any of the node input is missing
                if len(node.input) < 3:
                    continue
                # Skip if node x/y iput is initializer,not an actiavation input
                if self.is_initializer(node.input[0]):
                    continue
                zero_point = node.input[2]
                dtype = self.get_data_type(zero_point)
                zero_point_dtypes[dtype] += 1

        # default activation type is chosen by majority vote
        self._activation_type: Optional[str] = None
        self._activation_dtype_sorted_list: list[str] = []

        if zero_point_dtypes:
            self._activation_type = max(zero_point_dtypes, key=zero_point_dtypes.get)
            self._activation_dtype_sorted_list = sorted(
                zero_point_dtypes, key=zero_point_dtypes.get, reverse=True
            )

        # nodes with no output consumed or any input not produced
        self._nodes_unconnected: set[str] = set()
        self._lost_ini_helper = LostInitializerHelper()

    def _check_node_unconnected(self, node_name: str) -> bool:
        """
        Check if the node with the passed name has no output consumed or any
        input not produced and return True if so.
        """
        for input_name in self._nodes_dict[node_name].input:
            if input_name == "":
                continue  # unused optional input
            if (
                not self._writers[input_name]
                and input_name not in self._ini_dict
                and input_name not in self._input_names
            ):
                return True
        for output_name in self._nodes_dict[node_name].output:
            if self._readers[output_name] or output_name in self._output_names:
                return False
        return True

    def _get_consumed_tensors(self) -> set[str]:
        """
        Return set of names of tensors read by a node or output as a graph
        output.
        """
        return {
            tensor_name
            for tensor_name, reader_names in self._readers.items()
            if reader_names
        } | self._output_names.keys()

    def _get_produced_tensors(self) -> set[str]:
        """
        Return set of names of tensors written by a node, available from an
        initializer, or available as graph input.
        """
        return (
            {
                tensor_name
                for tensor_name, writer_names in self._writers.items()
                if writer_names
            }
            | self._ini_dict.keys()
            | self._input_names.keys()
        )

    def _update_in_dict_node(self, node_name: str) -> None:
        node = self._nodes_dict[node_name]
        in_node_names: list[str] = []
        for input_ in node.input:
            if input_ == "":
                continue  # unused optional input
            for writer in self._writers[input_]:
                if writer not in in_node_names:
                    in_node_names.append(writer)
        self._in_nodes_dict[node_name] = in_node_names

    def _update_in_dict_nodes(self, node_names: Iterable[str]) -> None:
        for node_name in node_names:
            self._update_in_dict_node(node_name)

    def _update_in_out_dict_node(self, node_name: str) -> None:
        self._update_in_dict_node(node_name)
        self._update_out_dict_node(node_name)

    def _update_in_out_dict_nodes(self, node_names: Iterable[str]) -> None:
        for node_name in node_names:
            self._update_in_out_dict_node(node_name)

    def _update_node_unconnected(self, node_name: str) -> None:
        """
        Check if the node with the passed name has no output consumed or any
        input not produced and update self._nodes_unconnected (i.e. add or
        remove this node).
        """
        unconnected = self._check_node_unconnected(node_name)
        if unconnected:
            self._nodes_unconnected.add(node_name)
        else:
            self._nodes_unconnected.discard(node_name)

    def _update_nodes_unconnected(self, node_names: Iterable[str]) -> None:
        for node_name in node_names:
            self._update_node_unconnected(node_name)

    def _update_out_dict_node(self, node_name: str) -> None:
        node = self._nodes_dict[node_name]
        out_node_names: list[str] = []
        for output in node.output:
            for reader in self._readers[output]:
                if reader not in out_node_names:
                    out_node_names.append(reader)
        self._out_nodes_dict[node_name] = out_node_names

    def _update_out_dict_nodes(self, node_names: Iterable[str]) -> None:
        for node_name in node_names:
            self._update_out_dict_node(node_name)

    def add_initializer(self, initializer: onnx.TensorProto) -> None:
        """
        Adds an new initializer to the model
        """
        if initializer.name in self._ini_dict:
            if initializer != self._ini_dict[initializer.name]:
                raise Exception(
                    f"initializer with name={initializer.name} already exists."
                )
            else:
                return

        self._model.graph.initializer.append(initializer)
        # initializer is copied while appending to the onnx model
        self._ini_dict[initializer.name] = self._model.graph.initializer[-1]
        # update potentially unconnected nodes that are now connected
        for reader in self._readers[initializer.name]:
            self._update_node_unconnected(reader)

    def update_initializer_value(self, initializer_name: str, new_value: any) -> None:
        # Update an initializer value
        initializer = self.get_initializer(initializer_name)
        data_type = helper.tensor_dtype_to_np_dtype(initializer.data_type)
        new_values = np.array(new_value, dtype=data_type)
        new_values_list = new_values.flatten().tolist()
        if initializer.float_data:
            initializer.float_data[:] = new_values_list
        elif initializer.int32_data:
            initializer.int32_data[:] = new_values_list
        elif initializer.int64_data:
            initializer.int64_data[:] = new_values_list
        elif initializer.double_data:
            initializer.double_data[:] = new_values_list
        elif initializer.uint64_data:
            initializer.uint64_data[:] = new_values_list
        elif initializer.raw_data:
            initializer.raw_data = new_values.tobytes()
        else:
            raise ValueError(
                f"Unsupported data format for initializer '{initializer_name}'"
            )

    def add_opschema(
        self,
        name: str,
        inputs: dict[str, str],
        outputs: dict[str, str],
        domain: str,
        version: int,
    ) -> None:
        """
        Creates an OpSchema with the given name.
        Derives the number, names and types of inputs/outputs from the concrete tensors.
        """
        assert (
            name not in self._onnx_ops.get_operator_names()
        ), f"operator {name} is already defined"  # sanity check
        inputs_parameters: list[onnx.defs.OpSchema.FormalParameter] = []
        for input_name, tensor_name in inputs.items():
            type_str = dtype_to_ops_type(
                self.get_data_type(tensor_name), self.get_shape(tensor_name)
            )
            parameter = onnx.defs.OpSchema.FormalParameter(
                name=input_name,
                type_str=type_str,
                # TODO How can we handle variadic and optional parameter
                # param_option=
            )
            inputs_parameters.append(parameter)

        outputs_parameters: list[onnx.defs.OpSchema.FormalParameter] = []
        for output_name, tensor_name in outputs.items():
            type_str = dtype_to_ops_type(
                self.get_data_type(tensor_name), self.get_shape(tensor_name)
            )
            parameter = onnx.defs.OpSchema.FormalParameter(
                name=output_name,
                type_str=type_str,
                # TODO How can we handle variadic and optional parameter
                # param_option=
            )
            outputs_parameters.append(parameter)

        schema = onnx.defs.OpSchema(
            name=name,
            domain=domain,
            since_version=version,
            inputs=inputs_parameters,
            outputs=outputs_parameters,
        )
        self.get_onnx_ops().register_schema(schema)

    def add_transposed_initializer(
        self,
        initializer_name: str,
        initializer_name_new: str,
        permutation: Optional[list[int]] = None,
    ) -> None:
        tensor, dtype = onnxTensorProto_to_array(
            self.get_initializer(initializer_name), transpose=1, permutation=permutation
        )
        initializer_new = onnxTensorProto_from_array(
            tensor, initializer_name_new, og_dtype=dtype
        )
        self.add_initializer(initializer_new)

    def add_split_initializer(
        self, initializer_name: str, split_factor: int
    ) -> list[str] | Literal["scalar", "impossible_split"]:
        initializer_np, dtype = onnxTensorProto_to_array(
            self.get_initializer(initializer_name)
        )
        axis = -1

        # a scalar or a 1-D tensor of size 1
        if initializer_np.ndim == 0 or (
            initializer_np.ndim == 1 and initializer_np.shape[0] == 1
        ):
            return "scalar"

        if initializer_np.shape[axis] % split_factor != 0:
            return "impossible_split"

        split_initializers = np.split(initializer_np, split_factor, axis)

        initializer_names = []
        for i, initializer in enumerate(split_initializers):
            name = initializer_name + "_" + str(i)
            initializer_new = onnxTensorProto_from_array(
                initializer, name, og_dtype=dtype
            )
            self.add_initializer(initializer_new)
            initializer_names.append(name)
        return initializer_names

    def remove_unused_ini_nodes(self):
        ini_unused = set(self._ini_dict.keys()) - self._get_consumed_tensors()
        for ini in ini_unused:
            self._model.graph.initializer.remove(self._ini_dict[ini])
            del self._ini_dict[ini]

    def add_quantization_information(self, tensor_name: str, qdq_info: QdqInfo) -> None:
        """
        Adds quantization information for a tensor(name=tensor_name).
        Fails if quantization information for this tensor are stored.
        """
        assert (
            self.get_quantization_information(tensor_name) is None
        ), f"quantization information for tensor(name={tensor_name}) already exists"  # sanity check

        annotation = onnx.TensorAnnotation()
        annotation.tensor_name = tensor_name
        annotation.quant_parameter_tensor_names.append(
            onnx.StringStringEntryProto(key="SCALE_TENSOR", value=str(qdq_info.scale))
        )
        annotation.quant_parameter_tensor_names.append(
            onnx.StringStringEntryProto(
                key="SCALE_TENSOR_DTYPE", value=str(qdq_info.scale_dtype)
            )
        )
        annotation.quant_parameter_tensor_names.append(
            onnx.StringStringEntryProto(
                key="ZERO_POINT_TENSOR", value=str(qdq_info.zero_point)
            )
        )
        annotation.quant_parameter_tensor_names.append(
            onnx.StringStringEntryProto(
                key="ZERO_POINT_TENSOR_DTYPE", value=str(qdq_info.zero_point_dtype)
            )
        )
        annotation.quant_parameter_tensor_names.append(
            onnx.StringStringEntryProto(key="OUTPUT_TYPE", value=qdq_info.output_type)
        )
        self._model.graph.quantization_annotation.append(annotation)

    def append_attribute(self, node_name: str, attr_name: str, attr_value: Any) -> None:
        """
        Append a new attrubute to the node.
        """
        onnx_node = self.get_onnx_node(node_name)
        onnx_node.attribute.append(onnx.helper.make_attribute(attr_name, attr_value))

    def append_input(self, node_name: str, input_name: str) -> None:
        """
        Appends existing input name to the node
        """
        onnx_node = self.get_onnx_node(node_name)
        onnx_node.input.append(input_name)

        self._readers[input_name].add(node_name)

        self._update_node_unconnected(node_name)
        update_nodes = self._writers[input_name]
        self._update_nodes_unconnected(update_nodes)
        self._update_out_dict_nodes(update_nodes)

    def append_node(
        self,
        name: str,
        op_type: str,
        inputs: dict[str, str],
        outputs: dict[str, str],
        domain: str = "ai.onnx.contrib",
        since_version: int = 1000,
    ) -> None:
        """
        Creates a new onnx node and appends it to the graph.
        """
        if name in self._nodes_dict:
            raise Exception(f"Node with name = {name} already exists.")
        onnx_node = onnx.helper.make_node(
            op_type=op_type, inputs=[], outputs=[], name=name, domain=domain
        )
        self._model.graph.node.append(onnx_node)
        self._nodes_dict[name] = self._model.graph.node[-1]

        self._in_nodes_dict[name] = []
        self._out_nodes_dict[name] = []
        # hint: does not return onnx_node (append creates a copy!)
        # id(self._model.graph.node[-1]) != id(onnx_node)

        if op_type not in self.get_onnx_ops().get_operator_names():
            self.add_opschema(op_type, inputs, outputs, domain, since_version)

        for tensor_name in inputs.values():
            if tensor_name == "":
                continue  # unused optional input
            self.append_input(name, tensor_name)  # Process tensor_name directly

        for _, tensor_name in outputs.items():
            self.append_output(name, tensor_name)

        for input_parameter in inputs:
            input_argument = inputs[input_parameter]
            if self.is_initializer(input_argument):
                self._lost_ini_helper.inc(input_argument)

        self._op_type_dict[op_type].add(name)

    def append_output(self, node_name: str, output_name: str) -> None:
        """
        Appends existing output to the node

        Output can only be produced by one node. If some other node has `output_name` as an output, it is _not_ going to be removed from the node.
        Instead the original_writer name is returned and the caller has to make sure it is going to be removed
        """
        onnx_node = self.get_onnx_node(node_name)
        onnx_node.output.append(output_name)

        self._writers[output_name].add(node_name)

        self._update_node_unconnected(node_name)
        update_nodes = self._readers[output_name]
        self._update_nodes_unconnected(update_nodes)
        self._update_in_dict_nodes(update_nodes)

    def change_op_type(self, node_name: str, new_op_type: str) -> None:
        """
        Change the op_type of a node.
        Updates both the ONNX node and the internal op_type_dict tracking.
        """
        onnx_node = self.get_onnx_node(node_name)
        old_op_type = onnx_node.op_type

        # Update the ONNX node's op_type
        onnx_node.op_type = new_op_type

        # Update the op_type_dict tracking
        self._op_type_dict[old_op_type].discard(node_name)
        if not self._op_type_dict[old_op_type]:  # Remove empty set
            del self._op_type_dict[old_op_type]
        self._op_type_dict[new_op_type].add(node_name)

    def extract_filtered_model(self, node_names: Iterable[str]) -> "ModelDict":
        """
        Extract a filtered copy of the model that contains just the passed
        nodes.
        """
        extracted_model = onnx.ModelProto()
        extracted_model.ir_version = self._model.ir_version
        extracted_model.producer_name = self._model.producer_name
        extracted_model.producer_version = self._model.producer_version
        extracted_model.domain = self._model.domain
        extracted_model.model_version = self._model.model_version
        extracted_model.doc_string = self._model.doc_string
        extracted_model.opset_import.extend(self._model.opset_import)
        extracted_model.graph.name = self._model.graph.name
        # collect output tensors produced by the selected nodes
        produced_tensors: set[str] = set()
        for node_name in node_names:
            node = self._nodes_dict[node_name]
            produced_tensors |= set(node.output)
        # add copies of selected nodes to model and connect their inputs
        consumed_tensors: set[str] = set()
        initializers_added: set[str] = set()
        inputs_added: set[str] = set()
        for node_name in node_names:
            node = self._nodes_dict[node_name]
            extracted_model.graph.node.append(node)  # node is copied
            consumed_tensors |= set(node.input)
            for in_tensor_name in node.input:
                # nothing to do if input is produced by another node,
                # initializer already added or input already added
                if (
                    in_tensor_name in produced_tensors
                    or in_tensor_name in initializers_added
                    or in_tensor_name in inputs_added
                ):
                    continue
                # use initializer if present
                initializer = self._ini_dict.get(in_tensor_name)
                if initializer is not None:
                    extracted_model.graph.initializer.append(
                        initializer
                    )  # initializer is copied
                    initializers_added.add(initializer.name)
                    continue
                # add as input as graph input if not produced or initializer
                in_vi = self._value_info_dict[in_tensor_name]
                extracted_model.graph.input.append(in_vi)
                inputs_added.add(in_tensor_name)
        # make all produced tensors outputs
        for out_tensor_name in produced_tensors:
            out_vi = self._value_info_dict[out_tensor_name]
            extracted_model.graph.output.append(out_vi)
        # copy value infos
        for tensor_name in consumed_tensors | produced_tensors:
            vi = self._value_info_dict.get(tensor_name)
            if vi is not None:
                extracted_model.graph.value_info.append(vi)
        return ModelDict(extracted_model, self._onnx_ops)

    def get_sanitized_attribute_value(self, attribute: onnx.AttributeProto) -> Any:
        attribute_value = onnx.helper.get_attribute_value(attribute)

        # sanitize the attribute value by converting the data type when necessary
        if isinstance(attribute_value, bytes):
            attribute_value = attribute_value.decode("utf-8")
        if isinstance(attribute_value, onnx.TensorProto):
            data, _ = onnxTensorProto_to_array(attribute_value)
            attribute_value = data.tolist()

        return attribute_value

    def get_attributes(self, node_name: str) -> dict[str, Any]:
        attributes = self.get_onnx_node(node_name).attribute
        return {
            attr.name: self.get_sanitized_attribute_value(attr) for attr in attributes
        }

    def get_raw_attributes(self, node_name: str) -> dict[str, Any]:
        """
        Returns the raw attributes of the node
        """
        attributes = self.get_onnx_node(node_name).attribute
        return {attr.name: attr for attr in attributes}

    def get_raw_attribute(self, node_name: str, attribute_name: str) -> Any:
        """
        Returns the raw named attribute of a node
        """
        raw_attributes = self.get_raw_attributes(node_name)
        return raw_attributes.get(attribute_name)

    def get_attribute_value(self, node_name: str, attribute_name: str) -> Any:
        """
        Returns the value of the corresponding attribute, if the attribute with
        the given name does not exist and the node has an onnx schema,
        the default value is returned. Otherwise None is returned.
        """
        onnx_node = self.get_onnx_node_or_none(node_name)
        if onnx_node is None:
            return None

        value = self.get_attributes(node_name).get(attribute_name)
        if value is not None:
            return value

        if onnx.defs.has(onnx_node.op_type):
            attributes = onnx.defs.get_schema(onnx_node.op_type).attributes
            if attribute_name not in attributes:
                return None
            return self.get_sanitized_attribute_value(
                attributes[attribute_name].default_value
            )
        return None

    def get_data_type(self, tensor_name: str) -> str:
        """
        Return name of data type of tensor as string.
        Raises KeyError if tensor not found.
        """
        data_type = self.get_data_type_raw(tensor_name)
        # TODO raise Exception when not found
        return onnxTensor_dtype_to_np_dtype(data_type)

    def get_data_type_raw(self, tensor_name: str) -> Any:
        """
        Return raw data type of tensor.
        Raises KeyError if tensor not found.
        """
        data_type: Optional[int] = None
        if tensor_name in self._ini_dict:
            data_type = self._ini_dict[tensor_name].data_type
        elif tensor_name in self._value_info_dict:
            data_type = self._value_info_dict[tensor_name].type.tensor_type.elem_type
        # TODO check if the same input in both?
        return data_type

    def get_domain(self, node_name: str) -> str:
        """
        Get domain of node.
        Raises KeyError if node not found.
        """
        return self._nodes_dict[node_name].domain

    def get_graph_output_names(self) -> list[str]:
        return sorted(self._output_names.keys())

    def get_graph_input_names(self) -> list[str]:
        return sorted(self._input_names.keys())

    def get_initializer(self, initializer_name: str) -> onnx.TensorProto:
        return self._ini_dict[initializer_name]

    def get_initializer_value(self, initializer_name: str) -> Any:
        data, _ = onnxTensorProto_to_array(self.get_initializer(initializer_name))
        return data.tolist()

    def get_input_names(self, node_name: str) -> list[str]:
        """
        Return list of names of inputs of node.
        Raises KeyError if node not found.
        """
        return list(self._nodes_dict[node_name].input)

    def get_levels(self, node_names: list[str]) -> dict[str, int]:
        """
        Returns a dict mapping the node_names to their level in the graph.
        The level is the maximum path length of all paths from graph inputs to the node.
        """
        # stack of nodes for which the level needs to be calculated
        todo: list[str] = node_names.copy()
        # nodename to level
        level_map: dict[str, int] = {}
        while todo:
            current_name = todo.pop()
            # the current maximum level for the node; -1 indicates that for one of the inputs the level is not calculated yet
            max_input_level = 0
            for input_ in self.get_input_names(current_name):
                if input_ == "":  # optional input not provided
                    continue
                # graph input and initializer are at level 0 (implicitly covered by initialization)
                if self.is_graph_input(input_) or self.is_initializer(input_):
                    continue
                input_node_name = self.get_writer_name(input_)

                if input_node_name in level_map:
                    max_input_level = max(max_input_level, level_map[input_node_name])
                else:
                    todo.append(current_name)
                    todo.append(input_node_name)
                    max_input_level = -1
                    break

            if max_input_level == -1:
                continue

            level_map[current_name] = max_input_level + 1

        # remove entry from the list which weren't requested
        for name in list(level_map.keys()):
            if name not in node_names:
                del level_map[name]
        return level_map

    def get_node_names(self, op_type: Optional[str] = None) -> list[str]:
        """
        Returns a sorted list of names of nodes with the given op_type.
        If the op_type is None, returns all node names
        """
        if op_type is None:
            return sorted(self._nodes_dict.keys())
        return sorted(self._op_type_dict[op_type])

    def get_node_names_starts_with(self, op_type: str) -> dict[str, list[str]]:
        """
        Returns a list starts with op_type
        """
        op_types = [key for key in self._op_type_dict.keys() if key.startswith(op_type)]
        node_names: dict[str, list[str]] = {}
        for op_type in op_types:
            # MHA_2p0, MatMul_act_act as key, node names as value
            node_names[op_type] = sorted(self._op_type_dict[op_type])
        return node_names

    def get_onnx_node(self, node_name: str) -> onnx.NodeProto:
        onnx_node = self.get_onnx_node_or_none(node_name)
        assert (
            onnx_node is not None
        ), f"Onnx node {node_name} should be defined in the node dict"  # sanity check

        return onnx_node

    def get_onnx_node_or_none(self, node_name: str) -> Optional[onnx.NodeProto]:
        return self._nodes_dict.get(node_name)

    def get_onnx_ops(self) -> OnnxOpsWrapper:
        return self._onnx_ops

    def get_op_type(self, node_name: str) -> str:
        """
        Get operator type of node.
        Raises KeyError if node not found.
        """
        return self._nodes_dict[node_name].op_type

    def get_output_names(self, node_name: str) -> list[str]:
        """
        Return list of names of outputs of node.
        Raises KeyError if node not found.
        """
        return list(self._nodes_dict[node_name].output)

    def get_quantization_information(self, tensor_name: str) -> Optional[QdqInfo]:
        """
        Returns the quantization information stored for this tensor.
        Returns None if no information exists for the tensor.
        """

        def string_to_number(string: str):
            """
            Converts a string to an int or a float
            """
            try:
                return int(string)
            except ValueError:
                return float(string)

        def get(annotations, key: str) -> str:
            for annotation in annotations:
                if annotation.key == key:
                    return annotation.value

        for quantization_annotation in self._model.graph.quantization_annotation:
            if quantization_annotation.tensor_name == tensor_name:
                scale = get(
                    quantization_annotation.quant_parameter_tensor_names,
                    "SCALE_TENSOR",
                )
                scale_dtype = get(
                    quantization_annotation.quant_parameter_tensor_names,
                    "SCALE_TENSOR_DTYPE",
                )
                zero_point = get(
                    quantization_annotation.quant_parameter_tensor_names,
                    "ZERO_POINT_TENSOR",
                )
                zero_point_dtype = get(
                    quantization_annotation.quant_parameter_tensor_names,
                    "ZERO_POINT_TENSOR_DTYPE",
                )
                output_type = get(
                    quantization_annotation.quant_parameter_tensor_names,
                    "OUTPUT_TYPE",
                )
                return QdqInfo(
                    scale=string_to_number(scale),
                    scale_dtype=string_to_number(scale_dtype),
                    zero_point=string_to_number(zero_point),
                    zero_point_dtype=string_to_number(zero_point_dtype),
                    output_type=output_type,
                )

    def get_quantization_informations(self) -> dict[str, QdqInfo]:
        """
        Returns a dict for all stored quantization information mapping tensor names to information
        """
        infos: dict[str, QdqInfo] = {}
        for quantization_annotation in self._model.graph.quantization_annotation:
            infos[quantization_annotation.tensor_name] = (
                self.get_quantization_information(quantization_annotation.tensor_name)
            )
        return infos

    def get_reader_names(self, tensor_name: str) -> list[str]:
        """
        Return list names of readers of tensor.
        Raises KeyError if tensor not found.
        """
        return sorted(self._readers[tensor_name])

    def get_shape(self, tensor_name: str) -> Optional[list[str]]:
        """
        Returns the shape (dimensions) for a tensor.
        If the information is not stored as a value info a 'KeyError" is raised.
        """
        if tensor_name in self._ini_dict:
            return list(self._ini_dict[tensor_name].dims)
        if tensor_name in self._value_info_dict:
            dimensions = self._value_info_dict[tensor_name].type.tensor_type.shape.dim
            return [
                dim.dim_value if not dim.dim_param else dim.dim_param
                for dim in dimensions
            ]
        return None

    def get_writer_name(self, tensor_name: str) -> str:
        """
        Return unique name of writer of tensor.
        Raises KeyError if tensor not found.
        Raises ValueError if not a single writer.
        """
        writers = self._writers[tensor_name]
        if len(writers) != 1:
            raise ValueError(
                f"tensor {tensor_name} does not have a unique writer: {sorted(writers)}"
            )
        return list(writers)[0]

    def get_writer_names(self, tensor_name: str) -> list[str]:
        """
        Return names of writers of tensor.
        Raises KeyError if tensor not found.
        """
        return sorted(self._writers[tensor_name])

    def has_node(self, node_name: str) -> bool:
        return node_name in self._nodes_dict

    def has_writer(self, tensor_name: str) -> bool:
        return tensor_name in self._writers and len(self._writers[tensor_name]) == 1

    def is_graph_input(self, tensor_name: str) -> bool:
        return tensor_name in self._input_names

    def is_initializer(self, tensor_name: str) -> bool:
        """
        Returns true if tensor is an initializer
        """
        return tensor_name in self._ini_dict

    def initializer_multiplication(self, initializer_name: str, factor: float) -> None:
        """
        Multiply an existing initializer with the given factor and store the result under the same name
        """
        tensor = self.get_initializer(initializer_name)
        data = onnxTensorProto_to_array(tensor)
        self.update_initializer_value(initializer_name, factor * data)

    @staticmethod
    def load(model_path: str) -> "ModelDict":
        return ModelDict(onnx.load_model(model_path), onnx_ops=OnnxOpsWrapper())

    def remove_attribute(self, node_name: str, attr_name: str) -> None:
        onnx_node = self.get_onnx_node(node_name)
        for attribute in onnx_node.attribute:
            if attribute.name == attr_name:
                onnx_node.attribute.remove(attribute)
                return
        raise ValueError(f"node {node_name} has no attribute with name {attr_name}")

    def remove_input(self, node_name: str, input_name: str) -> None:
        """
        Removes input from node
        """
        onnx_node = self.get_onnx_node(node_name)
        assert input_name in onnx_node.input  # sanity check
        onnx_node.input.remove(input_name)

        self._readers[input_name].discard(node_name)

        self._update_node_unconnected(node_name)
        update_nodes = self._writers[input_name]
        self._update_nodes_unconnected(update_nodes)
        self._update_out_dict_nodes(update_nodes)

    def remove_node(self, node_name: str) -> None:
        onnx_node = self.get_onnx_node(node_name)

        for input in onnx_node.input:
            if self.is_initializer(input):
                self._lost_ini_helper.dec(input)

        self._op_type_dict[onnx_node.op_type].remove(node_name)

        update_nodes: set[str] = set()
        for input_name in onnx_node.input:
            if input_name == "":
                continue  # unused optional input
            self._readers[input_name].discard(node_name)
            update_nodes |= self._writers[input_name]
        for output_name in onnx_node.output:
            self._writers[output_name].discard(node_name)
            update_nodes |= self._readers[output_name]

        self._model.graph.node.remove(onnx_node)

        del self._nodes_dict[node_name]
        del self._in_nodes_dict[node_name]
        del self._out_nodes_dict[node_name]

        self._update_nodes_unconnected(update_nodes)
        self._update_in_out_dict_nodes(update_nodes)

        self._nodes_unconnected.discard(node_name)

    def remove_output(self, node_name: str, output_name: str) -> None:
        """
        Removes output from node
        """
        onnx_node = self.get_onnx_node(node_name)
        assert output_name in onnx_node.output  # sanity check
        onnx_node.output.remove(output_name)

        self._writers[output_name].discard(node_name)

        self._update_node_unconnected(node_name)
        update_nodes = self._readers[output_name]
        self._update_nodes_unconnected(update_nodes)
        self._update_in_dict_nodes(update_nodes)

    def remove_shape(self, tensor_name: str) -> None:
        """
        Removes shape information from tensor
        """
        for value_info in self._model.graph.value_info:
            if value_info.name == tensor_name:
                self._model.graph.value_info.remove(value_info)
                del self._value_info_dict[tensor_name]
                return
        raise KeyError(f"tensor {tensor_name} has no shape information")

    def remove_unconnected(self) -> None:
        """
        Remove all nodes with no output consumed or any input not produced.
        """
        while self._nodes_unconnected:
            # create local copy of names of unconnected nodes for for loop,
            # because self.remove_node() modifies self._nodes_unconnected
            unconnected = self._nodes_unconnected.copy()
            # remove unconnected nodes
            for node_name in unconnected:
                self.remove_node(node_name)

    def replace_input(
        self, node_name: str, input_name_old: str, input_name_new: str
    ) -> None:
        """
        Replace input named `input_name_old` with `input_name_new`
        """
        onnx_node = self.get_onnx_node(node_name)
        node_inputs = self.get_input_names(node_name)
        for idx, input in enumerate(node_inputs):
            if input == input_name_old:
                onnx_node.input[idx] = input_name_new

        self._readers[input_name_old].remove(node_name)
        self._readers[input_name_new].add(node_name)

        self._update_node_unconnected(node_name)
        update_nodes = self._writers[input_name_old] | self._writers[input_name_new]
        self._update_nodes_unconnected(update_nodes)
        self._update_out_dict_nodes(update_nodes)

    def replace_output(
        self, node_name: str, output_name_old: str, output_name_new: str
    ) -> None:
        """
        Replace output named `output_name_old` with `output_name_new`
        """
        onnx_node = self.get_onnx_node(node_name)
        idx = self.get_output_names(node_name).index(output_name_old)
        onnx_node.output[idx] = output_name_new
        self._writers[output_name_old].remove(node_name)
        self._writers[output_name_new].add(node_name)

        self._update_node_unconnected(node_name)
        update_nodes = self._readers[output_name_old] | self._readers[output_name_new]
        self._update_nodes_unconnected(update_nodes)
        self._update_out_dict_nodes(update_nodes)

    def set_attribute(self, node_name: str, attr_name: str, attr_value: Any):
        """
        Set the value of the attribute (overwrites value if it exists)
        """
        try:
            self.remove_attribute(node_name, attr_name)
        except ValueError:
            pass
        self.append_attribute(node_name, attr_name, attr_value)

    def set_shape(self, tensor_name: str, shape: list[int], element_type: int) -> None:
        """
        Set the shape of the tensor (overwrites value if it exists)
        """
        if self.get_shape(tensor_name) is not None:
            self.remove_shape(tensor_name)

        value_info = onnx.helper.make_tensor_value_info(
            tensor_name, onnxTensor_np_dtype_to_dtype(element_type), shape
        )
        self._model.graph.value_info.append(value_info)
        self._value_info_dict[tensor_name] = self._model.graph.value_info[-1]

    def sanity_check(self, logger: Logger) -> None:
        """
        Perform sanity checks of readers and writers.
        """
        consumed_tensors = self._get_consumed_tensors()
        produced_tensors = self._get_produced_tensors()
        uninitialized = consumed_tensors - produced_tensors
        assert (
            len(uninitialized) == 0
        ), f"the following tensors are consumed but not produced : {sorted(uninitialized)}"  # sanity check
        for tensor_name, writer_names in self._writers.items():
            assert len(writer_names) <= 1, (
                f"multiple writers for {tensor_name}:" f" {sorted(writer_names)}"
            )  # sanity check
            assert (
                len(writer_names) == 0 or tensor_name not in self._input_names.keys()
            ), f"node {list(writer_names)[0]} writes tensor {tensor_name}, but it is also a graph input"  # sanity check
            assert (
                len(writer_names) == 0 or tensor_name not in self._ini_dict.keys()
            ), f"node {list(writer_names)[0]} writes tensor {tensor_name}, but it is also an initializer"  # sanity check
        in_and_ini = self._ini_dict.keys() & self._in_nodes_dict.keys()
        assert (
            len(in_and_ini) == 0
        ), f"tensors {sorted(in_and_ini)} are graph inputs and initializers"  # sanity check

        for tensor in consumed_tensors | produced_tensors:
            if self.get_shape(tensor) is None:
                logger.debug(f"warning: tensor {tensor} has no shape information")
