import dataclasses
import logging
import os
from typing import Any, TypeAlias, Optional, Iterable
from collections import defaultdict
from dataclasses import dataclass

import numpy as np
import onnx
from google.protobuf.json_format import MessageToDict
from onnx import (
    ModelProto,
    numpy_helper,
    TensorProto,
    AttributeProto,
    ValueInfoProto,
    NodeProto,
)

from onnx.helper import (
    make_attribute,
    make_tensor_value_info,
)

from onnx.external_data_helper import load_external_data_for_tensor

from OGOAT.src.L1_fusion.py_match.helpers.common_type import (
    TensorShape,
    OnnxDType,
    NumpyDType,
    NamedArray,
)

from ml_dtypes import bfloat16

np.dtype("bfloat16")


TensorToNodes: TypeAlias = dict[str, list[str]]
"""TensorToNodes: a dict that maps a tensor name to a list of node-names"""
NodeToNodes: TypeAlias = dict[str, list[str]]
"""NodeToNodes: a dict that maps a node-name to a list of node-names"""


class ShapeMismatchError(ValueError):
    pass


class NodeNotFound(LookupError):
    def __init__(
        self,
        msg=None,
        node_name: str = None,
    ) -> None:
        msg = f"Node({node_name}) not found." if not msg else msg
        super().__init__(msg)


@dataclasses.dataclass
class TensorInfo:
    name: str
    shape: TensorShape
    dtype: OnnxDType

    def make_activation_tensor(self) -> ValueInfoProto:
        tensor_ = make_tensor_value_info(
            self.name,
            self.dtype,
            self.shape,
        )
        return tensor_

    @staticmethod
    def from_initializer(tensor_: TensorProto):
        return TensorInfo(
            name=tensor_.name,
            dtype=tensor_.data_type,
            shape=list(tensor_.dims),
        )

    @staticmethod
    def from_activation_tensor(tensor_: ValueInfoProto):
        return TensorInfo(
            name=tensor_.name,
            dtype=tensor_.type.tensor_type.elem_type,
            shape=[
                d.dim_value if not d.dim_param else int(d.dim_param)
                for d in tensor_.type.tensor_type.shape.dim
            ],
        )


