"""Common utilities and functions for the AIE4 benchmark suite"""
import json
import tempfile
from itertools import chain
from typing import Union
import onnx
from onnx import helper
import graph.L2L3_allocator as l2l3_alloc
from graph.utilities import logger


def set_node_attributes(
    model_path: str,
    op_type: str,
    attributes: Union[dict, list[tuple]],
    node_name: str | None = None,
    output_path: str | None = None
) -> str:
    """
    Add or update attributes on ONNX nodes of a given op type.

    Args:
        model_path: Path to the ONNX model
        op_type: ONNX operator type to match (e.g., "Gemm", "Conv")
        attributes: Dict or list of tuples with attribute name-value pairs.
                   Values can be int, float, list[int], list[float], str.
        node_name: Optional specific node name to match (if None, matches all nodes of op_type)
        output_path: Optional output path. If None, overwrites input model.

    Returns:
        Path to the modified model
    """
    output_path = output_path or model_path
    model = onnx.load(model_path)

    if isinstance(attributes, dict):
        attributes = list(attributes.items())

    for node in model.graph.node:
        if node.op_type != op_type:
            continue
        if node_name and node.name != node_name:
            continue

        # Remove existing attributes that we're updating
        attr_names_to_set = {name for name, _ in attributes}
        attrs_to_keep = [attr for attr in node.attribute if attr.name not in attr_names_to_set]
        del node.attribute[:]
        node.attribute.extend(attrs_to_keep)

        # Add new attributes
        for name, value in attributes:
            if isinstance(value, int):
                node.attribute.append(helper.make_attribute(name, value))
            elif isinstance(value, float):
                node.attribute.append(helper.make_attribute(name, value))
            elif isinstance(value, str):
                node.attribute.append(helper.make_attribute(name, value))
            elif isinstance(value, (list, tuple)):
                node.attribute.append(helper.make_attribute(name, value))
            else:
                raise ValueError(f"Unsupported attribute type for {name}: {type(value)}")

        logger.info("Set attributes on %s node '%s': %s", op_type, node.name, dict(attributes))

    onnx.save(model, output_path)
    return output_path


def change_node_op_type(
    model_path: str,
    from_op_type: str,
    to_op_type: str,
    node_name: str | None = None,
    output_path: str | None = None
) -> str:
    """
    Change the op_type of ONNX nodes.

    Args:
        model_path: Path to the ONNX model
        from_op_type: Original op type to match (e.g., "Gemm")
        to_op_type: New op type to set (e.g., "Conv")
        node_name: Optional specific node name to match (if None, matches all nodes of from_op_type)
        output_path: Optional output path. If None, overwrites input model.

    Returns:
        Path to the modified model
    """
    output_path = output_path or model_path
    model = onnx.load(model_path)

    for node in model.graph.node:
        if node.op_type != from_op_type:
            continue
        if node_name and node.name != node_name:
            continue
        old_op = node.op_type
        node.op_type = to_op_type
        logger.info("Changed node '%s' op_type: %s -> %s", node.name, old_op, to_op_type)

    onnx.save(model, output_path)
    return output_path


def unsqueeze_tensor_shapes(
    model_path: str,
    tensor_names: list[str],
    axis: int = 1,
    output_path: str | None = None
) -> str:
    """
    Unsqueeze tensor shapes by inserting a dimension of size 1 at the specified axis.
    E.g., (1, N) with axis=1 becomes (1, 1, N).

    Args:
        model_path: Path to the ONNX model
        tensor_names: List of tensor names to unsqueeze
        axis: Position to insert the new dimension (default: 1)
        output_path: Optional output path. If None, overwrites input model.

    Returns:
        Path to the modified model
    """
    output_path = output_path or model_path
    model = onnx.load(model_path)
    tensor_set = set(tensor_names)

    all_tensors = chain(model.graph.value_info, model.graph.input, model.graph.output)
    for item in all_tensors:
        if item.name not in tensor_set:
            continue
        shape_dims = item.type.tensor_type.shape.dim
        old_shape = [d.dim_value for d in shape_dims]
        new_shape = old_shape[:axis] + [1] + old_shape[axis:]
        while shape_dims:
            shape_dims.pop()
        for dim in new_shape:
            shape_dims.add().dim_value = dim
        logger.info("Unsqueezed tensor '%s': %s -> %s", item.name, old_shape, new_shape)

    onnx.save(model, output_path)
    return output_path


