"""
Regression test for the gemm layers
"""

from enum import Enum
import os
from typing import Optional
import math
import pytest
import typer

from common import (
    process_simulation_results,
    create_hw_package,
    write_csv,
    default_row_mapper,
    change_dir,
    BuildTarget,
    DataflowType,
    run_hw_validation,
    clean_output_dir,
    Counter,
)
from build_aie4 import compile_operator, out_dir_name_from_dict, get_combined_kernel_list
from graph.utilities import config_logger_from_env

CURRDIR = os.path.dirname(os.path.abspath(__file__))
config_logger_from_env()
pdi_counter = Counter()


GEMM_SHAPES_ACT_ACT = [
    # ['Yi', 'Xi', 'Ci', 'Yo', 'Xo', 'Co', 'Ky', 'Kx', 'Sy', 'Sx', 'Py', 'Px']
    # ACT x ACT - start index 0
    [1, 16 * 1, 128, 1, 16 * 1, 64, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128, 1, 16 * 2, 64, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128, 1, 16 * 3, 64, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128, 1, 16 * 4, 64, 1, 1, 1, 1, 0, 0],
    [1, 16, 128, 1, 16, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16, 128, 1, 16, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16, 128, 1, 16, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128, 1, 16 * 2, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128, 1, 16 * 3, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128, 1, 16 * 4, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128, 1, 16 * 2, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128, 1, 16 * 3, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128, 1, 16 * 4, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128, 1, 16 * 2, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128, 1, 16 * 3, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128, 1, 16 * 4, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 1, 128 * 2, 1, 16 * 1, 64, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128 * 2, 1, 16 * 2, 64, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128 * 2, 1, 16 * 3, 64, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128 * 2, 1, 16 * 4, 64, 1, 1, 1, 1, 0, 0],
    [1, 16 * 1, 128 * 2, 1, 16 * 1, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16 * 1, 128 * 2, 1, 16 * 1, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16 * 1, 128 * 2, 1, 16 * 1, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128 * 2, 1, 16 * 2, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128 * 2, 1, 16 * 3, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128 * 2, 1, 16 * 4, 64 * 2, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128 * 2, 1, 16 * 2, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128 * 2, 1, 16 * 3, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128 * 2, 1, 16 * 4, 64 * 3, 1, 1, 1, 1, 0, 0],
    [1, 16 * 2, 128 * 2, 1, 16 * 2, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 3, 128 * 2, 1, 16 * 3, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128 * 2, 1, 16 * 4, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 5, 128 * 2, 1, 16 * 5, 64 * 5, 1, 1, 1, 1, 0, 0],
    # Additional shapes for better coverage
    [1, 16 * 4, 192, 1, 16 * 4, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 4, 128 * 2 + 64, 1, 16 * 4, 64 * 4, 1, 1, 1, 1, 0, 0],
    [1, 16 * 5, 128 * 2 + 64, 1, 16 * 5, 64 * 5, 1, 1, 1, 1, 0, 0],
    [1, 7, 88, 1, 7, 19, 1, 1, 1, 1, 0, 0],
    [1, 16, 77, 1, 16, 19, 1, 1, 1, 1, 0, 0],
    [1, 16, 77, 1, 16, 64, 1, 1, 1, 1, 0, 0],
    # # DI fail [1, 18, 90, 1, 18, 77, 1, 1, 1, 1, 0, 0],
    # PSMU_ST0
    [1, 64, 64, 1, 64, 64, 1, 1, 1, 1, 0, 0],
    # PSMU_ST1
    [1, 1, 64, 1, 1, 64, 1, 1, 1, 1, 0, 0],
    # PSR
    [1, 16, 256, 1, 16, 64, 1, 1, 1, 1, 0, 0],
    [1, 16, 64, 1, 16, 256, 1, 1, 1, 1, 0, 0],
    [1, 16, 64, 1, 16, 77, 1, 1, 1, 1, 0, 0],
    [1, 16, 77, 1, 16, 64, 1, 1, 1, 1, 0, 0],
    [1, 32, 1024, 1, 32, 64, 1, 1, 1, 1, 0, 0],
    [1, 32, 64, 1, 32, 1024, 1, 1, 1, 1, 0, 0],
    [1, 32, 64, 1, 32, 77, 1, 1, 1, 1, 0, 0],
    [1, 32, 77, 1, 32, 64, 1, 1, 1, 1, 0, 0],
    [1, 64, 4096, 1, 64, 64, 1, 1, 1, 1, 0, 0],
    [1, 64, 64, 1, 64, 4096, 1, 1, 1, 1, 0, 0],
    [1, 64, 64, 1, 64, 77, 1, 1, 1, 1, 0, 0],
    [1, 64, 77, 1, 64, 64, 1, 1, 1, 1, 0, 0],
    [1, 8, 64, 1, 8, 64, 1, 1, 1, 1, 0, 0],
    [1, 8, 64, 1, 8, 77, 1, 1, 1, 1, 0, 0],
    [1, 8, 77, 1, 8, 64, 1, 1, 1, 1, 0, 0],
    # PSJ
    [12, 77, 96, 12, 77, 77, 1, 1, 1, 1, 0, 0],
    # PSD1
    [12, 77, 64, 12, 77, 77, 1, 1, 1, 1, 0, 0],
    [12, 77, 77, 12, 77, 64, 1, 1, 1, 1, 0, 0],
    # PSD2
    [20, 77, 64, 20, 77, 77, 1, 1, 1, 1, 0, 0],
    [20, 77, 77, 20, 77, 64, 1, 1, 1, 1, 0, 0],
    # PSD3/4/5
    [10, 1024, 1024, 10, 1024, 64, 1, 1, 1, 1, 0, 0],
    [10, 1024, 64, 10, 1024, 1024, 1, 1, 1, 1, 0, 0],
    [10, 1024, 64, 10, 1024, 77, 1, 1, 1, 1, 0, 0],
    [10, 1024, 77, 10, 1024, 64, 1, 1, 1, 1, 0, 0],
    [20, 256, 256, 20, 256, 64, 1, 1, 1, 1, 0, 0],
    [20, 256, 64, 20, 256, 256, 1, 1, 1, 1, 0, 0],
    [20, 256, 64, 20, 256, 77, 1, 1, 1, 1, 0, 0],
    [20, 256, 77, 20, 256, 64, 1, 1, 1, 1, 0, 0],
    # clip common 16/32/laion
    [1, 10, 512, 1, 10, 1, 1, 1, 1, 1, 0, 0],
    [12, 50, 50, 12, 50, 64, 1, 1, 1, 1, 0, 0],
    [12, 50, 64, 12, 50, 50, 1, 1, 1, 1, 0, 0],
    [80, 77, 64, 80, 77, 77, 1, 1, 1, 1, 0, 0],
    [80, 77, 77, 80, 77, 64, 1, 1, 1, 1, 0, 0],
    [12, 197, 197, 12, 197, 64, 1, 1, 1, 1, 0, 0],
    [12, 197, 64, 12, 197, 197, 1, 1, 1, 1, 0, 0],
    # Google BERT
    [12, 128, 128, 12, 128, 64, 1, 1, 1, 1, 0, 0],
    [12, 128, 64, 12, 128, 128, 1, 1, 1, 1, 0, 0],
    # PSH
    [12, 512, 512, 12, 512, 64, 1, 1, 1, 1, 0, 0],
    [12, 512, 96, 12, 512, 512, 1, 1, 1, 1, 0, 0],
    # PSI
    [16, 32, 196, 16, 32, 32, 1, 1, 1, 1, 0, 0],
    # # DI fail Co pattern shim to mem [16, 32, 32, 16, 32, 196, 1, 1, 1, 1, 0, 0],
    [32, 32, 32, 32, 32, 49, 1, 1, 1, 1, 0, 0],
    [32, 32, 49, 32, 32, 32, 1, 1, 1, 1, 0, 0],
    [32, 49, 32, 32, 49, 49, 1, 1, 1, 1, 0, 0],
    [32, 49, 49, 32, 49, 32, 1, 1, 1, 1, 0, 0],
    [4, 32, 3136, 4, 32, 32, 1, 1, 1, 1, 0, 0],
    [4, 32, 32, 4, 32, 3136, 1, 1, 1, 1, 0, 0],
    [8, 32, 32, 8, 32, 784, 1, 1, 1, 1, 0, 0],
    # # # DMAC fail[8, 32, 784, 8, 32, 32, 1, 1, 1, 1, 0, 0],
    [16 * 8, 49, 32, 16 * 8, 49, 49, 1, 1, 1, 1, 0, 0],
    [16 * 8, 49, 49, 16 * 8, 49, 32, 1, 1, 1, 1, 0, 0],
    [4 * 16, 49, 32, 4 * 16, 49, 49, 1, 1, 1, 1, 0, 0],
    [64 * 4, 49, 32, 64 * 4, 49, 49, 1, 1, 1, 1, 0, 0],
    [64 * 4, 49, 49, 64 * 4, 49, 32, 1, 1, 1, 1, 0, 0],
]