@dataclass(init=False)
class model_dict:
    ini: dict[str, TensorProto]
    nodes: dict[str, NodeProto]
    tensor_readers: TensorToNodes
    tensor_writers: TensorToNodes
    input_nodes: NodeToNodes
    output_nodes: NodeToNodes
    inputs: dict[str, ValueInfoProto]
    outputs: dict[str, ValueInfoProto]
    vinfo: dict[str, ValueInfoProto]

    @staticmethod
    def create(model: onnx.ModelProto) -> "model_dict":
        md = model_dict()
        md.update_dict(model)
        return md

    def update_dict(self, model: onnx.ModelProto) -> None:
        self.ini: dict[str, TensorProto] = construct_initializer_dict(model)
        self.nodes: dict[str, NodeProto] = construct_nodes_dict(model)
        self.tensor_readers, self.tensor_writers = construct_tensor_in_out_dict(model)
        self.input_nodes, self.output_nodes = construct_node_in_out_dict(
            model, self.tensor_readers, self.tensor_writers
        )
        self.inputs: dict[str, ValueInfoProto] = construct_graph_input_dict(model)
        self.outputs: dict[str, ValueInfoProto] = construct_graph_output_dict(model)
        self.vinfo: dict[str, ValueInfoProto] = construct_valinfo_dict(model)
        self.vinfo |= self.inputs
        self.vinfo |= self.outputs
        self._setup_initializer_nodes()
        return

    def _setup_initializer_nodes(self) -> dict[str, NodeProto]:
        self._initializer_nodes: dict[str, NodeProto] = dict()
        ini_nodes_counter = None
        while ini_nodes_counter != len(self._initializer_nodes):
            ini_nodes_counter = len(self._initializer_nodes)
            for n_name, node in self.nodes.items():
                if self.is_initializer_node(node):
                    self._initializer_nodes[n_name] = node

        return self._initializer_nodes

    def is_node(self, name_: str) -> bool:
        if name_ in self.nodes:
            return True
        return False

    def is_tensor(self, name_: str) -> bool:
        if name_ in self.ini or name_ in self.vinfo:
            return True
        return False

    def is_const_initializer(self, tensor_name: str) -> bool:
        """True IF t in md.inis"""
        return tensor_name in self.ini

    def is_initializer(self, tensor_name: str) -> bool:
        """True IF t in md.inis OR t.writer is ini_node"""
        if tensor_name in self.ini:
            return True
        if tensor_name not in self.tensor_writers:
            return False  # tensor is graph input -- activation
        return any(
            [w in self._initializer_nodes for w in self.tensor_writers[tensor_name]]
        )

    def is_activation(self, tensor_name: str) -> bool:
        if tensor_name not in self.vinfo:
            return False
        return not self.is_initializer(tensor_name)

    def is_initializer_node(self, node: str | NodeProto) -> bool:
        node = self.get_node(node)
        return node.input and all(map(self.is_initializer, node.input))

    def is_model_input(self, name: str) -> bool:
        return name in self.inputs

    def is_model_output(self, name: str) -> bool:
        return name in self.outputs

    def is_root(self, node: str | Optional[NodeProto]) -> bool:
        node = self.get_node(node)
        return node and node.input and any([i in self.inputs for i in node.input])

    def is_leaf(self, node: str | Optional[NodeProto]) -> bool:
        node = self.get_node(node)
        return node and node.output and any([o in self.outputs for o in node.output])

    def get_node(
        self, node: str | Optional[NodeProto], default: Any = None
    ) -> NodeProto | Any:
        if node and isinstance(node, str):
            node = self.nodes.get(node, None)
        if node and isinstance(node, NodeProto):
            node = self.nodes.get(node.name, None)
        if not node:
            return default
        return node

    def get_nodes(
        self, names: list[str] = None, *, op_types: Iterable[str] = None
    ) -> list[NodeProto]:
        if names is None and op_types is None:
            return list(self.nodes.values())
        if names:
            return [self.nodes[n] for n in names]
        if op_types:
            return [n for n in self.nodes.values() if n.op_type in op_types]
        return []

    def get_nodes_dict(self, *op_types) -> dict[str, NodeProto]:
        if not op_types:
            return self.nodes
        nodes_ = {name: n for name, n in self.nodes.items() if n.op_type in op_types}
        return nodes_

    def get_roots(self) -> dict[str, NodeProto]:
        roots_ = {
            name_: node_ for name_, node_ in self.nodes.items() if self.is_root(node_)
        }
        return roots_

    def get_leaves(self) -> dict[str, NodeProto]:
        leaves_ = {
            name_: node_ for name_, node_ in self.nodes.items() if self.is_leaf(node_)
        }
        return leaves_

    def get_tensor_suppliers(
        self, tensor_name: str, activations_only=False
    ) -> list[str]:
        """
        Tensor suppliers of a tensor X -- are input tensors of X's  writer-nodes
        """
        if tensor_name not in self.tensor_writers:
            return []
        suppliers_ = [
            i for w in self.tensor_writers[tensor_name] for i in self.nodes[w].input
        ]
        if activations_only:
            suppliers_ = [t for t in suppliers_ if self.is_activation(t)]
        return suppliers_

    def get_tensor_consumers(self, tensor_name: str) -> list[str]:
        """
        Tensor consumers of a tensor X -- are output tensors of X's  reader-nodes
        """
        if tensor_name not in self.tensor_readers:
            return []
        consumers_ = [
            o for r in self.tensor_readers[tensor_name] for o in self.nodes[r].output
        ]
        return consumers_

    def get_node_suppliers(
        self,
        node: str | Optional[NodeProto],
        *op_types: str,
        exclude=False,
        first=False,
        quiet=True,
    ) -> dict[str, NodeProto] | Optional[NodeProto]:
        """
        Node suppliers of a node N -- are writer-nodes of N's input tensors, a.k.a. input nodes
        """
        node = self.get_node(node)
        if not node or node.name not in self.input_nodes:
            return dict() if not first else None

        op_types = op_types if op_types else []
        suppliers_ = {
            # i : n
            i: self.get_node(i)
            for i in self.input_nodes[node.name]
            if not op_types or (self.get_op_type(i) in op_types) ^ exclude
        }

        if first:
            return next(iter(suppliers_.values())) if suppliers_ else None
        if not suppliers_:
            if quiet:
                return suppliers_
            raise NodeNotFound(
                f"Node({node.name}) has no node-suppliers with op_type={op_types if op_types else 'ANY'}"
            )
        return suppliers_

    def get_node_consumers(
        self,
        node: str | Optional[NodeProto],
        *op_types: str,
        exclude=False,
        first=False,
        quiet=True,
    ) -> dict[str, NodeProto] | Optional[NodeProto]:
        """
        Node consumers of a node N -- are reader-nodes of N's output tensors, a.k.a. output nodes
        """
        node = self.get_node(node)
        if not node or node.name not in self.output_nodes:
            return dict() if not first else None

        op_types = op_types if op_types else []
        consumers_ = {
            # o : n
            o: self.get_node(o)
            for o in self.output_nodes[node.name]
            if not op_types or (self.get_op_type(o) in op_types) ^ exclude
        }

        if first:
            return next(iter(consumers_.values())) if consumers_ else None
        if not consumers_:
            if quiet:
                return consumers_
            raise NodeNotFound(
                f"Node({node.name}) has no node-consumers with op_type={op_types if op_types else 'ANY'}"
            )
        return consumers_

    def get_tensor(self, tensor_name: str) -> ValueInfoProto | TensorProto | None:
        if tensor_name in self.ini:
            return self.ini[tensor_name]
        if tensor_name in self.vinfo:
            return self.vinfo[tensor_name]
        return None

    def get_activations(self) -> list[ValueInfoProto]:
        activations_ = [
            tensor
            for t_name, tensor in self.vinfo.items()
            if self.is_activation(t_name)
        ]
        return activations_

    def get_dynamic_initializers(self) -> list[ValueInfoProto]:
        initializers_ = [
            tensor
            for t_name, tensor in self.vinfo.items()
            if self.is_initializer(t_name)
        ]
        return initializers_

    def get_node_activations(
        self,
        node: str | NodeProto,
        first=False,
    ) -> dict[str, ValueInfoProto] | Optional[ValueInfoProto]:
        node = self.get_node(node)
        if not node:
            return dict() if not first else None

        activations_ = {
            t_name: self.vinfo[t_name]
            for t_name in node.input
            if self.is_activation(t_name)
        }

        if first:
            return next(iter(activations_.values())) if activations_ else None
        return activations_

    def get_node_activations_index(self, node: str | NodeProto) -> dict[str, int]:
        node = self.get_node(node)
        if not node:
            return dict()

        activations_ = {
            t_name: idx
            for idx, t_name in enumerate(node.input)
            if self.is_activation(t_name)
        }
        return activations_

    def get_node_outputs(
        self,
        node: str | NodeProto,
        first=False,
    ) -> dict[str, ValueInfoProto] | Optional[ValueInfoProto]:
        node = self.get_node(node)
        if not node:
            return dict() if not first else None

        outs_ = {
            t_name: self.vinfo[t_name] for t_name in node.output if t_name in self.vinfo
        }

        if first:
            return next(iter(outs_.values())) if outs_ else None
        return outs_

    def get_node_outputs_index(self, node: str | NodeProto) -> dict[str, int]:
        node = self.get_node(node)
        if not node:
            return dict()

        outs_ = {t_name: idx for idx, t_name in enumerate(node.output)}
        return outs_

    def get_writer(self, tensor: str) -> NodeProto | None:
        if tensor not in self.tensor_writers:
            return None
        writers_ = self.tensor_writers[tensor]
        assert len(writers_) == 1, f"Tensor({tensor}) has multiple writers: {writers_}"

        writer_ = self.get_node(writers_[0])
        return writer_

    def get_readers(
        self, tensor: str, *op_types: str, exclude=False
    ) -> list[NodeProto]:
        if tensor not in self.tensor_readers:
            return []

        op_types = op_types if op_types else []
        readers_ = [
            self.get_node(r)
            for r in self.tensor_readers[tensor]
            if not op_types or (self.get_node(r).op_type in op_types) ^ exclude
        ]
        return readers_

    def is_output_of(self, tensor: str, *op_types: str) -> bool:
        writer_ = self.get_writer(tensor)
        return writer_ and writer_.op_type in op_types

    def is_input_of(self, tensor: str, *op_types: str) -> bool:
        readers_ = self.get_readers(tensor, *op_types)
        return bool(readers_)

    def get_shape(
        self, tensor_: str | ValueInfoProto | TensorProto | TensorInfo
    ) -> TensorShape:
        """
        Returns the shape (dimensions) for a tensor.
        If the information is not stored as a value info a 'KeyError" is raised.
        """
        shape_ = []
        if isinstance(tensor_, TensorInfo):
            return tensor_.shape

        if isinstance(tensor_, str) and tensor_ in self.ini:
            shape_ = list(self.ini[tensor_].dims)
        if isinstance(tensor_, TensorProto):
            shape_ = list(tensor_.dims)

        if isinstance(tensor_, str) and tensor_ in self.vinfo:
            tensor_ = self.vinfo[tensor_]
        if isinstance(tensor_, ValueInfoProto):
            dimensions = tensor_.type.tensor_type.shape.dim
            shape_ = [d.dim_param if d.dim_param else d.dim_value for d in dimensions]
            shape_ = [int(d) if str(d).isdigit() else d for d in shape_]
        return shape_

    def get_dim(self, tensor_: str | ValueInfoProto | TensorProto | TensorInfo) -> int:
        shape_ = self.get_shape(tensor_)
        return len(shape_)

    def get_rank(self, tensor_: str | ValueInfoProto | TensorProto | TensorInfo) -> int:
        shape_ = self.get_shape(tensor_)
        shape_ = [d for d in shape_ if d > 1]
        return len(shape_)

    def get_onnx_dtype(
        self, tensor_: str | ValueInfoProto | TensorProto | TensorInfo
    ) -> OnnxDType:
        dtype_ = None
        if isinstance(tensor_, TensorInfo):
            return tensor_.dtype

        if isinstance(tensor_, str) and tensor_ in self.ini:
            tensor_ = self.ini[tensor_]
        if isinstance(tensor_, TensorProto):
            dtype_ = tensor_.data_type

        if isinstance(tensor_, str) and tensor_ in self.vinfo:
            tensor_ = self.vinfo[tensor_]
        if isinstance(tensor_, ValueInfoProto):
            dtype_ = tensor_.type.tensor_type.elem_type

        return dtype_

    def get_shape_dtype(
        self, tensor: str | ValueInfoProto | TensorProto | TensorInfo
    ) -> tuple[TensorShape, OnnxDType]:
        shape_ = self.get_shape(tensor)
        dtype_ = self.get_onnx_dtype(tensor)
        return shape_, dtype_

    def get_op_type(self, node: str | NodeProto) -> Optional[str]:
        node = self.get_node(node)
        if not node:
            return None
        return node.op_type


