'''
Regression test for the binary layers
'''
import os
import math
from enum import Enum
from typing import Optional
# import pytest
import typer

from buildtest.common import (
    process_simulation_results,
    create_hw_package,
    write_csv,
    default_row_mapper,
    change_dir,
    BuildTarget,
    run_hw_validation,
    clean_output_dir,
    Counter
)
from build_aie4 import compile_operator, out_dir_name_from_dict, get_combined_kernel_list
from utils.build_utils import set_datatype
from utils.utils_common import iceil
from graph.utilities import config_logger_from_env

CURRDIR = os.path.dirname(os.path.abspath(__file__))
config_logger_from_env()
pdi_counter = Counter()


BINARY_SHAPES = [
    # N  Y  X  C
    # ResNet 50
    [1, 56, 56, 256],
    [1, 28, 28, 512],
    [1, 14, 14, 1024],
    [1, 7, 7, 2048],
    # YOLOv3
    [1, 52, 52, 256],
    [1, 26, 26, 512],
    [1, 13, 13, 1024],
    # PSU
    [1, 1, 64, 3072],
    [1, 1, 64, 9216],
    [1, 1, 64, 16384],
    # PSJ
    [1, 12, 77, 77],
    [1, 1, 77, 768],
    [1, 12, 77, 1],
    [1, 1, 1, 768],
    # PSD
    [1, 12, 512, 1],
]

extra_table = {
    "input0_addr": {0: 6144},
    "input1_addr": {1: 6144},
    "output_addr": {2: 6144},
    "wgt_addr": [[0, [4096, 5120]],
                 [1, [4096, 5120]],
                 [2, [4096, 5120]]],
    "prm_addr": [[0, 0], [1, 0], [2, 0]],
    "load_input_from_ddr": True, "store_output_to_ddr": True,
    "L3": {
        "ifm0": [1, 0],
        "ifm1": [1, 0],
        "ofm": [0, 0]
        },
    "attributes": {
        "disable_dq0": [1],
        "disable_q": [1],
    },
    "in_dtype_A": "uint16",
    "in_dtype_B": "uint16",
    "out_dtype_C": "uint16",
}


class Ops(str, Enum):
    """Binary operations"""
    ADD_8 = "add_8"
    ADD_16 = "add_16"
    MUL_16 = "mul_16"


# @pytest.mark.dma
# @pytest.mark.parametrize("shape_index", range(len(BINARY_SHAPES)))
# @pytest.mark.parametrize("op", Ops)
# @pytest.mark.parametrize("b_on_wgt", [False, True])
# def test_binary(shape_index: int, op: str, b_on_wgt: bool, target: BuildTarget, hwtest: bool, clean: bool, output_root: str) -> None:
#     """Test binary operations"""
#     main(target=target, operator=op, shape_index=shape_index, b_on_wgt=b_on_wgt, hwtest=hwtest, clean=clean, output_root=output_root, a_on_wgt=False)


def operator_mapping(operator: str) -> str:
    """Operator mapping to ONNX names"""
    mapping = {
        Ops.ADD_8: "Add_qdq_EleWise_uint8xuint8xuint8",
        Ops.ADD_16: "Add_qdq_EleWise_uint16xuint16xuint16",
        Ops.MUL_16: "Mul_qdq_EleWise_uint16xuint16xuint16"
    }
    return mapping[operator]


