"""
Regression test for the convolutional layers
"""

import os
from typing import Dict, Optional
from enum import Enum
from common import (
    process_simulation_results,
    create_hw_package,
    write_csv,
    default_row_mapper,
    change_dir,
    BuildTarget,
    DataflowType,
    run_hw_validation,
    clean_output_dir,
    Counter,
)
import pytest
import typer
from build_aie4 import compile_operator, out_dir_name_from_dict, get_combined_kernel_list
from graph.utilities import config_logger_from_env

CURRDIR = os.path.dirname(os.path.abspath(__file__))
config_logger_from_env()
pdi_counter = Counter()

INT8_CONV_SHAPES = [
    # Resnet50 - start index 0
    [224, 224, 8, 112, 112, 64, 7, 7, 2, 2, 3, 3],
    [56, 56, 64, 56, 56, 64, 1, 1, 1, 1, 0, 0],
    [56, 56, 64, 56, 56, 64, 3, 3, 1, 1, 1, 1],
    [56, 56, 64, 56, 56, 256, 1, 1, 1, 1, 0, 0],
    [56, 56, 256, 56, 56, 64, 1, 1, 1, 1, 0, 0],
    [56, 56, 256, 56, 56, 128, 1, 1, 1, 1, 0, 0],
    [56, 56, 128, 28, 28, 128, 3, 3, 2, 2, 1, 1],
    [56, 56, 256, 28, 28, 512, 1, 1, 2, 2, 0, 0],
    [28, 28, 128, 28, 28, 512, 1, 1, 1, 1, 0, 0],
    [28, 28, 512, 28, 28, 128, 1, 1, 1, 1, 0, 0],
    [28, 28, 128, 28, 28, 128, 3, 3, 1, 1, 1, 1],
    [28, 28, 512, 28, 28, 256, 1, 1, 1, 1, 0, 0],
    [28, 28, 256, 14, 14, 256, 3, 3, 2, 2, 1, 1],
    [28, 28, 512, 14, 14, 1024, 1, 1, 2, 2, 0, 0],
    [14, 14, 256, 14, 14, 1024, 1, 1, 1, 1, 0, 0],
    [14, 14, 1024, 14, 14, 256, 1, 1, 1, 1, 0, 0],
    [14, 14, 256, 14, 14, 256, 3, 3, 1, 1, 1, 1],
    [14, 14, 1024, 14, 14, 512, 1, 1, 1, 1, 0, 0],
    [14, 14, 512, 7, 7, 512, 3, 3, 2, 2, 1, 1],
    [14, 14, 1024, 7, 7, 2048, 1, 1, 2, 2, 0, 0],
    [7, 7, 512, 7, 7, 2048, 1, 1, 1, 1, 0, 0],
    [7, 7, 2048, 7, 7, 512, 1, 1, 1, 1, 0, 0],
    [7, 7, 512, 7, 7, 512, 3, 3, 1, 1, 1, 1],
    # Yolo - start index 23
    [13, 13, 512, 13, 13, 256, 1, 1, 1, 1, 0, 0],
    [13, 13, 1024, 13, 13, 512, 1, 1, 1, 1, 0, 0],
    [13, 13, 1024, 13, 13, 256, 1, 1, 1, 1, 0, 0],
    [26, 26, 256, 26, 26, 128, 1, 1, 1, 1, 0, 0],
    [26, 26, 512, 26, 26, 256, 1, 1, 1, 1, 0, 0],
    [52, 52, 256, 52, 52, 128, 1, 1, 1, 1, 0, 0],
    [52, 52, 256, 52, 52, 256, 1, 1, 1, 1, 0, 0],
    [104, 104, 128, 104, 104, 64, 1, 1, 1, 1, 0, 0],  # fail 30
    [208, 208, 64, 208, 208, 64, 1, 1, 1, 1, 0, 0],
    [13, 13, 512, 13, 13, 1024, 3, 3, 1, 1, 1, 1],
    [26, 26, 256, 26, 26, 512, 3, 3, 1, 1, 1, 1],
    [26, 26, 512, 13, 13, 1024, 3, 3, 2, 2, 1, 1],
    [52, 52, 128, 52, 52, 256, 3, 3, 1, 1, 1, 1],
    [52, 52, 256, 26, 26, 512, 3, 3, 2, 2, 1, 1],
    [104, 104, 64, 104, 104, 128, 3, 3, 1, 1, 1, 1],
    [104, 104, 128, 52, 52, 256, 3, 3, 2, 2, 1, 1],
    [208, 208, 64, 208, 208, 64, 3, 3, 1, 1, 1, 1],
    [208, 208, 64, 104, 104, 128, 3, 3, 2, 2, 1, 1],
    [416, 416, 8, 416, 416, 64, 3, 3, 1, 1, 1, 1],
    [416, 416, 64, 208, 208, 64, 3, 3, 2, 2, 1, 1],
]