def construct_constant_dict(model: onnx.ModelProto) -> dict[str, NamedArray]:
    constant_dict = {}
    for node in model.graph.node:
        if node.op_type != "Constant" or not node.attribute:
            continue
        if (attr := node.attribute[0]) and (attr.type == AttributeProto.TENSOR):
            value = numpy_helper.to_array(attr.t)
            const_info = NamedArray(node.name, value, value.dtype)
            constant_dict[node.output[0]] = const_info
    return constant_dict


def check_binary_shapes(md):
    for node_name, node in md.nodes.items():
        if node.op_type not in ["Add", "Mul", "Div", "Sub"]:
            continue

        try:
            check_binary_node_shapes(md, node)
        except ShapeMismatchError as e:
            logging.warning(e)


def check_binary_node_shapes(md, node):
    inA, inB, outC = node.input[0], node.input[1], node.output[0]
    shapeA = md.get_shape(inA)
    shapeB = md.get_shape(inB)
    shapeC = md.get_shape(outC)

    # view scalars as 1-dim vectors
    shapeA = shapeA if shapeA else [1]
    shapeB = shapeB if shapeB else [1]
    shapeC = shapeC if shapeC else [1]
    if shapeA[-1] == 1 and shapeB[-1] == shapeC[-1]:
        return
    if shapeB[-1] == 1 and shapeA[-1] == shapeC[-1]:
        return
    if shapeA[-1] == shapeB[-1] == shapeC[-1]:
        return

    raise ShapeMismatchError(
        f"Shape mismatch: {node.op_type.upper()}({node.name}) :: "
        f"A{shapeA} + B{shapeB} = C{shapeC}"
    )