GEMM_SHAPES_ACT_WGT = [
    # Index - 0
    [1, 1, 64, 1, 1, 1024, 1, 1, 1, 1, 0, 0],
    [1, (32 * 5), 256, 1, (32 * 5), (64 * 3), 1, 1, 1, 1, 0, 0],
    [1, (32 * 1), 256, 1, (32 * 1), (64 * 11), 1, 1, 1, 1, 0, 0],
    [1, 32, 256, 1, 32, 64, 1, 1, 1, 1, 0, 0],
    [1, 64, 256, 1, 64, 64, 1, 1, 1, 1, 0, 0],
    [1, 64, 512, 1, 64, 512, 1, 1, 1, 1, 0, 0],
    [1, 64, 512, 1, 64, 64, 1, 1, 1, 1, 0, 0],
    # PSU0 - start index 6
    [1, 64, 3072, 1, 64, 3072, 1, 1, 1, 1, 0, 0],
    [1, 64, 8192, 1, 64, 3072, 1, 1, 1, 1, 0, 0],
    [1, 64, 3072, 1, 64, 9216, 1, 1, 1, 1, 0, 0],
    [1, 64, 3072, 1, 64, 16384, 1, 1, 1, 1, 0, 0],
    # PSU1 - start index 10
    [1, 1, 3072, 1, 1, 3072, 1, 1, 1, 1, 0, 0],
    [1, 1, 8192, 1, 1, 3072, 1, 1, 1, 1, 0, 0],
    [1, 1, 3072, 1, 1, 9216, 1, 1, 1, 1, 0, 0],
    [1, 1, 3072, 1, 1, 16384, 1, 1, 1, 1, 0, 0],
    # PSD3 - start index 14
    [1, 4096, 640, 1, 4096, 320, 1, 1, 1, 1, 0, 0],
    [1, 1, 1280, 1, 1, 320, 1, 1, 1, 1, 0, 0],
    [1, 1, 1280, 1, 1, 640, 1, 1, 1, 1, 0, 0],
    [1, 1024, 320, 1, 1024, 640, 1, 1, 1, 1, 0, 0],
    [1, 1024, 640, 1, 1024, 640, 1, 1, 1, 1, 0, 0],
    [1, 77, 2048, 1, 77, 640, 1, 1, 1, 1, 0, 0],
    [1, 1024, 640, 1, 1024, 2560, 1, 1, 1, 1, 0, 0],
    [1, 1024, 2560, 1, 1024, 640, 1, 1, 1, 1, 0, 0],
    [1, 1, 1280, 1, 1, 1280, 1, 1, 1, 1, 0, 0],
    [1, 256, 640, 1, 256, 1280, 1, 1, 1, 1, 0, 0],  # 23 fail
    [1, 256, 1280, 1, 256, 1280, 1, 1, 1, 1, 0, 0],
    [1, 77, 2048, 1, 77, 1280, 1, 1, 1, 1, 0, 0],
    [1, 256, 1280, 1, 256, 5120, 1, 1, 1, 1, 0, 0],
    [1, 256, 5120, 1, 256, 1280, 1, 1, 1, 1, 0, 0],
    [1, 256, 640, 1, 256, 640, 1, 1, 1, 1, 0, 0],
    [1, 1024, 320, 1, 1024, 320, 1, 1, 1, 1, 0, 0],
    [1, 4096, 320, 1, 4096, 320, 1, 1, 1, 1, 0, 0],  # 30 hang
    [1, 64, 2048, 1, 64, 11008, 1, 1, 1, 1, 0, 0],
    # PSI
    [1, 32*37, 128, 1, 32*37, 128, 1, 1, 1, 1, 0, 0],
    [1, 3136, 128, 1, 3136, 128, 1, 1, 1, 1, 0, 0],
    [1, 3136, 128, 1, 3136, 384, 1, 1, 1, 1, 0, 0],
    [1, 3136, 128, 1, 3136, 512, 1, 1, 1, 1, 0, 0],
    [1, 3136, 512, 1, 3136, 128, 1, 1, 1, 1, 0, 0],
]


