# (c) Copyright 2022 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.


from dataclasses import dataclass
import logging
import onnx
from OGOAT.src.L1_fusion.L1_utils.utils import onnxTensor_dtype_to_np_dtype

## TODO: need to keep updating these lists as we encounter new ops
# list of linear operators
# low_precision_input_op_types = ["Conv", "Gemm", "MatMul", "Add"]
LOW_PRECISION_INPUT_OP_TYPES = [
    "Conv",
    "Gemm",
    "MatMul",
    "Add",
    "qdq_matmul_uint16_uint8_cstm",
    "qdq_conv2d_weightsZPeq0_uint16_cstm",
    "qdq_matmul_uint16_uint16_cstm",
]
# list of shape changing, no compute ops or ops which are going to be fused later
# used in child and parent nodes calculation
BYPASS_OP_TYPES = [
    "Concat",
    "Slice",
    "Reshape",
    "Transpose",
    "Shape",
    "Gather",
    "Squeeze",
    "Unsqueeze",
    "QuantizeLinear",
    "DequantizeLinear",
]
DTYPE_BYTES_DICT = {
    "bool": 1 / 8,
    "mx9": 9 / 8,
    "bfp16": 9 / 8,
    "float16": 2,
    "bfloat16": 2,
    "fp32": 4,
    "float32": 4,
    "double": 8,
    "int2": 1 / 4,
    "uint2": 1 / 4,
    "int4": 1 / 2,
    "uint4": 1 / 2,
    "int8": 1,
    "uint8": 1,
    "int16": 2,
    "uint16": 2,
    "int32": 4,
    "uint32": 4,
    "int64": 8,
    "uint64": 8,
    "": 0,
}


@dataclass(frozen=True)
class GraphInfoParams:
    # assign new datatype flag
    assign_new_dtypes: int
    # datatype for Linear layers, Options: ["int8", "mx9", "fp32", "bfloat16", "bfp16"], # "uint16" for win24 # "bfp16" for sdxl turbo
    low_precision_act_dtype: str
    # datatype for non linear layers, "uint16" for win24 # "bfloat16" for sdxl turbo
    high_precision_act_dtype: str
    # "uint8" for win24 # "bfp16" for sdxl turbo
    low_precision_wgt_dtype: str
    # datatype for non linear layers,"uint16" for win24 # "bfloat16" for sdxl turbo
    high_precision_wgt_dtype: str
    # flag to avoid any downcasting for all ops in the model
    no_dtype_downcast: bool
    # flag to indicate the chosen device type
    device: str


def get_act_signal_shapes(
    graph: onnx.GraphProto,
) -> tuple[dict[str, list[int]], dict[str, str]]:
    """
    Return shapes and dtypes of activation signals in the model.
    return -- (shapes, dtypes)
      shapes[act_signal_name] = list of dimensions
      dtypes[act_signal_name] = name of data type as string
    """
    all_act_signal_shapes: dict[str, list[int]] = {}
    all_act_signal_dtypes: dict[str, str] = {}

    initializer_names = set(
        initializer.name for initializer in graph.initializer
    )
    for vi in graph.value_info:
        if vi.name in initializer_names:
            continue
        dimensions = vi.type.tensor_type.shape.dim
        all_act_signal_shapes[vi.name] = [dim.dim_value for dim in dimensions]
        all_act_signal_dtypes[vi.name] = onnxTensor_dtype_to_np_dtype(
            vi.type.tensor_type.elem_type
        )

    print(
        f"Got shapes for {len(all_act_signal_shapes):d} activation signals..."
    )

    return all_act_signal_shapes, all_act_signal_dtypes