def get_fixed_shapes_from_params(
    model_path, shape_params: dict
) -> tuple[dict[str, list[int]], dict[str, list[int]]]:
    model = onnx.load_model(model_path, load_external_data=False)
    graph_inputs = {}
    graph_outputs = {}

    for input in model.graph.input:
        graph_inputs[input.name] = [
            d.dim_value if d.dim_value else shape_params[d.dim_param]
            for d in input.type.tensor_type.shape.dim
        ]
    for output in model.graph.output:
        graph_outputs[output.name] = [
            d.dim_value if d.dim_value else shape_params[d.dim_param]
            for d in output.type.tensor_type.shape.dim
        ]
    return graph_inputs, graph_outputs


def get_shape_params_from_model(
    model: onnx.ModelProto,
) -> list[str]:
    """
    Returns the list of shape parameters used in the model.
    """
    shape_params = []
    for input in model.graph.input:
        for dim in input.type.tensor_type.shape.dim:
            if dim.dim_param:
                shape_params.append(dim.dim_param)
    return shape_params


def onnxTensorProto_to_array(
    onnx_tensor_proto: TensorProto,
    transpose=0,
    out_dir="",
    permutation: Optional[list[int]] = None,
) -> tuple[np.ndarray, NumpyDType]:
    """
    Converts a tensor def object to a numpy array. Supports int4
    Return dimension reduced transposed output if transpose = 1

    """
    if onnx_tensor_proto.data_type == 22:  # int4
        # init.data_type = 3 #3 for int8, 22 for int4
        og_dtype = "int4"

        if onnx_tensor_proto.raw_data:  # stored in int8 or 8 bits
            raw_int4_data = np.frombuffer(onnx_tensor_proto.raw_data, dtype=np.uint8)
            # extract higher and lower nibbles in int8 arrays
            higher_nibble = (raw_int4_data >> 4).astype(np.int8)
            lower_nibble = (raw_int4_data & 0xF).astype(np.int8)
            # combine the two in one array
            np_data_int4 = np.stack((lower_nibble, higher_nibble), axis=1)
            np_data_int4 = np_data_int4.flatten()
            # convert to signed int4
            np_data_int4 = np.where(np_data_int4 > 7, np_data_int4 - 16, np_data_int4)

        elif onnx_tensor_proto.int32_data:  # stored in int32 or 32 bits
            # print(onnx_tensor_proto.name)
            raw_int32_data = np.array(onnx_tensor_proto.int32_data, dtype=np.int32)
            int4_list = []

            for val in raw_int32_data:
                for i in range(8):  # to extract 8 int4 samples from each int32 value
                    int4_sample = (val >> (4 * i)) & 0xF
                    int4_sample_signed = (
                        int4_sample - 16 if int4_sample > 7 else int4_sample
                    )
                    int4_list.append(int4_sample_signed)

            np_data_int4 = np.array(int4_list, dtype=np.int8)

            # print(onnx_tensor_proto.int32_data)
            # print(len(raw_int32_data))
            # print(len(np_data_int4))

        # reshape
        shape = tuple(onnx_tensor_proto.dims)
        initnp = np_data_int4[: np.prod(shape)].reshape(shape)
        initnp = initnp.astype(np.int8)
    else:
        if onnx_tensor_proto.data_location == TensorProto.EXTERNAL:
            initnp = numpy_helper.to_array(
                onnx_tensor_proto, out_dir
            )  # get numpy array from external data
        else:
            initnp = numpy_helper.to_array(onnx_tensor_proto)  # get numpy array
        og_dtype = initnp.dtype.name

    if transpose:
        if permutation is None:
            return initnp.T, og_dtype
        else:
            return np.transpose(initnp, permutation), og_dtype

    return initnp, og_dtype


