"""Tests for L2L3 allocator"""
import os
import tempfile
import subprocess
import pytest
from graph.clean_onnx_graph import main as clean_onnx_model
from graph.utilities import logger
from aie4_bench.common import patch_alloc_json, generate_alloc_json, squeeze_tensor_shapes, unsqueeze_tensor_shapes, set_node_attributes, change_node_op_type


CURRDIR = os.path.dirname(os.path.abspath(__file__))


@pytest.mark.e2e_model_bench
@pytest.mark.xdist_group("serial")
def test_resnet50():
    """Benchmark ResNet50 end-to-end"""
    # Clean the model
    model_path = os.path.join(CURRDIR, "../graph/ResNet50_INT8_Model.onnx")
    cleaned_model_path = tempfile.mktemp(suffix="ResNet50_INT8_Model.onnx")
    clean_onnx_model(
        input_model_path=model_path,
        output_model_path=cleaned_model_path,
        is_dtype_int8=True,
        ops_to_remove=["Relu", "LeakyRelu", "QuantizeLinear", "DequantizeLinear", "Flatten", "Identity"]
    )

    # Unsqueeze Gemm input/output shapes: (1, N) -> (1, 1, 1, N)
    unsqueeze_tensor_shapes(cleaned_model_path, ["output_identity"], axis=0)
    unsqueeze_tensor_shapes(cleaned_model_path, ["output_identity"], axis=0)

    # Squeeze GlobalAveragePool input/output shapes: (1, N, 1, 1) -> (1, 1, 1, N)
    squeeze_tensor_shapes(cleaned_model_path, ["/avgpool/GlobalAveragePool_output_0"], axis=-1)
    squeeze_tensor_shapes(cleaned_model_path, ["/avgpool/GlobalAveragePool_output_0"], axis=-1)
    unsqueeze_tensor_shapes(cleaned_model_path, ["/avgpool/GlobalAveragePool_output_0"], axis=0)
    unsqueeze_tensor_shapes(cleaned_model_path, ["/avgpool/GlobalAveragePool_output_0"], axis=0)

    # Change GEMM act x wgt (int8) to Conv (a8w8) attributes
    change_node_op_type(cleaned_model_path, "Gemm", "Conv")
    set_node_attributes(model_path, "Gemm", [
        ("dilations", [1, 1]),
        ("group", 1),
        ("kernel_shape", [1, 1]),
        ("pads", [0, 0, 0, 0]),
        ("strides", [1, 1])
    ])

    # Generate allocation JSON
    _, alloc_data = generate_alloc_json(model_path=cleaned_model_path)

    # Patch the json
    patched_json = patch_alloc_json(alloc_data, {
        "Add": "Add_qdq_EleWise_uint8xuint8xuint8",
        "Conv": "conv_noqdq_a8w8",
        "MaxPool": "maxpool_noqdq_a8",
        "GlobalAveragePool": "gap",
    })

    # call build_aie4.py
    try:
        result = subprocess.call([
            "python", "build_aie4.py",
            "--json", str(patched_json),
            "--target", "cert",
            "--clean"
        ])
        assert result == 0, f"Build failed with exit code {result}"
    except Exception as e:  # pylint: disable=broad-except
        logger.info("DI_FAIL")
        pytest.fail(f"Build process failed: {e}")