def pad_shape(shape):
    """Pad shape: Ci→64x, Xi/Xo/Co→8x"""
    s = shape.copy()
    s[2] = math.ceil(s[2] / 64) * 64  # Ci
    s[1] = math.ceil(s[1] / 1) * 1  # Xi
    s[4] = math.ceil(s[4] / 1) * 1  # Xo
    s[5] = math.ceil(s[5] / 64) * 64  # Co
    return s


def L2_alloc_for_unit_test(shape: list, gemm_mode) -> dict:
    """Function to allocate memory for unit test"""
    ifm_size = shape[0] * shape[1] * shape[2]
    MB = 2**20
    # Parameters allocation
    prm_addrs = [0, 4096]  # col_0, col_1
    # Weight allocation (ping/pong buffers)
    wgt_base = prm_addrs[1] + 4096
    wgt_addrs = [wgt_base + i * (MB // 4) for i in range(4)]  # 256KB each
    # IFM allocation
    ifm_addr = wgt_addrs[3] + (MB // 4)
    # OFM allocation with tile assignment
    total_addr = ifm_addr + ifm_size
    if total_addr < 3 * MB:
        ofm_tile, ofm_addr = 0, total_addr
    elif total_addr < 6 * MB:
        ofm_tile, ofm_addr = 1, total_addr - 3 * MB
    else:
        ofm_tile, ofm_addr = 2, total_addr - 6 * MB
    # Column 2 allocation (tile 2)
    wgt_col_2_base = 2 * MB  # 9MB - 1MB
    prm_col_2_addr = wgt_col_2_base + (MB // 2)  # After both ping/pong
    if gemm_mode == GemmMode.WGT:
        return {
            "input_addr": {0: ifm_addr},
            "output_addr": {ofm_tile: ofm_addr},
            "wgt_addr": [
                [0, [wgt_addrs[0], wgt_addrs[1]]],
                [0, [wgt_addrs[2], wgt_addrs[3]]],
                [2, [wgt_col_2_base, wgt_col_2_base + (MB // 4)]],
            ],
            "prm_addr": [[0, prm_addrs[0]], [0, prm_addrs[1]], [2, prm_col_2_addr]],
        }
    # ACT x ACT
    return {
        "input0_addr": {0: ifm_addr},
        "output_addr": {ofm_tile: ofm_addr},
        "wgt_addr": [[0, [wgt_addrs[0], wgt_addrs[1]]]],
        "prm_addr": [[0, prm_addrs[0]], [0, prm_addrs[1]]],
    }


class DataType(str, Enum):
    """Data types for gemm tests"""

    UINT16 = "uint16"
    INT16 = "int16"
    INT8 = "int8"
    UINT8 = "uint8"
    INT4 = "int4"


class GemmMode(str, Enum):
    """Gemm modes for gemm tests"""

    ACT = "act"
    ACT_T = "act_t"
    WGT = "wgt"


class Transpose(int, Enum):
    """Gemm modes for gemm tests"""

    NO_TRANSPOSE = 0
    TRANSPOSE = 1


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", range(0, len(GEMM_SHAPES_ACT_ACT)))
@pytest.mark.parametrize(
    "transpose_wgts", [Transpose.NO_TRANSPOSE]
)
@pytest.mark.parametrize("data_type", [DataType.UINT16, DataType.INT16])
def test_gemm_act_t_l3(
    shape_index: int,
    data_type: DataType,
    transpose_wgts: int,
    target: BuildTarget,
    hwtest: bool,
    clean: bool,
    output_root: str,
):
    """Run ACT x ACT L3 tests"""
    main(
        target=target,
        dataflow_type=DataflowType.L3,
        shape_index=shape_index,
        data_type=data_type,
        gemm_mode=GemmMode.ACT_T,
        transpose_wgts=transpose_wgts,
        hwtest=hwtest,
        clean=clean,
        output_root=output_root,
    )


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", range(0, len(GEMM_SHAPES_ACT_WGT)))
@pytest.mark.parametrize("data_type", [DataType.INT4, DataType.INT8])
def test_gemm_wgt_l3(
    shape_index: int,
    data_type: DataType,
    target: BuildTarget,
    hwtest: bool,
    clean: bool,
    output_root: str,
):
    """Run ACT x WGT L3 tests"""
    main(
        target=target,
        dataflow_type=DataflowType.L3,
        shape_index=shape_index,
        data_type=data_type,
        gemm_mode=GemmMode.WGT,
        transpose_wgts=Transpose.TRANSPOSE,
        hwtest=hwtest,
        clean=clean,
        output_root=output_root,
    )


def main(
    target: BuildTarget = typer.Option(
        default=BuildTarget.DATAFLOW, help="Build target for the operator"
    ),
    dataflow_type: DataflowType = typer.Option(
        default=DataflowType.L3, help="Dataflow type"
    ),
    shape_index: Optional[int] = typer.Option(
        default=None, help="Index of the shape to test"
    ),
    data_type: DataType = typer.Option(
        default=DataType.INT16, help="Data type to test"
    ),
    gemm_mode: GemmMode = typer.Option(default=GemmMode.ACT, help="Gemm mode to test"),
    transpose_wgts: Transpose = typer.Option(
        default=Transpose.TRANSPOSE, help="IFM2 transpose flag"
    ),
    vcd: bool = typer.Option(default=False, help="Dump VCD trace", is_flag=True),
    hwtest: bool = typer.Option(
        default=False, help="Run HW_test after builds", is_flag=True
    ),
    clean: bool = typer.Option(
        default=False, help="Clean output directory before running", is_flag=True
    ),
    output_root: str = typer.Option(
        default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output"
    ),
) -> None:
    """Function for running gemm regression testing using build script"""
    if gemm_mode in [GemmMode.ACT, GemmMode.ACT_T]:
        GEMM_SHAPES = GEMM_SHAPES_ACT_ACT
    else:
        GEMM_SHAPES = GEMM_SHAPES_ACT_WGT
    if shape_index is not None:
        shape_table = [GEMM_SHAPES[shape_index]]
    else:
        shape_table = GEMM_SHAPES
    is_cert_backend = target == BuildTarget.CERT
    gen_pdi = is_cert_backend and pdi_counter() == 1
    # Initialize parameters
    results_list = [""] * len(shape_table)
    simtime_list = [0.0] * len(shape_table)

    _transpose_wgts = 0 if transpose_wgts == Transpose.NO_TRANSPOSE else 1
    output_root = str(output_root)
    clean_output_dir(output_root, clean)

    combined_kernel_includes, combined_kernel_names = get_combined_kernel_list(
        "build_gemm.py"
    )

    with change_dir("../"):
        for shape_idx, shape in enumerate(shape_table):
            padded_shape = None
            if gemm_mode in [GemmMode.ACT, GemmMode.ACT_T]:
                padded_shape = pad_shape(shape)
            l2_alloc = L2_alloc_for_unit_test(
                padded_shape if gemm_mode in [GemmMode.ACT, GemmMode.ACT_T] else shape,
                gemm_mode,
            )
            if gemm_mode == GemmMode.WGT:
                op_type_mapping = {
                    "int8": "MatMul_qdq_int16xint8xint16",
                    "uint8": "MatMul_qdq_uint16xuint8xuint16",
                    "int4": "MatMul_qdq_uint16xint4xuint16",
                }
                extra_table = {
                    "MappingRank": 0,
                    "act_type": 0,
                    "load_input_from_ddr": True,
                    "store_output_to_ddr": True,
                    "vector_coeff": 0,
                    "L3": {"ifm": [1, 0], "ofm": [0, 0]},
                }
                wgt_fmt = {
                    "name": op_type_mapping[data_type],
                    "in_dtype_A": "uint16",
                    "in_dtype_B": data_type,
                    "in_dtype_Bias": "int32",
                    "out_dtype_Y": "uint16",
                }
                merged_shape = {
                    "input": shape[0:3],
                    "output": shape[3:6],
                    "padded_input": shape[0:3],
                    "padded_output": shape[3:6],
                    "kernel": shape[6:8],
                    "stride": shape[8:10],
                    "pad": shape[10:12],
                    "transpose_wgts": _transpose_wgts,
                    "enable_L2_fusion": dataflow_type == "l2",
                    **extra_table,
                    **l2_alloc,
                    **wgt_fmt,
                    **{"dataflow_type": (dataflow_type.name,)},
                    "op": op_type_mapping[data_type],
                }
            else:
                op_type_mapping = {
                    "uint16": (
                        "MatMul_qdq_actxact_Transpose_uint16xuint16xuint16"
                        if transpose_wgts == Transpose.TRANSPOSE
                        else "MatMul_qdq_actxact_uint16xuint16xuint16"
                    ),
                    "int16": (
                        "MatMul_qdq_actxact_Transpose_int16xint16xint16"
                        if transpose_wgts == Transpose.TRANSPOSE
                        else "MatMul_qdq_actxact_int16xint16xint16"
                    ),
                }
                extra_table = {
                    "MappingRank": 0,
                    "act_type": 0,
                    "load_input_from_ddr": True,
                    "store_output_to_ddr": True,
                    "vector_coeff": 0,
                    "L3": {"ifm0": [1, 0], "ifm1": [1, 0], "ofm": [0, 0]},
                }
                wgt_fmt = {
                    "name": op_type_mapping[data_type],
                    "in_dtype_A": data_type,
                    "in_dtype_B": data_type,
                    "in_dtype_Bias": "int32",
                    "out_dtype_Y": data_type,
                }
                ifm_bytes = 2 if data_type in [DataType.UINT16, DataType.INT16] else 1
                extra_table["L3"]["ifm1"][1] = padded_shape[0] * padded_shape[1] * padded_shape[2] * ifm_bytes
                merged_shape = {
                    "input0": shape[0:3],
                    "output": shape[3:6],
                    "padded_input0": padded_shape[0:3],
                    "padded_output": padded_shape[3:6],
                    "kernel": shape[6:8],
                    "stride": shape[8:10],
                    "pad": shape[10:12],
                    "transpose_wgts": _transpose_wgts,
                    "enable_L2_fusion": dataflow_type == "l2",
                    **extra_table,
                    **l2_alloc,
                    **wgt_fmt,
                    **{"dataflow_type": (dataflow_type.name,)},
                    "op": op_type_mapping[data_type],
                }
            print(f"Building gemm with shape: {shape}")

            op_type = op_type_mapping[data_type]
            os.environ["LOG_ENABLED"] = "True"
            compile_operator(
                op_type,
                merged_shape,
                target,
                output_root,
                combined_kernel_includes=combined_kernel_includes,
                combined_kernel_names=combined_kernel_names,
                gen_standalone_pdi=gen_pdi,
                gen_op_elf=is_cert_backend,
                dump_vcd=vcd,
            )
            build_dir = os.path.join(
                output_root,
                f"op_{op_type}_shape_{out_dir_name_from_dict(merged_shape)}",
            )
            if target == BuildTarget.SIM:
                sim_log = os.path.join(build_dir, "AIESimulator.log")
                process_simulation_results(
                    sim_log, shape_idx, results_list, simtime_list
                )
                fieldnames = [
                    "gemm_mode",
                    "data_type",
                    "Yi",
                    "Xi",
                    "Ci",
                    "Yo",
                    "Xo",
                    "Co",
                    "Ky",
                    "Kx",
                    "Sy",
                    "Sx",
                    "Py",
                    "Px",
                ]
                csv_file_name = os.path.join(CURRDIR, "gemm_aiesim_results.csv")
                csv_shape_table = [
                    [gemm_mode.value, data_type.value] + shape for shape in shape_table
                ]
                write_csv(
                    csv_shape_table,
                    results_list,
                    simtime_list,
                    csv_file_name,
                    fieldnames,
                    default_row_mapper(fieldnames),
                )
            elif target == BuildTarget.CERT:
                create_hw_package(build_dir)
            gen_pdi = False  # Only generate PDI for the first one

    if hwtest:
        run_hw_validation(
            output_root,
            dtype=data_type.value,
            debug=False,
        )


if __name__ == "__main__":
    typer.run(main)