def onnxTensorProto_from_array(
    numpy_array: np.ndarray, proto_name: str, og_dtype: NumpyDType, change_to_int8=False
) -> TensorProto:
    """
    Converts a numpy array to tensor def object. Supports int4
    og_type is original TensorProto's datatype, before converting to np array

    """
    if og_dtype == "int4" and change_to_int8 == False:
        flattened_array = numpy_array.flatten()

        # pad if necessary
        if len(flattened_array) % 2 != 0:
            flattened_array = np.append(flattened_array, 0)

        # convert to uint4, one value then can be fit in one nibble
        unsigned_int4 = np.where(
            flattened_array < 0, flattened_array + 16, flattened_array
        ).astype(np.uint8)

        # pack two int4 values in one byte
        packed_int4_data = ((unsigned_int4[1::2] & 0xF) << 4) | (
            unsigned_int4[0::2] & 0xF
        )

        # packed shape
        # packed_shape = list(shape)
        # if len(packed_shape) > 0:
        #     packed_shape[-1] = (packed_shape[-1] +1) // 2

        # create tensor proto, use raw_data to store the packed values
        shape = numpy_array.shape
        new_initializer = TensorProto(
            name=proto_name,
            data_type=TensorProto.INT4,
            dims=shape,
            raw_data=packed_int4_data.tobytes(),
        )

    else:
        new_initializer = numpy_helper.from_array(numpy_array, proto_name)

    return new_initializer


def convert_int4_inits_to_int8(model):
    """
    converts all int4 initializers of a model to int8, this is required for custom ops as int4 support is not there yet in custom ops
    """
    print("Converting int4 initializers to int8")
    for idx, init in enumerate(model.graph.initializer):
        if init.data_type == 22:
            init_np, og_dtype = onnxTensorProto_to_array(init, transpose=0)
            tensor_t = onnxTensorProto_from_array(
                init_np, init.name, og_dtype=og_dtype, change_to_int8=True
            )
            model.graph.initializer[idx].CopyFrom(tensor_t)


def onnxTensor_dtype_to_np_dtype(tensor_type: OnnxDType) -> NumpyDType:
    if tensor_type == 22:
        return "int4"
    elif tensor_type == 0:
        return "undefined"
    else:
        return str(onnx.helper.tensor_dtype_to_np_dtype(tensor_type))


def onnxTensor_np_dtype_to_dtype(element_type: NumpyDType) -> OnnxDType:
    return onnx.helper.np_dtype_to_tensor_dtype(np.dtype(element_type))


def right_broadcasting(arr, target):
    return arr.reshape(arr.shape + (1,) * (target.ndim - arr.ndim))


def remove_node(model__, node_name):
    for node in model__.graph.node:
        if node.name == node_name:
            model__.graph.node.remove(node)
    return model__