def get_initializer_shapes(
    graph: onnx.GraphProto,
) -> tuple[dict[str, list[int]], dict[str, str]]:
    """
    Return shapes and dtypes of initializers in the model.
    return -- (shapes, dtypes)
      shapes[initializer_name] = list of dimensions
      dtypes[initializer_name] = name of data type as string
    """
    initializer_shapes_dict = {}
    initializer_dtype_dict = {}
    for initializer in graph.initializer:
        if initializer.name not in initializer_shapes_dict:
            initializer_shapes_dict[initializer.name] = [
                int(d) for d in initializer.dims
            ]
            d_type = initializer.data_type
            initializer_dtype_dict[initializer.name] = (
                onnxTensor_dtype_to_np_dtype(d_type)
            )

    print(
        f"Got shapes for {len(initializer_shapes_dict):d} model initializers.."
    )

    return initializer_shapes_dict, initializer_dtype_dict


def get_top_level_details_from_model(model: onnx.ModelProto):
    graph_inputs = {}
    for input in model.graph.input:
        graph_inputs[input.name] = [
            d.dim_value if d.dim_value else d.dim_param
            for d in input.type.tensor_type.shape.dim
        ]

    graph_outputs = {}
    for output in model.graph.output:
        graph_outputs[output.name] = [
            d.dim_value if d.dim_value else d.dim_param
            for d in output.type.tensor_type.shape.dim
        ]

    all_ops = set()
    total_nodes = 0
    for node in model.graph.node:
        all_ops.add(node.op_type)
        total_nodes += 1

    # sort all_ops set when converting to list to make return value stable
    return sorted(all_ops), total_nodes, graph_inputs, graph_outputs


# FIXME same implemented in L1_utils/utils.py
def construct_tensor_in_out_dict(graph: onnx.GraphProto):
    # key is tensor name, value is all possible node names to which the tensor is input
    in_tensors_dict = {}
    # key is tensor name, value is node name for which the tensor is output
    out_tensors_dict = {}
    for node in graph.node:
        for node_input in node.input:
            if node_input not in in_tensors_dict:
                in_tensors_dict[node_input] = []
                in_tensors_dict[node_input].append(node.name)
            else:
                in_tensors_dict[node_input].append(node.name)

        for node_output in node.output:
            if node_output not in out_tensors_dict:
                out_tensors_dict[node_output] = []
                out_tensors_dict[node_output].append(node.name)
            else:
                out_tensors_dict[node_output].append(node.name)

    return in_tensors_dict, out_tensors_dict


def get_child_nodes_dict(
    graph_info_dict, in_tensors_dict, ops_to_bypass, bypass_ops=True
):
    # in_tensors_dict >> key is tensor name, value is all possible node names to which the tensor is input
    # create dictionary of child nodes on the basis of activation signals
    # add list of children to node properties
    for node_name, node in graph_info_dict.items():
        child_name_list = []
        child_op_type_list = []
        for node_output in node["outputs"]:
            out_name = node_output["name"]
            if out_name in in_tensors_dict:
                child_name_list.extend(in_tensors_dict[out_name])
        for child_node_name in child_name_list:
            if child_node_name not in graph_info_dict:
                continue
            child_node_info = graph_info_dict[child_node_name]
            if child_node_info and "op_type" in child_node_info:
                child_op_type_list = [child_node_info["op_type"]]

        node["children_names"] = child_name_list
        node["children_op_types"] = child_op_type_list

    if bypass_ops:
        # update dictionary of child nodes, bypassing shape changing operators
        for node_name, node in graph_info_dict.items():
            # iterate through child nodes and update them with usefyl children
            node_children_names = node["children_names"]
            for index, node_children_name in enumerate(node_children_names):
                if node_children_name not in graph_info_dict:
                    continue
                node_to_check = graph_info_dict[node_children_name]
                if not node_to_check:
                    continue
                while True:
                    if not node_to_check:
                        break
                    if (
                        "op_type" in node_to_check
                        and node_to_check["op_type"] in ops_to_bypass
                        and len(node_to_check["children_names"]) > 0
                    ):
                        # update that child with child's child (first child)
                        # this is a limitation, updating with the first child and not checking others
                        first_child_name = node_to_check["children_names"][0]
                        if first_child_name in graph_info_dict:
                            node_to_check = graph_info_dict[first_child_name]
                        else:
                            break
                    else:
                        if len(node_children_names) > index and "node_name" in node_to_check:
                            node_children_names[index] = node_to_check["node_name"]
                        if len(node["children_op_types"]) > index and "op_type" in node_to_check:
                            node["children_op_types"][index] = node_to_check[
                                "op_type"
                            ]
                        break
    return graph_info_dict


