'''
Regression test for the convolutional layers
'''
import os
from typing import Dict, Optional
from enum import Enum
from common import (
    process_simulation_results,
    create_hw_package,
    write_csv,
    default_row_mapper,
    change_dir,
    BuildTarget,
    DataflowType,
    run_hw_validation,
    clean_output_dir,
    Counter,
)
import pytest
import typer
from build_aie4 import compile_operator, out_dir_name_from_dict
from graph.utilities import config_logger_from_env

CURRDIR = os.path.dirname(os.path.abspath(__file__))
config_logger_from_env()
pdi_counter = Counter()


class DataType(str, Enum):
    """Data types for gemm tests"""
    A16W8 = "a16w8"


def L2_alloc_for_unit_test(shape: list) -> Dict:
    '''Function to allocate memory for unit test'''
    MB = 2**20
    ifm_size = max(MB, shape[0] * shape[1] * shape[2])  # At least 1MB
    # Parameters allocation
    prm_addrs = [0, 4096]  # col_0, col_1
    # Weight allocation (ping/pong buffers)
    wgt_base = prm_addrs[1] + 4096
    wgt_addrs = [wgt_base + i * (MB//4) for i in range(4)]  # 256KB each
    # IFM allocation
    ifm_addr = wgt_addrs[3] + (MB//4)
    # OFM allocation with tile assignment
    total_addr = ifm_addr + ifm_size
    if total_addr < 3 * MB:
        ofm_tile, ofm_addr = 0, total_addr
    elif total_addr < 6 * MB:
        ofm_tile, ofm_addr = 1, total_addr - 3*MB
    else:
        ofm_tile, ofm_addr = 2, total_addr - 6*MB
    # Column 2 allocation (tile 2)
    wgt_col_2_base = 2 * MB  # 9MB - 1MB
    prm_col_2_addr = wgt_col_2_base + (MB//2)  # After both ping/pong
    return {
        "input_addr": {0: ifm_addr},
        "output_addr": {ofm_tile: ofm_addr},
        "wgt_addr": [[0, [wgt_addrs[0], wgt_addrs[1]]],
                     [0, [wgt_addrs[2], wgt_addrs[3]]],
                     [2, [wgt_col_2_base, wgt_col_2_base + (MB//4)]]],
        "prm_addr": [[0, prm_addrs[0]], [0, prm_addrs[1]], [2, prm_col_2_addr]]
    }


A16W8_DWC_SHAPES = [
    [(6*12), (4*12), (448), (4*12), (2*12), (448), 3, 3, 1, 1, 0, 0],
    [(6*12), (4*12), (416), (4*12), (2*12), (416), 3, 3, 1, 1, 0, 0],
    [(9*12), (5*12), (416), (4*12), (2*12), (416), 3, 3, 2, 2, 0, 0],
]

extra_table = {
    "MappingRank": 0,
    "out_shift": 8,
    "bias_shift": 2,
    "act_type": 0,
    "load_input_from_ddr": True,
    "store_output_to_ddr": True,
    "vector_coeff": 0,
    "enable_over_compute": 1,
    "L3": {
        "ifm": [1, 0],
        "ofm": [0, 0]
        }
    }


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", range(len(A16W8_DWC_SHAPES)))
def test_dwc_a16w8(
    shape_index: int,
    target: BuildTarget,
    hwtest: bool,
    clean: bool,
    output_root: str,
) -> None:
    '''Run all the tests'''
    main(
        target=target,
        dataflow_type=DataflowType.L3, shape_index=shape_index,
        data_type=DataType.A16W8,
        hwtest=hwtest,
        clean=clean,
        output_root=output_root,
    )


def main(
    target: BuildTarget = typer.Option(default=BuildTarget.DATAFLOW, help="Build target for the operator"),
    dataflow_type: DataflowType = typer.Option(default=DataflowType.L2, help="Dataflow type"),
    shape_index: Optional[int] = typer.Option(default=None, help="Index of the shape to test"),
    data_type: DataType = typer.Option(default=DataType.A16W8, help="Data type to test"),
    vcd: bool = typer.Option(default=False, help="Dump VCD trace", is_flag=True),
    hwtest: bool = typer.Option(
        default=False, help="Run HW_test after builds", is_flag=True
    ),
    clean: bool = typer.Option(
        default=False, help="Clean output directory before running", is_flag=True
    ),
    output_root: str = typer.Option(
        default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output"
    ),
) -> None:
    '''Function for running conv regression testing using build script'''
    DWC_SHAPES = []
    if data_type == DataType.A16W8:
        DWC_SHAPES = A16W8_DWC_SHAPES
    if shape_index is not None:
        shape_table = [DWC_SHAPES[shape_index]]
    else:
        shape_table = DWC_SHAPES
    is_cert_backend = target == "cert"
    gen_pdi = is_cert_backend
    # Initialize parameters
    results_list = [''] * len(shape_table)
    simtime_list = [0.0] * len(shape_table)
    op_type_mapping = {
        'a16w8': 'Conv_qdq_int16xint8xint16',
    }

    output_root = str(output_root)
    print(output_root)
    clean_output_dir(output_root, clean)

    with change_dir("../"):
        for shape in shape_table:
            l2_alloc = L2_alloc_for_unit_test(shape)
            wgt_fmt = {
                "name": op_type_mapping[data_type],
                "in_dtype_A": "uint8" if data_type == 'int8' else "uint16",
                "in_dtype_B": "uint8" if data_type == 'int8' else "uint8",
                "in_dtype_Bias": "int16" if data_type == 'int8' else "int32",
                "out_dtype_Y": "uint8" if data_type == 'int8' else "uint16",
            }
            extra_table["enable_over_compute"] = 0 if dataflow_type == DataflowType.L2 else 1
            extra_table["group"] = shape[2]  # Make group == Cin == Cout to map to DWC flavor conv
            merged_shape = {
                "input": shape[0:3],
                "output": shape[3:6],
                "kernel": shape[6:8],
                "stride": shape[8:10],
                "pad": shape[10:12],
                "enable_L2_fusion": dataflow_type == DataflowType.L2,
                **extra_table, **l2_alloc, **wgt_fmt,
                **{"dataflow_type": (dataflow_type.name, )},
                "op": op_type_mapping[data_type],
                }
            print(f"Building dwc with shape: {shape}")
            op_type = op_type_mapping[data_type]
            os.environ["LOG_ENABLED"] = "true"
            compile_operator(
                op_type,
                merged_shape,
                target,
                output_root,
                gen_standalone_pdi=gen_pdi,
                gen_op_elf=is_cert_backend,
                dump_vcd=vcd,
            )
            build_dir = os.path.join("Output", f"op_{op_type}_{out_dir_name_from_dict(merged_shape)}")
            if target == BuildTarget.SIM:
                sim_log = os.path.join(build_dir, 'AIESimulator.log')
                process_simulation_results(sim_log, shape_table.index(shape),
                                           results_list, simtime_list)
            elif target == BuildTarget.CERT:
                create_hw_package(build_dir)
            gen_pdi = False     # We build PDI only for the first shape to save time

        # Write results to CSV
        if target == BuildTarget.SIM:
            fieldnames = ['conv_type', 'Yi', 'Xi', 'Ci', 'Yo', 'Xo', 'Co', 'Ky', 'Kx', 'Sy', 'Sx', 'Py', 'Px']
            csv_file_name = os.path.join(CURRDIR, 'conv_aiesim_results.csv')
            # Add conv type to shape table
            shape_table = [[data_type] + shape for shape in shape_table]
            write_csv(shape_table,
                      results_list,
                      simtime_list,
                      csv_file_name,
                      fieldnames,
                      default_row_mapper(fieldnames))

    if hwtest:
        run_hw_validation(
            output_root,
            dtype="int16",
            filter_patterns="op_Conv*",  # optional filtering, can be romoved
            debug=False,
        )


if __name__ == '__main__':
    typer.run(main)
