'''
Regression test for the convolutional layers
'''
import os
from typing import Optional
import typer
import pytest

from common import (
    process_simulation_results,
    create_hw_package,
    write_csv,
    default_row_mapper,
    change_dir,
    BuildTarget,
    run_hw_validation,
    clean_output_dir,
    Counter
)
from build_aie4 import compile_operator, out_dir_name_from_dict
from graph.utilities import config_logger_from_env

CURRDIR = os.path.dirname(os.path.abspath(__file__))
config_logger_from_env()
pdi_counter = Counter()


MAXPOOL_SHAPES = [
    [112, 112, 64, 56, 56, 64, 3, 3, 2, 2, 1, 1],
]

extra_table = {
    "MappingRank": 0,
    "sign_A": 0,
    "dtype": 1,
    "input_addr": {1: 0},
    "output_addr": {1: 1*(2**20)},
    "wgt_addr": [[0, [(2*2**20), (2*2**20 + 262144)]],
                 [1, [(2*2**20), (2*2**20 + 262144)]],
                 [2, [(2*2**20), (2*2**20 + 262144)]]],
    "prm_addr": [[0, (2*2**20 + 262144 + 262144)], [1, (2*2**20 + 262144 + 262144)], [2, (2*2**20 + 262144 + 262144)]],
    "load_input_from_ddr": True, "store_output_to_ddr": True,
    "L3": {
        "ifm": [1, 0],
        "ofm": [0, 0]
        }
    }


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", range(len(MAXPOOL_SHAPES)))
def test_matpool(shape_index: int, target: BuildTarget, hwtest: bool, clean: bool, output_root: str) -> None:
    '''Run all the tests'''
    main(target=target, shape_index=shape_index, hwtest=hwtest, clean=clean, output_root=output_root)


def main(
    target: BuildTarget = typer.Option(default=BuildTarget.DATAFLOW, help="Build target for the operator"),
    shape_index: Optional[int] = typer.Option(default=None, help="Index of the shape to test"),
    vcd: bool = typer.Option(default=False, help="Dump VCD trace", is_flag=True),
    hwtest: bool = typer.Option(default=False, help="Run HW_test after builds", is_flag=True),
    clean: bool = typer.Option(default=False, help="Clean output directory before running", is_flag=True),
    output_root: str = typer.Option(default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output")
) -> None:
    '''Function for running maxpool regression testing using build script'''
    if shape_index is not None:
        shape_table = [MAXPOOL_SHAPES[shape_index]]
    else:
        shape_table = MAXPOOL_SHAPES
    is_cert_backend = target == BuildTarget.CERT
    gen_pdi = is_cert_backend and pdi_counter() == 1
    # Initialize parameters
    results_list = [''] * len(shape_table)
    simtime_list = [0.0] * len(shape_table)
    output_root = str(output_root)
    clean_output_dir(output_root, clean)

    with change_dir("../"):
        for shape in shape_table:
            merged_shape = {
                "input": shape[0:3],
                "output": shape[3:6],
                "kernel": shape[6:8],
                "stride": shape[8:10],
                "pad": shape[10:12],
                **extra_table
                }
            os.environ["LOG_ENABLED"] = "true"
            compile_operator("maxpool_noqdq_a8",
                             merged_shape,
                             target,
                             output_root,
                             gen_standalone_pdi=gen_pdi,
                             gen_op_elf=is_cert_backend,
                             dump_vcd=vcd,
                             )
            build_dir = os.path.join(output_root, f"op_maxpool_noqdq_a8_shape_{out_dir_name_from_dict(merged_shape)}")
            if target == BuildTarget.SIM:
                sim_log = os.path.join(build_dir, 'AIESimulator.log')
                process_simulation_results(sim_log, shape_table.index(shape),
                                           results_list, simtime_list)
            elif target == BuildTarget.CERT:
                create_hw_package(build_dir)
            gen_pdi = False

        # Write results to CSV
        if target == BuildTarget.SIM:
            fieldnames = ['Yi', 'Xi', 'Ci', 'Yo', 'Xo', 'Co', 'Ky', 'Kx', 'Sy', 'Sx', 'Py', 'Px']
            csv_file_name = os.path.join(CURRDIR, 'maxpool_aiesim_results.csv')
            write_csv(shape_table,
                      results_list,
                      simtime_list,
                      csv_file_name,
                      fieldnames,
                      default_row_mapper(fieldnames))

    if hwtest:
        run_hw_validation(
            output_root,
            dtype="int8",
        )


if __name__ == '__main__':
    typer.run(main)
