# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
import os
import onnx
from onnx import AttributeProto, ModelProto, NodeProto
from typing import Callable, Iterable
import numpy as np

from OGOAT.src.L1_fusion.L1_utils.model_IR_utils import (
    tag_layernorm_fusion_type,
)
from OGOAT.src.L1_fusion.L1_utils.safe_runner import SafeRunner
from OGOAT.src.L1_fusion.L1_utils.ops_definition_utils import OnnxOpsWrapper
from OGOAT.src.L1_fusion.L1_utils.utils import (
    construct_initializer_dict,
    onnxTensorProto_to_array,
    remove_additional_attributes_from_graph,
)
from OGOAT.src.L1_fusion.graph_info_utils import (
    GraphInfoParams,
    construct_tensor_in_out_dict,
    get_act_signal_shapes,
    get_child_nodes_dict,
    get_initializer_shapes,
    get_parent_nodes_dict,
    DTYPE_BYTES_DICT,
    BYPASS_OP_TYPES,
    LOW_PRECISION_INPUT_OP_TYPES,
)

from OGOAT.src.L1_fusion.kernel_metadata_loader import (
    KernelMetadataLoader,
)
from OGOAT.src.L1_fusion.py_match.fusion_frozen import DtypeFrozen


class GraphInfo:
    def __init__(
        self,
        model: ModelProto,
        graph_info_params: GraphInfoParams,
        runner: SafeRunner,
    ):
        self.graph = model.graph
        self.init_dict = construct_initializer_dict(model)
        self.graph_info_params = graph_info_params
        self._runner = runner
        self.onnx_wrapper = OnnxOpsWrapper()
        self.kernel_metadata = KernelMetadataLoader()

    def get_graph_info(self) -> dict:
        # get shapes and data types of initializers and activations
        self.model_initializer_shapes, self.model_initializer_dtypes = (
            get_initializer_shapes(self.graph)
        )
        self.all_act_signal_shapes, self.all_act_signal_dtypes = get_act_signal_shapes(
            self.graph
        )
        self.global_inputs = [x.name for x in self.graph.input]
        graph_ = {}
        for node in self.graph.node:
            node_info = self._runner.run(self.get_node_info, node)
            if not node_info:
                continue
            graph_[node.name] = node_info

        # add children and parents
        [in_tensors_dict, out_tensors_dict] = construct_tensor_in_out_dict(self.graph)
        graph_ = self._runner.run(
            get_child_nodes_dict,
            graph_,
            in_tensors_dict,
            BYPASS_OP_TYPES,
            bypass_ops=True,
        )
        graph_ = self._runner.run(
            get_parent_nodes_dict,
            graph_,
            out_tensors_dict,
            BYPASS_OP_TYPES,
            bypass_ops=True,
        )

        if self.graph_info_params.assign_new_dtypes:
            self._runner.run(self.assign_new_types, graph_)
        cwd = os.path.dirname(os.path.abspath(__file__))
        if graph_:
            graph_ = tag_layernorm_fusion_type(
                graph_,
                cwd + "/../../Collaterals/",
                device="strix",
                overlay="4x4",
            )

        # Remove unwanted attributes from graph which later populates in the IR files
        remove_additional_attributes_from_graph(graph_)
        return graph_

    def _update_hardware_datatype(self, node_info: dict[str, any]):
        """
        Update hardware datatype information in the node_info dictionary.
        Ignore marked frozen nodes and keep original datatypes for those nodes.
        """

        name = node_info.get("attributes", {}).get("orig_name", "") or node_info.get(
            "node_name", ""
        )
        if isinstance(name, list):
            name = name[0]

        if DtypeFrozen._frozen_nodes and name in DtypeFrozen._frozen_nodes:
            for input in node_info["inputs"]:
                input["hw_dtype"] = input["dtype"]
                input["hw_dtype_bytes"] = input["dtype_bytes"]

            for output in node_info["outputs"]:
                output["hw_dtype"] = output["dtype"]
                output["hw_dtype_bytes"] = output["dtype_bytes"]
            return

        # Normal processing for non-frozen nodes
        # update op_type
        idx = node_info["op_type"].rfind("_")
        hw_dtypes: list[str] = []
        for dtype in node_info["op_type"][idx + 1 :].split("x"):
            hw_dtypes.append(self.kernel_metadata.get_hardware_datatype(dtype))
        node_info["op_type"] = node_info["op_type"][: idx + 1] + "x".join(hw_dtypes)

        # update inputs
        for input in node_info["inputs"]:
            hw_dtype = self.kernel_metadata.get_hardware_datatype(input["dtype"])
            input["hw_dtype"] = hw_dtype
            input["hw_dtype_bytes"] = DTYPE_BYTES_DICT[hw_dtype]

        # update outputs
        for output in node_info["outputs"]:
            hw_dtype = self.kernel_metadata.get_hardware_datatype(output["dtype"])
            output["hw_dtype"] = hw_dtype
            output["hw_dtype_bytes"] = DTYPE_BYTES_DICT[hw_dtype]

        # Update datatype and bytes fields for in/wgt/wgt1/out
        for prefix in ["in", "wgt", "wgt1", "out"]:
            datatype_key = f"{prefix}_datatype"
            if datatype_key not in node_info:
                continue
            node_info[datatype_key] = self.kernel_metadata.get_hardware_datatype(
                node_info[datatype_key]
            )
            bytes_key = f"{prefix}_bytes"
            node_info[bytes_key] = DTYPE_BYTES_DICT[node_info[datatype_key]]

    def get_node_info(self, node: NodeProto) -> dict:
        node_info = {"node_name": node.name, "op_type": node.op_type}
        # Make a dictionary of attributes
        layer_params_ = self.add_layer_params(node)
        node_info["attributes"] = layer_params_
        if node.op_type.startswith("Concat_qdq_"):
            num_inputs = next((a.i for a in node.attribute if a.name == "num_inputs"), None)
            assert num_inputs is not None, f"Concat_qdq_ node '{node.name}' missing 'num_inputs' attribute."
            splitted_parts = node.op_type.split("_", 1)
            input_names = lambda _: self.onnx_wrapper.get_input_names(f"{splitted_parts[0]}{num_inputs}_{splitted_parts[1]}")
            output_names = lambda _: self.onnx_wrapper.get_output_names(f"{splitted_parts[0]}{num_inputs}_{splitted_parts[1]}")
        
        else:
            input_names = self.onnx_wrapper.get_input_names
            output_names = self.onnx_wrapper.get_output_names
            
        node_info["inputs"] = self.build_input_or_outputs(
            node.input, node.op_type, input_names
        )  # inputs to a node can be multiple
        node_info["outputs"] = self.build_input_or_outputs(
            node.output, node.op_type, output_names
        )  # there will be only one output from a node
        self.add_input_info(node_info)
        self.add_output_info(node_info)
        self.add_act_wgt_bias_info(node, node_info)
        self.add_out_act_shape(node, node_info)

        # DepthWise Conv Distinction, op_type update
        if (
            node.op_type.split("_")[0] == "Conv"
            and "group" in node_info["attributes"].keys()
        ):
            self.update_node_info_conv(node, node_info)

        if self.graph_info_params.assign_new_dtypes:
            # assign datatype to node i/o signals,
            # in datatypes are fixed for an op. Out datatypes will change based on child nodes, these will change later
            self.assign_datatypes_io_signals(node_info)
        else:
            # take datatypes as it is from the model
            self.assign_datatypes_from_model(node, node_info)

        # Special cases for matmul_qdq, this will override some
        # of the entry that were computed in a generic mattern
        # for all nodes
        if "MatMul_qdq" in node.op_type:
            self.add_matmul_node_info(node, node_info)

        # add dataype bytes
        self.add_bytes_to_node_info(node_info)
        ## add dummy residency params #TODO: remove it later
        node_info["in_act_residency"] = "L3"
        node_info["out_act_residency"] = "L3"
        if not self.graph_info_params.no_dtype_downcast:
            self._update_hardware_datatype(node_info)
        if self.graph_info_params.device and self.graph_info_params.device != "strix":
            node_info = self.add_const_padding_value(node_info)
        return node_info

    def add_const_padding_value(self, node_info: dict) -> dict:
        """
        Add padding value for txn padding to node_info attributes
        """
        param_name_map = {
            "Conv": "A_zero_point",
            "Softmax": "input_zero_point",
            "LayerNormalization": "X_zero_point",
            "GroupNormalization": "X_zero_point",
            "Transpose": "0",
            "Slice": "0",
            "Split": "0",
            "Concat": "0",
            "Quant": "0",
            "Dequant": "0",
            "MatMul_qdq_actxact_": "0",
            "MatMul_qdq_": "0",
        }
        op_type = node_info["op_type"]
        for key, val in param_name_map.items():
            if op_type == key or op_type.startswith(key):
                if val == "0":
                    node_info["attributes"]["const_padding_value"] = [str(0)]
                    break
                else:
                    inp_info = [
                        inp["name"]
                        for inp in node_info["inputs"]
                        if "param_name" in inp and inp["param_name"] == val
                    ]
                    if len(inp_info) == 0:
                        continue
                    const_name = inp_info[0]
                    pad_value = onnxTensorProto_to_array(self.init_dict[const_name])[0]
                    node_info["attributes"]["const_padding_value"] = [str(pad_value)]
                    break
            else:
                node_info["attributes"]["const_padding_value"] = ["NONE"]

        return node_info

    def add_layer_params(self, node: NodeProto) -> dict:
        """
        Add layer parameters to the node_info dictionary.
        This function is a placeholder and should be implemented based on the specific requirements.
        """
        layer_params_ = {}
        for attribute in node.attribute:
            if attribute.type == AttributeProto.FLOAT:  # "FLOAT"
                layer_params_[attribute.name] = [attribute.f]
            elif attribute.type == AttributeProto.FLOATS:  # "FLOATS"
                layer_params_[attribute.name] = [x for x in attribute.floats]
            elif attribute.type == AttributeProto.INT:  # "INT"
                layer_params_[attribute.name] = [attribute.i]
            elif attribute.type == AttributeProto.INTS:  # "INTS"
                layer_params_[attribute.name] = [x for x in attribute.ints]
            elif attribute.type == AttributeProto.TENSOR:  # "TENSOR"
                layer_params_[attribute.name] = onnx.numpy_helper.to_array(
                    attribute.t
                ).tolist()  # [MessageToDict(attribute.t)] #values are not proper
            elif attribute.type == AttributeProto.TENSORS:  # "TENSORS"
                layer_params_[attribute.name] = [
                    onnx.numpy_helper.to_array(t).tolist() for t in attribute.tensors
                ]
            elif attribute.type == AttributeProto.STRING:  # "STRING"
                layer_params_[attribute.name] = [attribute.s.decode("utf-8")]
            elif attribute.type == AttributeProto.STRINGS:  # "STRINGS"
                layer_params_[attribute.name] = [
                    s.decode("utf-8") for s in attribute.strings
                ]
            else:
                raise ValueError(f"Unknown attribute type: {attribute.type}")
        return layer_params_

    def add_input_info(self, node_info: dict[str, str]) -> None:
        for in_dict in node_info["inputs"]:
            signal_ = in_dict["name"]
            if (
                signal_ in self.model_initializer_shapes
                and signal_ in self.all_act_signal_shapes
            ):
                raise ValueError(
                    f"signal_ '{signal_}' found in both model_initializer_shapes and all_act_signal_shapes"
                )

            in_dict["type"] = (
                "const"
                if signal_ in self.model_initializer_shapes
                else (
                    "act"
                    if signal_ in self.global_inputs
                    or signal_ in self.all_act_signal_shapes
                    else ""
                )
            )
            in_dict["shape"] = self.get_tensor_shape(signal_)
            in_dict["dtype"] = str(
                self.all_act_signal_dtypes[signal_]
                if signal_ in self.all_act_signal_dtypes
                else (
                    self.model_initializer_dtypes[signal_]
                    if signal_ in self.model_initializer_dtypes
                    else ""
                )
            )
            in_dict["dtype_bytes"] = DTYPE_BYTES_DICT[in_dict["dtype"]]

    def add_output_info(self, node_info: dict[str, str]) -> None:
        for out_dict in node_info["outputs"]:
            signal_ = out_dict["name"]
            out_dict["type"] = (
                "act"  # global_out" if signal_ in global_outputs else "act" # model global out or act
            )
            out_dict["shape"] = self.get_tensor_shape(signal_)
            out_dict["dtype"] = str(
                self.all_act_signal_dtypes[signal_]
                if signal_ in self.all_act_signal_dtypes
                else ""
            )
            out_dict["dtype_bytes"] = DTYPE_BYTES_DICT[out_dict["dtype"]]

    def get_tensor_shape(self, tensor_name: str) -> list[int]:
        """Get the shape of a tensor by its name.

        Args:
            tensor_name (str): The name of the tensor.

        Returns:
            list[int]: The shape of the tensor. If no shape found, we assume that it's a scalar variable and we by default return [1].
        """
        shape = self.all_act_signal_shapes.get(tensor_name)
        if shape is None:
            shape = self.model_initializer_shapes.get(tensor_name)
        if shape is None:
            return None
        return [1] if shape == [] else shape

    def add_act_wgt_bias_info(self, node: NodeProto, node_info: dict[str, str]) -> None:
        """Add in_act_shape, in_wgt_shape, and in_wgt1_shape information to the node.

        Args:
            node (NodeProto): The ONNX node.
            node_info (dict[str, str]): The dictionary to store node information.
        """

        num_of_tensor_inputs = next(
            (attr.i for attr in node.attribute if attr.name == "num_of_tensor_inputs"),
            3,
        )
        node_info["in_act_shape"] = None
        node_info["in_wgt_shape"] = None
        node_info["in_wgt1_shape"] = None

        for i in range(min(len(node_info["inputs"]), num_of_tensor_inputs)):
            if i == 0:
                node_info["in_act_signal_name"] = node.input[0]
                node_info["in_act_shape"] = self.get_tensor_shape(node.input[0])
            if i == 1:
                node_info["in_wgt_shape"] = self.get_tensor_shape(node.input[1])
            if i == 2:
                node_info["in_wgt1_shape"] = self.get_tensor_shape(node.input[2])

    def add_out_act_shape(self, node: NodeProto, node_info: dict[str, str]) -> None:

        if not node_info["outputs"]:
            raise AssertionError(
                f"Node '{node.name}' , op_type={node.op_type} has no outputs; this is invalid."
            )

        SignalOut = node.output[0]  # assumption: each node will have one output
        out_act_shape = self.get_tensor_shape(SignalOut)
        node_info["out_act_signal_name"] = SignalOut
        node_info["out_act_shape"] = out_act_shape

    def add_matmul_node_info(self, node: NodeProto, node_info: dict[str, str]) -> None:
        # special case of matmul with qdq and without bias
        if "bias" in node.op_type:
            act_scale_ind = 3
            act_zp_ind = 4
            wgt_scale_ind = 5
            wgt_zp_ind = 6
        else:
            act_scale_ind = 2
            act_zp_ind = 3
            wgt_scale_ind = 4
            wgt_zp_ind = 5

        # Actual dtype was not annotated in the graph during fusion
        # because onnx only support int4 type when bits < 8.
        # Change the dtype, shape and element byte size for the weight
        # When the matmul_qdq node come from a MatMulNBits (presence
        # of the bits attribute).
        if "bits" in node_info["attributes"]:
            bits = node_info["attributes"]["bits"][0]
            dtype = f"uint{bits}"
            shape = [node_info["attributes"]["K"][0], node_info["attributes"]["N"][0]]
            bytes_size = DTYPE_BYTES_DICT[dtype]

            node_info["inputs"][1]["dtype"] = dtype
            node_info["inputs"][1]["dtype_bytes"] = bytes_size
            node_info["inputs"][1]["shape"] = shape

            node_info["inputs"][wgt_zp_ind]["dtype"] = dtype
            node_info["inputs"][wgt_zp_ind]["dtype_bytes"] = bytes_size
            node_info["inputs"][wgt_zp_ind]["shape"] = shape

            node_info["in_wgt_shape"] = shape
            node_info["wgt_datatype"] = dtype
            node_info["wgt_bytes"] = bytes_size
        # 0: act, 1: wgt, 2: act_scale, 3:act_zp, 4: wgt_scale, 5: wgt_zp, 6: out_scale, 7: out_zp
        act_scaler_qdq = (
            True
            if (
                node_info["inputs"][act_scale_ind]["shape"] == [1]
                and node_info["inputs"][act_zp_ind]["shape"] == [1]
            )
            else False
        )  # act scale and zp
        wgt_scaler_qdq = (
            True
            if (
                node_info["inputs"][wgt_scale_ind]["shape"] == [1]
                and node_info["inputs"][wgt_zp_ind]["shape"] == [1]
            )
            else False
        )  # wgt scale and zp

        N = node_info["inputs"][1]["shape"][
            -1
        ]  # wgt shape last element is N for matmul

        if act_scaler_qdq and wgt_scaler_qdq:  # both have scaler qdq params
            coeff_shape = [N]
        elif act_scaler_qdq ^ wgt_scaler_qdq:  # one of them has vector qdq params
            coeff_shape = [2 * N]
        else:  # both have vector qdq params
            coeff_shape = [4 * N]

        node_info["coeff_shape"] = coeff_shape

        ##qdq_symmetry attribute
        zero_act_zp = np.all(
            onnxTensorProto_to_array(self.init_dict[node.input[act_zp_ind]])[0] == 0
        )
        zero_wgt_zp = np.all(
            onnxTensorProto_to_array(self.init_dict[node.input[wgt_zp_ind]])[0] == 0
        )

        if zero_act_zp == False and zero_wgt_zp == False:
            # act_zp non zero, wgt_zp non zero
            # asymmetric quantization - ifmsum enabled
            node_info["qdq_symmetry"] = 0
        elif zero_act_zp == True and zero_wgt_zp == False:
            # act_zp zero, wgt_zp non zero
            # asymmetric quantization - ifmsum enabled
            node_info["qdq_symmetry"] = 1
        elif zero_act_zp == False and zero_wgt_zp == True:
            # act_zp non zero, wgt_zp zero
            # symmetric quantization - ifmsum disabled
            node_info["qdq_symmetry"] = 2
        else:
            # act_zp zero, wgt_zp zero
            # symmetric quantization - ifmsum disabled
            node_info["qdq_symmetry"] = 3

    def build_input_or_outputs(
        self,
        input_or_output_names: Iterable[str],
        op_type: str,
        param_name_lookup: Callable[[str], list[str]],
    ) -> list[dict[str, str]]:
        """
        Build initial list of all inputs or outputs of a node. Include the name of
        the input/output tensor and the parameter name if available.
        input_or_output_names -- names of input or output tensors
        op_type -- operator type of the node
        param_name_lookup -- function to get parameter names for the operator,
                            may raise KeyError if operator not known
        return -- [{"name": tensor_name, "param_name": param_name}, ...
                {"name": tensor_name}, ...]
        """
        try:
            param_names = param_name_lookup(op_type)
        except KeyError:
            param_names = []
        if len(param_names) < len(input_or_output_names):
            param_names += [None] * (len(input_or_output_names) - len(param_names))
        in_or_out_list: list[dict[str, str]] = []
        for name, param_name in zip(input_or_output_names, param_names):
            entry = {"name": name}
            if param_name is not None:
                entry["param_name"] = param_name
            in_or_out_list.append(entry)
        return in_or_out_list

    def assign_new_types(self, graph_: dict):
        # update the initialized dtypes, based on a heuristic
        # out_tensor_datatype_dict = {}
        # update out datatype of nodes based on children nodes
        for node_name, node in graph_.items():
            # update out_datatype of node based on its children
            out_low_precision = True  # set it to false if any of the children not in low_precision_input_op_types
            for i in range(len(node["children_names"])):
                if node["children_op_types"][i] not in LOW_PRECISION_INPUT_OP_TYPES:
                    out_low_precision = False
                    break
            node["out_datatype"] = (
                self.graph_info_params.low_precision_act_dtype
                if out_low_precision
                else self.graph_info_params.high_precision_act_dtype
            )

            # #store out dataype in out_tensor_datatype_dict # create a function for it later
            # out_tensor_datatype_dict[node["out_act_signal_name"]] = node["out_datatype"]

            # # update in_datatype of children nodes based on current node's out_datatype
            # # TODO: this assumes we dont have data type convertors and low precision in datatype of children will change to high if out datatype of parent is high precision
            # if node["out_datatype"] == low_precision_act_dtype:
            #   for i in range(len(node["children_names"])):
            #     node_to_update = graph_[ node["children_names"][i] ]
            #     node_to_update["in_datatype"] = low_precision_act_dtype
            # else:
            #   for i in range(len(node["children_names"])):
            #     node_to_update = graph_[ node["children_names"][i] ]
            #     node_to_update["in_datatype"] = high_precision_act_dtype

        # update in_datatype, wgt_datatype, wgt1_datatype of nodes based on parent node's out_datatype
        # TODO: this assumes we dont have data type convertors and low precision in datatype of children will change to high if out datatype of parent is high precision
        for node_name, node in graph_.items():
            for i in range(len(node["parent_names"])):
                parent_name = node["parent_names"][i]
                parent_op_type = node["parent_op_types"][i]
                if parent_op_type not in [
                    "QuantizeLinear",
                    "DequantizeLinear",
                ]:  # parent optype dq means signal is an initializer
                    if i == 0:
                        node["in_datatype"] = graph_[parent_name]["out_datatype"]
                    elif i == 1:
                        node["wgt_datatype"] = graph_[parent_name]["out_datatype"]
                    elif i == 2:
                        node["wgt1_datatype"] = graph_[parent_name]["out_datatype"]
        # TODO: add a functionality to handle in, wgt, wgt1 datatypes of quant and dequant blocks, if possible extract from the model
        # # update wgt_datatype for nodes for which both inputs come from previous node. By looking at all node out act datatypes
        # for node_name, node in graph_.items():
        #   if node["op_type"] == "MatMul" or node["op_type"] == "Gemm" or node["op_type"] == "Add":
        #     if node["out_act_signal_name"][1] in output_avtivations:
        #       node["wgt_datatype"] = out_tensor_datatype_dict[node["out_act_signal_name"][1]] # we assume 0 index to be act and 1 to be weight

        # update wgt_datatype for nodes for which both inputs come from previous nodes, and not from initializers
        # has_weights = False
        # if node["op_type"] == "MatMul" or node["op_type"] == "Gemm" or node["op_type"] == "Add":
        #   if(node["inputs"][0] in output_avtivations) and (node["inputs"][1] in output_avtivations):
        #     node["wgt_datatype"] = high_precision_wgt_dtype
        #   else:
        #     has_weights = True

    def assign_datatypes_io_signals(self, node_info: dict[str, str]) -> None:
        if (
            node_info["op_type"] in LOW_PRECISION_INPUT_OP_TYPES
        ):  # these ops will surely have atleast 2 inputs
            node_info["in_datatype"] = self.graph_info_params.low_precision_act_dtype
            node_info["wgt_datatype"] = (
                self.graph_info_params.low_precision_wgt_dtype  # wgt_datatype behaviour assumed similar to in_datatype
            )
            if len(node_info["inputs"]) > 2:
                node_info["wgt1_datatype"] = (
                    self.graph_info_params.low_precision_wgt_dtype  # bias datatype, only useful when present
                )
            node_info["out_datatype"] = self.graph_info_params.high_precision_act_dtype
        else:  # rest of the ops may have any number of input ports
            if len(node_info["inputs"]) == 1:
                node_info["in_datatype"] = (
                    self.graph_info_params.high_precision_act_dtype
                )
            elif len(node_info["inputs"]) == 2:
                node_info["in_datatype"] = (
                    self.graph_info_params.high_precision_act_dtype
                )
                node_info["wgt_datatype"] = (
                    self.graph_info_params.high_precision_wgt_dtype  # wgt_datatype behaviour assumed similar to in_datatype
                )
            else:  # 3 or more
                node_info["in_datatype"] = (
                    self.graph_info_params.high_precision_act_dtype
                )
                node_info["wgt_datatype"] = (
                    self.graph_info_params.high_precision_wgt_dtype  # wgt_datatype behaviour assumed similar to in_datatype
                )
                node_info["wgt1_datatype"] = (
                    self.graph_info_params.high_precision_wgt_dtype
                )

            if len(node_info["outputs"]) > 0:
                node_info["out_datatype"] = (
                    self.graph_info_params.high_precision_act_dtype
                )
            else:
                node_info["out_datatype"] = ""

    def assign_datatypes_from_model(
        self, node: NodeProto, node_info: dict[str, str]
    ) -> None:
        num_of_tensor_inputs = next(
            (attr.i for attr in node.attribute if attr.name == "num_of_tensor_inputs"),
            3,
        )
        input_length = min(len(node_info["inputs"]), num_of_tensor_inputs)
        if input_length == 1:
            node_info["in_datatype"] = str(
                self.all_act_signal_dtypes[node.input[0]]
                if node.input[0] in self.all_act_signal_dtypes
                else (
                    self.model_initializer_dtypes[node.input[0]]
                    if node.input[0] in self.model_initializer_dtypes
                    else ""
                )
            )
        elif input_length == 2:
            node_info["in_datatype"] = str(
                self.all_act_signal_dtypes[node.input[0]]
                if node.input[0] in self.all_act_signal_dtypes
                else (
                    self.model_initializer_dtypes[node.input[0]]
                    if node.input[0] in self.model_initializer_dtypes
                    else ""
                )
            )
            node_info["wgt_datatype"] = str(
                self.all_act_signal_dtypes[node.input[1]]
                if node.input[1] in self.all_act_signal_dtypes
                else (
                    self.model_initializer_dtypes[node.input[1]]
                    if node.input[1] in self.model_initializer_dtypes
                    else ""
                )
            )
        else:  # 3 or more
            node_info["in_datatype"] = str(
                self.all_act_signal_dtypes[node.input[0]]
                if node.input[0] in self.all_act_signal_dtypes
                else (
                    self.model_initializer_dtypes[node.input[0]]
                    if node.input[0] in self.model_initializer_dtypes
                    else ""
                )
            )
            node_info["wgt_datatype"] = str(
                self.all_act_signal_dtypes[node.input[1]]
                if node.input[1] in self.all_act_signal_dtypes
                else (
                    self.model_initializer_dtypes[node.input[1]]
                    if node.input[1] in self.model_initializer_dtypes
                    else ""
                )
            )
            node_info["wgt1_datatype"] = str(
                self.all_act_signal_dtypes[node.input[2]]
                if node.input[2] in self.all_act_signal_dtypes
                else (
                    self.model_initializer_dtypes[node.input[2]]
                    if node.input[2] in self.model_initializer_dtypes
                    else ""
                )
            )

        if len(node_info["outputs"]) > 0:
            node_info["out_datatype"] = str(
                self.all_act_signal_dtypes[node.output[0]]
                if node.output[0] in self.all_act_signal_dtypes
                else ""
            )
        else:
            node_info["out_datatype"] = ""

    def update_node_info_conv(self, node: NodeProto, node_info: dict[str, str]) -> None:
        x = node.op_type.split("_")
        if len(node_info["in_act_shape"]) > 1:
            Cin = node_info["in_act_shape"][1]  # NCHW
            if (
                "group" in node_info["attributes"]
                and len(node_info["attributes"]["group"]) > 0
                and node_info["attributes"]["group"][0] == Cin
                and len(x) > 1
            ):
                x[1] = x[1] + "_dwc"
        node_info["op_type"] = "_".join(x)

    def add_bytes_to_node_info(self, node_info: dict[str, str]) -> None:
        if "in_datatype" in node_info.keys():
            in_dtype = node_info["in_datatype"]
            if in_dtype in DTYPE_BYTES_DICT:
                node_info["in_bytes"] = DTYPE_BYTES_DICT[in_dtype]
        if "wgt_datatype" in node_info.keys():
            wgt_dtype = node_info["wgt_datatype"]
            if wgt_dtype in DTYPE_BYTES_DICT:
                node_info["wgt_bytes"] = DTYPE_BYTES_DICT[wgt_dtype]
        if "wgt1_datatype" in node_info.keys():
            wgt1_dtype = node_info["wgt1_datatype"]
            if wgt1_dtype in DTYPE_BYTES_DICT:
                node_info["wgt1_bytes"] = DTYPE_BYTES_DICT[wgt1_dtype]
        if "out_datatype" in node_info.keys():
            out_dtype = node_info["out_datatype"]
            if out_dtype in DTYPE_BYTES_DICT:
                node_info["out_bytes"] = DTYPE_BYTES_DICT[out_dtype]