def get_parent_nodes_dict(
    graph_info_dict, out_tensors_dict, ops_to_bypass, bypass_ops=True
):
    # out_tensors_dict >> key is tensor name, value is node from which this tensors comes

    # create dictionary of child nodes on the basis of activation signals
    # add list of children to node properties
    for node_name, node in graph_info_dict.items():
        parent_name_list = []
        parent_op_type_list = []
        for node_input in node[
            "inputs"
        ]:  # for each input signal, correspondingly add parent node
            if node_input["name"] in out_tensors_dict:
                parent_name_list.extend(out_tensors_dict[node_input["name"]])
        for parent_node_name in parent_name_list:
            if parent_node_name not in graph_info_dict:
                continue
            parent_node_info = graph_info_dict[parent_node_name]
            if parent_node_info and "op_type" in parent_node_info:
                parent_op_type_list = [parent_node_info["op_type"]]

        node["parent_names"] = parent_name_list
        node["parent_op_types"] = parent_op_type_list

    if bypass_ops:
        # update dictionary of parent nodes, bypassing shape changing operators
        for node_name, node in graph_info_dict.items():
            # iterate through parent nodes and update them with usefyl children
            parent_names = node["parent_names"]
            for index, parent_name in enumerate(parent_names):
                if parent_name not in graph_info_dict:
                    continue
                node_to_check = graph_info_dict[parent_name]
                if not node_to_check:
                    continue
                while True:
                    if not node_to_check:
                        break
                    if (
                        "op_type" in node_to_check
                        and node_to_check["op_type"] in ops_to_bypass
                        and len(node_to_check["parent_names"]) > 0
                    ):
                        # update that parent with parent's parent (first parent)
                        # this is a limitation, updating with the first parent and not checking others
                        first_parent_name = node_to_check["parent_names"][0]
                        if first_parent_name in graph_info_dict:
                            node_to_check = graph_info_dict[
                                first_parent_name
                            ]
                        else:
                            break
                    else:
                        if len(parent_names) > index and "node_name" in node_to_check:
                            parent_names[index] = node_to_check["node_name"]
                        if len(node["parent_op_types"]) > index and "op_type" in node_to_check:
                            node["parent_op_types"][index] = node_to_check[
                                "op_type"
                            ]
                        break

    return graph_info_dict


# checks the size of activations at each stage and tags them with residency (l2 or l3)
# outputs 1 if it can reside in memtile
def get_activations_residency(
    in_samples, out_samples, node_name, overlay, in_type, out_type, has_weights
):  # add datatype (input and output both) param
    out = 1
    bits_for = {"int8": 8, "mx9": 9, "fp32": 32, "bfloat16": 16, "bfp16": 9}
    # sum of input and output sizes in bytes
    total_samples_in_bytes = (
        in_samples * bits_for[in_type] + out_samples * bits_for[out_type]
    ) / 8

    if (
        has_weights == True
    ):  # op_type == "Conv": and MatMul with constant second input
        mem_tile_size = (
            (overlay.memtile_capacity - 64)
            * overlay.memtile_rows
            * overlay.num_columns
            * 1024
        )  # -64kb for weights # why is it fixed?
    else:
        mem_tile_size = (
            overlay.memtile_capacity
            * overlay.memtile_rows
            * overlay.num_columns
        ) * 1024

    if total_samples_in_bytes == None:
        out = None
    elif total_samples_in_bytes > mem_tile_size:
        out = 0
        logging.debug(
            "Activations for node " + node_name + " can't fit in MemTile (L2)"
        )
    else:
        logging.debug(
            "Activations for node " + node_name + " can fit in MemTile (L2)"
        )
        out = 1

    return out
