"""
This module provides common type definitions, constants, and data structures
"""

import os
from collections import defaultdict
from typing import Any, Dict, List, TypeVar

import onnx
from onnx import helper

CURRDIR = os.path.dirname(os.path.abspath(__file__))

NUM_AIE_ROWS = 4
BYTES_PER_ELEMENT = 1
REGION_SIZE = 3 * 1024 * 1024  # 3MB
REGION_COUNT = 3
SRAM_TOTAL = 9 * 1024 * 1024

PARAM_CORE_SIZE = 1024  # Memory (in bytes) reserved per core for layer parameters
PARAM_SIZE = (
    PARAM_CORE_SIZE * NUM_AIE_ROWS
)  # Total memory reserved per column for layer parameters

WEIGHT_CORE_SIZE = 64 * 1024  # Memory (in bytes) reserved per core for weights
WEIGHT_SIZE = (
    WEIGHT_CORE_SIZE * NUM_AIE_ROWS
)  # Total memory reserved per column for layer parameters
WEIGHT_PING_PONG_SIZE = (
    WEIGHT_SIZE * 2
)  # Total memory reserved per column for ping-pong weight buffers

T = TypeVar("T")  # Type of items in the list
K = TypeVar("K")  # Type of the key used for comparison


def update_tensor_shape(
    model: onnx.ModelProto, tensor_info: Dict[str, List[int]]
) -> onnx.ModelProto:
    """
    Update the shape of a tensor in an ONNX model.

    Args:
        model (ModelProto): input ONNX model
        tensor_info (Dict[str, List[int]]): a dictionary of tensor names and their shapes
    """
    # Update input shapes
    for input_info in model.graph.input:
        if input_info.name in tensor_info:
            # Clear existing dimensions
            input_info.type.tensor_type.shape.dim.clear()  # type: ignore
            # Add new dimensions
            new_shape = tensor_info[input_info.name]
            for dim_size in new_shape:
                dim = input_info.type.tensor_type.shape.dim.add()
                dim.dim_value = dim_size
            print(f"Updated input tensor '{input_info.name}' shape to {new_shape}")

    # Update output shapes
    for output_info in model.graph.output:
        if output_info.name in tensor_info:
            # Clear existing dimensions
            output_info.type.tensor_type.shape.dim.clear()  # type: ignore
            # Add new dimensions
            new_shape = tensor_info[output_info.name]
            for dim_size in new_shape:
                dim = output_info.type.tensor_type.shape.dim.add()
                dim.dim_value = dim_size
            print(f"Updated output tensor '{output_info}' shape to {new_shape}")

    # Update intermediate value info (for tensors in the middle of the graph)
    for value_info in model.graph.value_info:
        if value_info.name in tensor_info:
            # Clear existing dimensions
            value_info.type.tensor_type.shape.dim.clear()  # type: ignore
            # Add new dimensions
            new_shape = tensor_info[value_info.name]
            for dim_size in new_shape:
                dim = value_info.type.tensor_type.shape.dim.add()
                dim.dim_value = dim_size
            print(
                f"Updated intermediate tensor '{value_info.name}' shape to {new_shape}"
            )

    # Check and run model checker
    try:
        onnx.checker.check_model(model)
        model = onnx.shape_inference.infer_shapes(model)
    except Exception as e:  # pylint: disable=broad-except
        print(f"Warning: Model validation failed: {e}")

    return model


def simplify_model(model: onnx.ModelProto) -> onnx.ModelProto:
    """
    Simplifies an ONNX model using `onnxsim`, performs shape inference, and validates the model.

    Args:
        model (onnx.ModelProto): The original ONNX model to simplify and validate.

    Returns:
        onnx.ModelProto: The simplified and validated ONNX model with inferred shapes.

    Raises:
        AssertionError: If the simplified model fails validation.
    """
    import onnxsim  # pylint: disable=C0415
    model, check = onnxsim.simplify(model)
    if not check:
        raise ValueError("Simplified ONNX model could not be validated")
    model = onnx.shape_inference.infer_shapes(model)
    onnx.checker.check_model(model)
    return model