def squeeze_tensor_shapes(
    model_path: str,
    tensor_names: list[str],
    axis: int = 1,
    output_path: str | None = None
) -> str:
    """
    Squeeze tensor shapes by removing a dimension of size 1 at the specified axis.
    E.g., (1, 1, N) with axis=1 becomes (1, N).

    Args:
        model_path: Path to the ONNX model
        tensor_names: List of tensor names to squeeze
        axis: Position to remove the dimension (default: 1)
        output_path: Optional output path. If None, overwrites input model.

    Returns:
        Path to the modified model
    """
    output_path = output_path or model_path
    model = onnx.load(model_path)
    tensor_set = set(tensor_names)

    all_tensors = chain(model.graph.value_info, model.graph.input, model.graph.output)
    for item in all_tensors:
        if item.name not in tensor_set:
            continue
        shape_dims = item.type.tensor_type.shape.dim
        old_shape = [d.dim_value for d in shape_dims]
        # Normalize negative axis
        norm_axis = axis if axis >= 0 else len(old_shape) + axis
        if 0 <= norm_axis < len(old_shape) and old_shape[norm_axis] == 1:
            new_shape = old_shape[:norm_axis] + old_shape[norm_axis + 1:]
            while shape_dims:
                shape_dims.pop()
            for dim in new_shape:
                shape_dims.add().dim_value = dim
            logger.info("Squeezed tensor '%s': %s -> %s", item.name, old_shape, new_shape)

    onnx.save(model, output_path)
    return output_path


def unsqueeze_op_shapes(
    model_path: str,
    op_type: str,
    axis: int = 0,
    output_path: str | None = None
) -> str:
    """
    Unsqueeze input and output shapes for all nodes of a given op type.

    Args:
        model_path: Path to the ONNX model
        op_type: ONNX operator type (e.g., "Gemm", "MatMul")
        axis: Position to insert the new dimension (default: 0)
        output_path: Optional output path. If None, overwrites input model.

    Returns:
        Path to the modified model
    """
    model = onnx.load(model_path)
    tensor_names = []
    for node in model.graph.node:
        if node.op_type == op_type:
            tensor_names.extend(node.input)
            tensor_names.extend(node.output)
    return unsqueeze_tensor_shapes(model_path, tensor_names, axis, output_path)


def squeeze_op_shapes(
    model_path: str,
    op_type: str,
    axis: int = 0,
    output_path: str | None = None
) -> str:
    """
    Squeeze input and output shapes for all nodes of a given op type.

    Args:
        model_path: Path to the ONNX model
        op_type: ONNX operator type (e.g., "Gemm", "MatMul")
        axis: Position to remove the dimension (default: 0)
        output_path: Optional output path. If None, overwrites input model.

    Returns:
        Path to the modified model
    """
    model = onnx.load(model_path)
    tensor_names = []
    for node in model.graph.node:
        if node.op_type == op_type:
            tensor_names.extend(node.input)
            tensor_names.extend(node.output)
    return squeeze_tensor_shapes(model_path, tensor_names, axis, output_path)


def generate_alloc_json(model_path) -> tuple[str, dict]:
    """Generate allocation JSON for a given model path"""
    alloc_json = tempfile.mktemp(suffix=".json")
    l2l3_alloc.main(l2l3_alloc.Command(
        model_path=model_path,
        c64=False,
        fusion_json_path=alloc_json,
        both_l2l3=True,
        is_nonwaic=True
    ))

    logger.info("Allocation JSON generated at: %s", alloc_json)
    with open(alloc_json, "r", encoding="utf-8") as f:
        alloc_data = json.load(f)

    for op in alloc_data.values():
        assert op.get("enable_L2_fusion"), f"operator {op} must be l2 fused"

    return alloc_json, alloc_data


def patch_alloc_json(alloc_data: dict, mapping: dict) -> str:
    """Patch allocation json for operators"""
    for op in alloc_data.values():
        if op["op"] in mapping:
            op["op"] = mapping[op["op"]]
        else:
            raise ValueError(f"Unexpected operator {op['op']} found in allocation json")

    patched_json = tempfile.mktemp(suffix=".json")
    with open(patched_json, "w", encoding="utf-8") as f:
        json.dump(alloc_data, f, indent=4)
    logger.info("Patched JSON generated at: %s", patched_json)

    return patched_json