def change_node_input(model__, node_name, input_indx, input_name):
    for node in model__.graph.node:
        if node.name == node_name:
            node.input[input_indx] = input_name
    return model__


def extract_instnorm_eps(node):
    for attr in node.attribute:
        if attr.name == "epsilon":
            return np.float32(attr.f)


def construct_initializer_dict(model: ModelProto) -> dict[str, TensorProto]:
    INTIALIZERS = model.graph.initializer
    initializer_dict = {}
    for initializer in INTIALIZERS:
        if initializer.name not in initializer_dict:
            # initializer_dict[initializer.name] = numpy_helper.to_array(initializer)
            initializer_dict[initializer.name] = initializer
        else:
            logging.warning(
                f"\033[33m!! Duplicated Initializer = Tensor({initializer.name}) !!\033[0m"
            )
    return initializer_dict


def construct_nodes_dict(model) -> dict[str, NodeProto]:
    nodes_dict = {}
    for node in model.graph.node:
        nodes_dict[node.name] = node
    return nodes_dict


def extract_attr_value(attr: AttributeProto) -> Any:
    if not attr:
        return None
    if attr.type == attr.FLOAT:
        return attr.f
    elif attr.type == attr.INT:
        return attr.i
    elif attr.type == attr.STRING:
        return attr.s
    elif attr.type == attr.TENSOR:
        return numpy_helper.to_array(attr.t)
    elif attr.type == attr.FLOATS:
        return list(attr.floats)
    elif attr.type == attr.INTS:
        return list(attr.ints)
    elif attr.type == attr.STRINGS:
        return list(attr.strings)
    elif attr.type == attr.TENSORS:
        return [numpy_helper.to_array(t) for t in attr.tensors]
    else:
        raise ValueError(f"Attribute={attr.name} has unknown type {attr.type}")


def get_attrs(node: NodeProto) -> dict[str, Any]:
    attrs = dict()
    for attr in node.attribute:
        if attr.type == attr.FLOAT:
            attrs[attr.name] = attr.f
        elif attr.type == attr.INT:
            attrs[attr.name] = attr.i
        elif attr.type == attr.STRING:
            attrs[attr.name] = attr.s
        elif attr.type == attr.TENSOR:
            attrs[attr.name] = numpy_helper.to_array(attr.t)
        elif attr.type == attr.FLOATS:
            attrs[attr.name] = list(attr.floats)
        elif attr.type == attr.INTS:
            attrs[attr.name] = list(attr.ints)
        elif attr.type == attr.STRINGS:
            attrs[attr.name] = list(attr.strings)
        elif attr.type == attr.TENSORS:
            attrs[attr.name] = [numpy_helper.to_array(t) for t in attr.tensors]
    return attrs


def get_attribute(node: NodeProto, attr_name: str) -> Any:
    if not node or not attr_name:
        return None
    for attr in node.attribute:
        if attr.name == attr_name:
            _attr_value = extract_attr_value(attr)
            return _attr_value
    return None


def remove_attribute(node: NodeProto, attr_name: str) -> Any:
    """
    Removes attribute `attr_name` from node.
    Returns previous value if any.
    """
    if not node or not attr_name:
        return None

    for attr in node.attribute:
        if attr.name == attr_name:
            _attr_value = extract_attr_value(attr)
            node.attribute.remove(attr)
            return _attr_value
    logging.debug(
        f"utils.remove_attribute: Node {node.name} has no attribute with name {attr_name}"
    )
    return None


def set_attribute(node: NodeProto, attr_name: str, attr_value: Any) -> Any:
    """
    Set the value of the attribute (overwrites value if it exists)
    Returns previous value if any.
    """
    _attr_value = remove_attribute(node, attr_name)
    node.attribute.append(make_attribute(attr_name, attr_value))
    return _attr_value


def add_attrs_to_node(node, attrs):
    for key, value in attrs.items():
        if isinstance(value, np.ndarray):
            tensor = numpy_helper.from_array(value)
            attr = make_attribute(key, tensor)
        elif isinstance(value, list):
            if all(isinstance(x, np.ndarray) for x in value):
                tensors = [numpy_helper.from_array(x) for x in value]
                attr = make_attribute(key, tensors)
            else:
                attr = make_attribute(key, value)
        else:
            attr = make_attribute(key, value)
        node.attribute.append(attr)
    return node


def construct_tensor_in_out_dict(model) -> tuple[TensorToNodes, TensorToNodes]:
    """
    Maps tensors to their node-readers and node-writers
    Returns:
        * tensor_readers: dict[str, list[str]]
            -- {tensor name -> list of node-names reading this tensor}
        * tensor_writers: dict[str, list[str]]
            -- {tensor name -> list of node-names writing this tensor}
    """
    tensor_readers = defaultdict(list)
    tensor_writers = defaultdict(list)

    for node in model.graph.node:
        for input_tensor in node.input:
            tensor_readers[input_tensor].append(node.name)

        for output_tensor in node.output:
            tensor_writers[output_tensor].append(node.name)

    return tensor_readers, tensor_writers