def simplify_yolov3(model_path=os.path.join(CURRDIR, "YoloV3_INT8_Model.onnx")) -> onnx.ModelProto:
    """
    Simplifies the YOLOv3 ONNX model and infers missing shape information to ensure
    all tensor shapes are known.
    """
    tensor_info = {
        "Resize__255:0": [1, 256, 26, 26],
        "model/concatenate/concat:0": [1, 768, 26, 26],
        "Resize__299:0": [1, 128, 52, 52],
        "model/concatenate_1/concat:0": [1, 384, 52, 52],
    }
    model = onnx.load(model_path)
    model = simplify_model(model)
    model = update_tensor_shape(model, tensor_info)
    return model


def get_chained_removal_map(graph: Any, ops_to_remove: Any) -> Dict[str, Any]:
    """Recursive DFS collecting removable nodes after from_node"""
    output_to_node = {}
    input_consumers = defaultdict(list)

    # Build mappings
    for node in graph.node:
        for output in node.output:
            output_to_node[output] = node
        for input_name in node.input:
            input_consumers[input_name].append(node)

    visited = set()
    removal_chains = defaultdict(list)

    def dfs(start_node: Any, from_node: Any) -> None:
        """Recursive DFS collecting removable nodes after from_node"""
        if start_node.name in visited:
            return
        visited.add(start_node.name)

        if start_node.op_type in ops_to_remove:
            key = from_node if isinstance(from_node, str) else from_node.name
            removal_chains[key].append(start_node.name)

            for out in start_node.output:
                for consumer in input_consumers.get(out, []):
                    dfs(consumer, from_node)

    # DFS from non-removable nodes
    for node in graph.node:
        if node.op_type not in ops_to_remove:
            for out in node.output:
                for consumer in input_consumers.get(out, []):
                    dfs(consumer, node)

    # DFS from input tensors (to catch QDQ before first op)
    for input_info in graph.input:
        input_name = input_info.name
        for consumer in input_consumers.get(input_name, []):
            dfs(consumer, input_name)  # Use real input name as key

    return dict(removal_chains)


def remove_nodes_by_op_type_chained(
    onnx_model_path: str, ops_to_remove: Any
) -> onnx.ModelProto:
    """Recursive DFS collecting removable nodes after from_node"""
    model = onnx.load(onnx_model_path)
    graph = model.graph

    # Step 2: Build producer and consumer maps
    output_to_node = {out: node for node in graph.node for out in node.output}
    input_consumers = defaultdict(list)
    for node in graph.node:
        for inp in node.input:
            input_consumers[inp].append(node)

    # Step 3: Identify removable nodes
    removable_nodes = [node for node in graph.node if node.op_type in ops_to_remove]

    # Step 4: Rewire edges
    for node in removable_nodes:
        inputs = list(node.input)
        outputs = list(node.output)

        # If the current node is ReLU, tag upstream producer node
        if node.op_type == "Relu" and inputs:
            relu_input = inputs[0]
            if relu_input in output_to_node:
                upstream_node = output_to_node[relu_input]
                new_attr = helper.make_attribute("is_relu", 1)
                upstream_node.attribute.append(new_attr)

        for output in outputs:
            if output not in input_consumers:
                continue
            consumers = input_consumers[output]
            for consumer in consumers:
                new_inputs = []
                for inp in consumer.input:
                    if inp == output:
                        new_inputs.extend(inputs)
                    else:
                        new_inputs.append(inp)
                consumer.input[:] = new_inputs

        graph.node.remove(node)

    # Step 4.5: Tag all Conv nodes without ReLU as is_relu = 0
    for node in graph.node:
        if node.op_type == "Conv":
            has_is_relu = any(attr.name == "is_relu" for attr in node.attribute)
            if not has_is_relu:
                zero_attr = helper.make_attribute("is_relu", 0)
                node.attribute.append(zero_attr)

    return model