A16W8_CONV_SHAPES = [
    # PSD5 - start index 0
    [16, 16, 1280, 16, 16, 1280, 3, 3, 1, 1, 1, 1],
    [16, 16, 1920, 16, 16, 1280, 3, 3, 1, 1, 1, 1],
    [16, 16, 2560, 16, 16, 1280, 3, 3, 1, 1, 1, 1],
    [16, 16, 640, 16, 16, 1280, 3, 3, 1, 1, 1, 1],
    [32, 32, 1280, 32, 32, 1280, 3, 3, 1, 1, 1, 1],
    [32, 32, 1280, 32, 32, 640, 3, 3, 1, 1, 1, 1],
    [32, 32, 1920, 32, 32, 640, 3, 3, 1, 1, 1, 1],
    [32, 32, 320, 32, 32, 640, 3, 3, 1, 1, 1, 1],
    [32, 32, 640, 32, 32, 640, 3, 3, 1, 1, 2, 2],
    [32, 32, 640, 32, 32, 640, 3, 3, 1, 1, 1, 1],
    # PSD3 - start index 10
    [128, 128, 96, 128, 128, 96, 3, 3, 1, 1, 1, 1],
    [512, 512, 64, 512, 512, 64, 3, 3, 1, 1, 1, 1],
]

extra_table = {
    "MappingRank": 0,
    "out_shift": 8,
    "bias_shift": 2,
    "act_type": 0,
    "load_input_from_ddr": True,
    "store_output_to_ddr": True,
    "vector_coeff": 0,
    "enable_over_compute": 1,
    "L3": {"ifm": [1, 0], "ofm": [0, 0]},
    "attributes": {"const_padding_value": ["0"]},
}