def construct_node_in_out_dict(
    model, tensor_readers: TensorToNodes, tensor_writers: TensorToNodes
) -> tuple[NodeToNodes, NodeToNodes]:
    in_nodes_dict = {}
    out_nodes_dict = {}
    for node in model.graph.node:
        in_nodes_dict[node.name] = []

        for node_input in node.input:
            if node_input in tensor_writers:
                in_nodes_dict[node.name].append(tensor_writers[node_input][0])

    for node in model.graph.node:
        out_nodes_dict[node.name] = []

        for node_output in node.output:
            if node_output in tensor_readers:
                for i in range(len(tensor_readers[node_output])):
                    out_nodes_dict[node.name].append(tensor_readers[node_output][i])

    return in_nodes_dict, out_nodes_dict


def construct_graph_input_dict(model) -> dict[str, ValueInfoProto]:
    graph_inputs = {graph_in.name: graph_in for graph_in in model.graph.input}
    return graph_inputs


def construct_graph_output_dict(model) -> dict[str, ValueInfoProto]:
    graph_outputs = {graph_out.name: graph_out for graph_out in model.graph.output}
    return graph_outputs


def construct_valinfo_dict(model) -> dict[str, ValueInfoProto]:
    value_info = {vinfo.name: vinfo for vinfo in model.graph.value_info}
    return value_info


def construct_attr_dict(attr):
    attr_dict = {}
    for n in attr:
        attr_dict[n.name] = n.ints
    return attr_dict


def construct_dict(model):
    ini_dict = construct_initializer_dict(model)
    nodes_dict = construct_nodes_dict(model)
    in_tensors_dict, out_tensors_dict = construct_tensor_in_out_dict(model)
    # in_nodes_dict list input nodes
    # out_nodes_dict list output nodes
    in_nodes_dict, out_nodes_dict = construct_node_in_out_dict(
        model, in_tensors_dict, out_tensors_dict
    )
    graph_inputs = construct_graph_input_dict(model)
    graph_outputs = construct_graph_output_dict(model)
    # value_info_dict = m.graph.value_info + m.graph.input + m.graph.output
    value_info_dict = construct_valinfo_dict(model)
    value_info_dict |= graph_inputs
    value_info_dict |= graph_outputs
    return (
        ini_dict,
        nodes_dict,
        in_nodes_dict,
        out_nodes_dict,
        value_info_dict,
        graph_inputs,
        graph_outputs,
    )


def find_closest_shifted_int16(float_val, shift_max=np.inf):
    INT16_MAX = 32767
    prev_rel_err = 1e9
    curr_float_val = float_val
    best_float_val = float(0)
    shift_val = np.int16
    shift_val = 0
    best_int = np.int16
    closest_curr_int = np.int16
    best_shift_val = np.int16

    while (curr_float_val <= INT16_MAX) and (shift_val <= shift_max):
        closest_curr_int = round(curr_float_val)
        cur_rel_err = abs(float_val - closest_curr_int / (2**shift_val)) / float_val

        if cur_rel_err < prev_rel_err:
            prev_rel_err = cur_rel_err
            best_float_val = float(closest_curr_int >> shift_val)
            best_shift_val = shift_val
            best_int = closest_curr_int

        curr_float_val *= 2
        shift_val += 1

    return [best_int, best_shift_val]


def find_closest_shifted_int32(float_val, INT32_MAX=8388607, shift_max=np.inf):
    prev_rel_err = 1e9
    curr_float_val = float_val
    best_float_val = float(0)
    shift_val = np.int16
    shift_val = 0
    best_int = np.int32
    closest_curr_int = np.int32
    best_shift_val = np.int16
    while (curr_float_val <= INT32_MAX) and (shift_val <= shift_max):
        closest_curr_int = round(curr_float_val)
        cur_rel_err = abs(float_val - closest_curr_int / (2**shift_val)) / float_val

        if cur_rel_err < prev_rel_err:
            prev_rel_err = cur_rel_err
            best_float_val = float(closest_curr_int >> shift_val)
            best_shift_val = shift_val
            best_int = closest_curr_int

        curr_float_val *= 2
        shift_val += 1

    return [best_int, best_shift_val]


def extract_tensor_shape(graph_value_dict, tensor_name):
    m_dict = MessageToDict(graph_value_dict[tensor_name])
    dim_info = m_dict.get("type").get("tensorType").get("shape").get("dim")
    if dim_info:
        input_shape = [d.get("dimValue") for d in dim_info]
    else:
        input_shape = []
    return input_shape