def main(
    target: BuildTarget = typer.Option(default=BuildTarget.DATAFLOW, help="Build target for the operator"),
    operator: Ops = typer.Option(default=Ops.ADD_8, help="Binary operator to test"),
    shape_index: Optional[int] = typer.Option(default=None, help="Index of the shape to test"),
    a_on_wgt: bool = typer.Option(default=bool(False), help="Put IFM A on weight buffer", is_flag=True),
    b_on_wgt: bool = typer.Option(default=bool(False), help="Put IFM B on weight buffer", is_flag=True),
    vcd: bool = typer.Option(default=False, help="Dump VCD trace", is_flag=True),
    hwtest: bool = typer.Option(default=False, help="Run HW_test after builds", is_flag=True),
    clean: bool = typer.Option(default=False, help="Clean output directory before running", is_flag=True),
    output_root: str = typer.Option(default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output"),
    fp_16: bool = typer.Option(default=False, help="Is QDQ FP16 Datatype?", is_flag=True),
) -> None:
    '''Function for running add regression testing using build script'''
    operator_name = operator_mapping(operator)
    is_cert_backend = target == BuildTarget.CERT
    gen_pdi = is_cert_backend and pdi_counter() == 1

    if shape_index is not None:
        shape_table = [BINARY_SHAPES[shape_index]]
    else:
        shape_table = BINARY_SHAPES

    # we append number of bytes to the shape
    ifm_bytes = 1 if operator == 'add_8' else 2
    for shape in shape_table:
        shape.append(ifm_bytes)

    # Initialize parameters
    results_list = [''] * len(shape_table)
    simtime_list = [0.0] * len(shape_table)
    output_root = str(output_root)
    clean_output_dir(output_root, clean)

    combined_kernel_includes, combined_kernel_names = get_combined_kernel_list("build_binary.py")

    with change_dir("../"):
        for shape in shape_table:
            # Adjusting ["L3"]["ifm1"]["ifn_xrt_offset"] with the size of first tensor
            extra_table["in_dtype_A"] = "uint16" if operator in (Ops.ADD_16, Ops.MUL_16) else "uint8"
            extra_table["in_dtype_B"] = extra_table["out_dtype_C"] = extra_table["in_dtype_A"]
            extra_table["a_on_wgt"] = [int(a_on_wgt)]
            extra_table["b_on_wgt"] = [int(b_on_wgt)]
            extra_table["input_types"] = {"A": "const" if a_on_wgt else "act", "B": "const" if b_on_wgt else "act"}
            if a_on_wgt or b_on_wgt:
                # We adjust L3 addresses to put IFM1 on weight buffer in build_broadcast.py
                extra_table["L3"] = {"ifm": [1, 0], "ofm": [0, 0], "wgt": [2, 0]}
            else:
                ifm_a_bytes = math.prod(shape[:3]) * iceil(shape[3], 64) * ifm_bytes
                extra_table["L3"] = {"ifm0": [1, 0], "ifm1": [1, ifm_a_bytes], "ofm": [0, 0], "wgt": [2, 0]}

            merged_shape = {
                "input": shape[:4],
                "output": shape[:4],
                "op": operator_name,
                **extra_table
                }
            print(f"Building add with shape: {shape}")
            os.environ["LOG_ENABLED"] = "false"
            set_datatype(fp_16)
            compile_operator(operator_name,
                             merged_shape,
                             target,
                             output_root,
                             combined_kernel_includes=combined_kernel_includes,
                             combined_kernel_names=combined_kernel_names,
                             gen_standalone_pdi=gen_pdi,
                             gen_op_elf=is_cert_backend,
                             dump_vcd=vcd,
                             )
            build_dir = os.path.join(output_root, f"op_{operator_name}_shape_{out_dir_name_from_dict(merged_shape)}")
            if target == BuildTarget.SIM:
                sim_log = os.path.join(build_dir, 'AIESimulator.log')
                process_simulation_results(sim_log, shape_table.index(shape),
                                           results_list, simtime_list)
            elif target == BuildTarget.CERT:
                create_hw_package(build_dir)
            gen_pdi = False

    # Write results to CSV
    if target == BuildTarget.SIM:
        fieldnames = ['N', 'Y', 'X', 'C', 'ifm_bytes']
        csv_file_name = os.path.join(CURRDIR, 'binary_aiesim_results.csv')
        write_csv(shape_table, results_list,
                  simtime_list, csv_file_name,
                  fieldnames, default_row_mapper(fieldnames))

    if hwtest:
        run_hw_validation(
            output_root,
            dtype="int8" if operator == Ops.ADD_8 else "int16",
            debug=False,
        )


if __name__ == '__main__':
    typer.run(main)