def L2_alloc_for_unit_test(shape: list) -> Dict:
    """Function to allocate memory for unit test"""
    MB = 2**20
    ifm_size = max(MB, shape[0] * shape[1] * shape[2])  # At least 1MB
    # Parameters allocation
    prm_addrs = [0, 4096]  # col_0, col_1
    # Weight allocation (ping/pong buffers)
    wgt_base = prm_addrs[1] + 4096
    wgt_addrs = [wgt_base + i * (MB // 4) for i in range(4)]  # 256KB each
    # IFM allocation
    ifm_addr = wgt_addrs[3] + (MB // 4)
    # OFM allocation with tile assignment
    total_addr = ifm_addr + ifm_size
    if total_addr < 3 * MB:
        ofm_tile, ofm_addr = 0, total_addr
    elif total_addr < 6 * MB:
        ofm_tile, ofm_addr = 1, total_addr - 3 * MB
    else:
        ofm_tile, ofm_addr = 2, total_addr - 6 * MB
    # Column 2 allocation (tile 2)
    wgt_col_2_base = 2 * MB  # 9MB - 1MB
    prm_col_2_addr = wgt_col_2_base + (MB // 2)  # After both ping/pong
    return {
        "input_addr": {0: ifm_addr},
        "output_addr": {ofm_tile: ofm_addr},
        "wgt_addr": [
            [0, [wgt_addrs[0], wgt_addrs[1]]],
            [0, [wgt_addrs[2], wgt_addrs[3]]],
            [2, [wgt_col_2_base, wgt_col_2_base + (MB // 4)]],
        ],
        "prm_addr": [[0, prm_addrs[0]], [0, prm_addrs[1]], [2, prm_col_2_addr]],
    }


class DataType(str, Enum):
    """Data types for gemm tests"""

    INT8 = "int8"
    A16W8 = "a16w8"


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", range(len(INT8_CONV_SHAPES) - 2))
@pytest.mark.parametrize("dataflow_type", DataflowType)
def test_conv_int8(
    shape_index: int,
    dataflow_type: DataflowType,
    target: BuildTarget,
    hwtest: bool,
    clean: bool,
    output_root: str,
) -> None:
    """Run all the tests"""
    main(
        target=target,
        dataflow_type=dataflow_type,
        shape_index=shape_index,
        data_type=DataType.INT8,
        hwtest=hwtest,
        clean=clean,
        output_root=output_root,
    )


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", range(len(A16W8_CONV_SHAPES)))
def test_conv_a16w8(
    shape_index: int, target: BuildTarget, hwtest: bool, clean: bool, output_root: str
) -> None:
    """Run all the tests"""
    main(
        target=target,
        dataflow_type=DataflowType.L3,
        shape_index=shape_index,
        data_type=DataType.A16W8,
        hwtest=hwtest,
        clean=clean,
        output_root=output_root,
    )


def main(
    target: BuildTarget = typer.Option(
        default=BuildTarget.DATAFLOW, help="Build target for the operator"
    ),
    dataflow_type: DataflowType = typer.Option(
        default=DataflowType.L2, help="Dataflow type"
    ),
    shape_index: Optional[int] = typer.Option(
        default=None, help="Index of the shape to test"
    ),
    data_type: DataType = typer.Option(default=DataType.INT8, help="Data type to test"),
    vcd: bool = typer.Option(default=False, help="Dump VCD trace", is_flag=True),
    hwtest: bool = typer.Option(
        default=False, help="Run HW_test after builds", is_flag=True
    ),
    clean: bool = typer.Option(
        default=False, help="Clean output directory before running", is_flag=True
    ),
    output_root: str = typer.Option(
        default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output"
    ),
) -> None:
    """Function for running conv regression testing using build script"""
    if data_type == DataType.INT8:
        CONV_SHAPES = INT8_CONV_SHAPES
    else:
        CONV_SHAPES = A16W8_CONV_SHAPES
    if shape_index is not None:
        shape_table = [CONV_SHAPES[shape_index]]
    else:
        shape_table = CONV_SHAPES
    is_cert_backend = target == BuildTarget.CERT
    gen_pdi = is_cert_backend and pdi_counter() == 1
    # Initialize parameters
    results_list = [""] * len(shape_table)
    simtime_list = [0.0] * len(shape_table)
    op_type_mapping = {
        "int8": "conv_noqdq_a8w8",
        "a16w8": "Conv_qdq_int16xint8xint16",
    }

    output_root = str(output_root)
    print(output_root)
    clean_output_dir(output_root, clean)

    combined_kernel_includes, combined_kernel_names = get_combined_kernel_list(
        "build_conv.py"
    )

    with change_dir("../"):
        for shape in shape_table:
            l2_alloc = L2_alloc_for_unit_test(shape)
            wgt_fmt = {
                "name": op_type_mapping[data_type],
                "in_dtype_A": "uint8" if data_type == "int8" else "uint16",
                "in_dtype_B": "uint8" if data_type == "int8" else "uint8",
                "in_dtype_Bias": "int16" if data_type == "int8" else "int32",
                "out_dtype_Y": "uint8" if data_type == "int8" else "uint16",
            }
            extra_table["enable_over_compute"] = (
                0 if dataflow_type == DataflowType.L2 else 1
            )
            merged_shape = {
                "input": shape[0:3],
                "output": shape[3:6],
                "kernel": shape[6:8],
                "stride": shape[8:10],
                "pad": shape[10:12],
                "enable_L2_fusion": dataflow_type == DataflowType.L2,
                **extra_table,
                **l2_alloc,
                **wgt_fmt,
                **{"dataflow_type": (dataflow_type.name,)},
                "op": op_type_mapping[data_type],
            }
            print(f"Building conv with shape: {shape}")
            op_type = op_type_mapping[data_type]
            os.environ["LOG_ENABLED"] = "false"
            compile_operator(
                op_type,
                merged_shape,
                target,
                output_root,
                combined_kernel_includes=combined_kernel_includes,
                combined_kernel_names=combined_kernel_names,
                gen_standalone_pdi=gen_pdi,
                gen_op_elf=is_cert_backend,
                dump_vcd=vcd,
            )
            build_dir = os.path.join(
                output_root,
                f"op_{op_type}_shape_{out_dir_name_from_dict(merged_shape)}",
            )
            if target == BuildTarget.SIM:
                sim_log = os.path.join(build_dir, "AIESimulator.log")
                process_simulation_results(
                    sim_log, shape_table.index(shape), results_list, simtime_list
                )
            elif target == BuildTarget.CERT:
                create_hw_package(build_dir)
            gen_pdi = False

        # Write results to CSV
        if target == BuildTarget.SIM:
            fieldnames = [
                "conv_type",
                "Yi",
                "Xi",
                "Ci",
                "Yo",
                "Xo",
                "Co",
                "Ky",
                "Kx",
                "Sy",
                "Sx",
                "Py",
                "Px",
            ]
            csv_file_name = os.path.join(CURRDIR, "conv_aiesim_results.csv")
            # Add conv type to shape table
            shape_table = [[data_type] + shape for shape in shape_table]
            write_csv(
                shape_table,
                results_list,
                simtime_list,
                csv_file_name,
                fieldnames,
                default_row_mapper(fieldnames),
            )

    if hwtest:
        run_hw_validation(
            output_root,
            dtype="int8" if data_type == DataType.INT8 else "int16",
            filter_patterns="op_Conv*",  # optional filtering, can be romoved
            debug=False,
        )


if __name__ == "__main__":
    typer.run(main)