def extract_tensor_type(graph_value_dict, tensor_name):
    m_dict = MessageToDict(graph_value_dict[tensor_name])
    type_info = TensorProto.DataType.Name(
        m_dict.get("type").get("tensorType").get("elemType")
    )
    return type_info


def extract_tensor_shape_dict(graph_value_dict, tensor_name):
    for value_entry in graph_value_dict:
        if value_entry == tensor_name:
            m_dict = MessageToDict(graph_value_dict[value_entry])
            dim_info = m_dict.get("type").get("tensorType").get("shape").get("dim")
            input_shape = [d.get("dimValue") for d in dim_info]
            return input_shape


class L1FusionTempAttributes:
    TEMPORARY_ATTRS = [
        "matcher_name",
        "actxact",
        "bias",
        "orig_type",
        "dbgA",
        "dbgB",
        "dbgOut",
        "native_dtype",   
        "create_by_batch",      
        "Mul_const_value",      
        "Div_const_value",      
        "Add_const_value",      
        "Sub_const_value",      
        "L1_fusion_frozen",     
        "orig_x_zero_point_dtype",  
        "orig_y_zero_point_dtype",  
        "num_of_tensor_inputs",
        "orig_name",
        "RTR_subgraph_id",
        "RTR_subgraph_root",
        "RTR_subgraph_leaf",
        "RTR_is_inverse",
        "RTR_shape",
        "RTR_perm",
        "RTR_transposed",
        "RTR_origin",
    ]


def remove_additional_attributes(model: ModelProto) -> None:
    temp_attrs = L1FusionTempAttributes.TEMPORARY_ATTRS
    for node in model.graph.node:
        for attr in node.attribute[:]:
            if attr.name in temp_attrs:
                node.attribute.remove(attr)


def remove_additional_attributes_from_graph(graph: dict[str, dict]) -> None:
    temp_attrs = L1FusionTempAttributes.TEMPORARY_ATTRS
    for node_dict in graph.values():
        if attrs := node_dict.get("attributes"):
            for attr in temp_attrs:
                attrs.pop(attr, None)


def save_model(
    model: onnx.ModelProto, out_model_path: str, external_data: bool
) -> None:
    # save model
    if os.path.exists(out_model_path):
        os.remove(out_model_path)
    if external_data:
        out_model_data_file = os.path.splitext(out_model_path)[0] + ".onnx.data"
        # external data file needs to be removed to avoid appending to existing file
        if os.path.exists(out_model_data_file):
            os.remove(out_model_data_file)
        onnx.save_model(
            model,
            out_model_path,
            save_as_external_data=True,
            all_tensors_to_one_file=True,
            location=os.path.basename(out_model_data_file),
        )
        # Saving the model to new file (potentially?) loses the external data
        # from the in-memory file. If this happened, re-load it.
        if any(
            t.data_location == onnx.TensorProto.EXTERNAL
            for t in model.graph.initializer
        ):
            onnx.load_external_data_for_model(
                model, base_dir=os.path.dirname(out_model_path)
            )
    else:
        onnx.save(model, out_model_path)


def remove_model(model_path: str, external_data: bool) -> None:
    """
    Remove a model file from disk if it exists.
    If using an external data file, also remove the external data file if it
    exists.
    """
    if external_data:
        model_data_file = os.path.splitext(model_path)[0] + ".onnx.data"
        if os.path.isfile(model_data_file):
            os.unlink(model_data_file)
        # Also check for alternative data file naming pattern
        alternative_data_file = model_path + ".data"
        if os.path.isfile(alternative_data_file):
            os.unlink(alternative_data_file)

    if os.path.isfile(model_path):
        os.unlink(model_path)


def collect_unused_ini_nodes(model: onnx.ModelProto, md: model_dict):
    used_tensor_names = set()
    for node in model.graph.node:
        used_tensor_names.update(node.input)
        used_tensor_names.update(node.output)

    ini_unused = [ini for ini in md.ini if ini not in used_tensor_names]

    ini_list = []
    for n in ini_unused:
        for y in model.graph.initializer:
            if y.name == n:
                ini_list.append(y)

    return ini_list


def find_invalid_graph_output(model: onnx.ModelProto) -> list[str]:
    """
    Validates the ONNX model by checking if unused initializers are connected to graph global outputs.
    """
    md = model_dict.create(model)
    unused_ini = collect_unused_ini_nodes(model, md)
    invalid_graph_output = []
    for ini in unused_ini:
        for output in model.graph.output:
            if ini.name == output.name:
                invalid_graph_output.append(ini.name)
    return invalid_graph_output
