"""
Regression test for the unitary ops
"""

import os
import struct
from typing import Optional
import typer
import pytest

from buildtest.common import (
    change_dir,
    create_hw_package,
    BuildTarget,
    process_simulation_results,
    write_csv,
    default_row_mapper,
    run_hw_validation,
    clean_output_dir,
    Counter
)
from build_aie4 import compile_operator, out_dir_name_from_dict, get_combined_kernel_list
from scheduler.uniop.uniop_common import op_mapping
from graph.utilities import config_logger_from_env
from utils.build_utils import set_datatype

CURRDIR = os.path.dirname(os.path.abspath(__file__))
config_logger_from_env()
pdi_counter = Counter()

DQ_TESTS = [
    [1,     1,    64, "dequant", "3x4"],
    [1,    63,    64, "dequant", "3x4"],
    [1,    64,    63, "dequant", "3x4"],
    [1,    64,    64, "dequant", "3x4"],
    [1,   128,   128, "dequant", "3x4"],
    [1,     1,   768, "dequant", "3x4"],
    [1,     1,  2048, "dequant", "3x4"],
    [1,     1,  3072, "dequant", "3x4"],
    [1,     1,  9216, "dequant", "3x4"],
    [1,     1, 32128, "dequant", "3x4"],
    [1,     1,     2, "dequant", "3x4"],
    [1, 12544,    32, "dequant", "3x4"],
    [1,  6272,    32, "dequant", "3x4"],
    [1,     1,    10, "dequant", "3x4"],
    [1,   256,   640, "dequant", "3x4"],
    [1,   256,  1280, "dequant", "3x4"],
    [1,  3136,    32, "dequant", "3x4"],
    [1,  1024,   320, "dequant", "3x4"],
    [1,  1024,   640, "dequant", "3x4"],
    [1,  1568,    32, "dequant", "3x4"],
    [1,    49,   768, "dequant", "3x4"],
    [1,    49,  1024, "dequant", "3x4"],
    [1,    50,   768, "dequant", "3x4"],
    [1,  4096,     4, "dequant", "3x4"],
    [1,  4096,   320, "dequant", "3x4"],
    [1,    64,  2048, "dequant", "3x4"],
    [1,    64,  3072, "dequant", "3x4"],
    [1,    64,  9216, "dequant", "3x4"],
    [1,    77,   768, "dequant", "3x4"],
    [1,    77,  1280, "dequant", "3x4"],
    [1,     1,   100, "dequant", "3x4"],
    [1,   128,   768, "dequant", "3x4"],
    [1,   196,   768, "dequant", "3x4"],
    [1,   197,   768, "dequant", "3x4"],
    [1,     1,   512, "dequant", "3x4"],
    [1,   512,     1, "dequant", "3x4"],
    [1,     1,   768, "dequant", "3x4"],
    [1,     1,  1000, "dequant", "3x4"],
    [1,    10,     1, "dequant", "3x4"],
    [1,   770,    77, "dequant", "3x4"],
    [1,   770,   512, "dequant", "3x4"],
    [1,    10,   512, "dequant", "3x4"],
]

Q_TESTS = [
    [1,     1,     1, "quant", "3x4"],
    [1,     1,    64, "quant", "3x4"],
    [1,     1,   128, "quant", "3x4"],
    [1,     1,  3072, "quant", "3x4"],
    [1,    63,    64, "quant", "3x4"],
    [1,    64,    63, "quant", "3x4"],
    [1,    64,    64, "quant", "3x4"],
    [1,    64,  3072, "quant", "3x4"],
    [1,   128,   128, "quant", "3x4"],
    [1,     1,  1024, "quant", "3x4"],
    [1,     1,  1536, "quant", "3x4"],
    [1,     1,  2048, "quant", "3x4"],
    [1,     1,  3072, "quant", "3x4"],
    [1,   672,   224, "quant", "3x4"],
    [1, 12544,    32, "quant", "3x4"],
    [1,  6272,    32, "quant", "3x4"],
    [1,   256,   640, "quant", "3x4"],
    [1,   256,  1280, "quant", "3x4"],
    [1,  3136,    32, "quant", "3x4"],
    [1,  1024,   320, "quant", "3x4"],
    [1,  1024,   640, "quant", "3x4"],
    [1,  1568,    32, "quant", "3x4"],
    [1,    50,   768, "quant", "3x4"],
    [1,  4096,     4, "quant", "3x4"],
    [1,  4096,     5, "quant", "3x4"],
    [1,  4096,   320, "quant", "3x4"],
    [1,    64,  1024, "quant", "3x4"],
    [1,    64,  1536, "quant", "3x4"],
    [1,    64,  2048, "quant", "3x4"],
    [1,    64,  3072, "quant", "3x4"],
    [1,     1,    77, "quant", "3x4"],
    [1,    77,   768, "quant", "3x4"],
    [1,    77,  1280, "quant", "3x4"],
    [1,    77,  2048, "quant", "3x4"],
    [1,   197,   768, "quant", "3x4"],
    [1,     1,   512, "quant", "3x4"],
    [1, 262144,    3, "quant", "3x4"],
    [1,   512,   768, "quant", "3x4"],
    [1,     1,  1280, "quant", "3x4"],
    [1,   770,    77, "quant", "3x4"],
]

LRN_TESTS = [
    [1,    49,  1024, "layernorm", "3x4"],
    [1,    50,   768, "layernorm", "3x4"],
    [1,    77,   768, "layernorm", "3x4"],
    [1,    77,  1280, "layernorm", "3x4"],
    [1,   128,   768, "layernorm", "3x4"],
    [1,   196,   512, "layernorm", "3x4"],
    [1,   197,   768, "layernorm", "3x4"],
    [1,   256,  1280, "layernorm", "3x4"],
    [1,   512,   768, "layernorm", "3x4"],
    [1,     1,   768, "layernorm", "3x4"],
    [1,   784,   256, "layernorm", "3x4"],
    [1,     1,  1024, "layernorm", "3x4"],
    [1,  1024,   640, "layernorm", "3x4"],
    [1,  3136,   128, "layernorm", "3x4"],
    [1,   770,   512, "layernorm", "3x4"],
]

LPN_TESTS = [
    [1,     1,  3072, "l2norm", "3x4"],
    [1,    64,  3072, "l2norm", "3x4"],
    [1,     1,  1024, "l2norm", "3x4"],
    [1,     1,  1536, "l2norm", "3x4"],
    [1,     1,  2048, "l2norm", "3x4"],
    [1,     1,  3072, "l2norm", "3x4"],
    [1,    64,  1024, "l2norm", "3x4"],
    [1,    64,  1536, "l2norm", "3x4"],
    [1,    64,  2048, "l2norm", "3x4"],
    [1,    64,  3072, "l2norm", "3x4"],
    [1,     1,   100, "l2norm", "3x4"],
    [1,     1,   512, "l2norm", "3x4"],
    [1,     1,   768, "l2norm", "3x4"],
    [1,    10,   512, "l2norm", "3x4"],
]

GPN_TESTS = [
    [1,   256,   640, "groupnorm", "3x4"],
    [1,   256,  1280, "groupnorm", "3x4"],
    [1,   256,  1920, "groupnorm", "3x4"],
    [1,   256,  2560, "groupnorm", "3x4"],
    [1,  1024,   320, "groupnorm", "3x4"],
    [1,  1024,   640, "groupnorm", "3x4"],
    [1,  1024,   960, "groupnorm", "3x4"],
    [1,  1024,  1280, "groupnorm", "3x4"],
    [1,  1024,  1920, "groupnorm", "3x4"],
    [1,  4096,   320, "groupnorm", "3x4"],
    [1,  4096,   640, "groupnorm", "3x4"],
    [1,  4096,   960, "groupnorm", "3x4"],
]

SFMX_TESTS = [
    [1,     1,    64, "softmax", "3x4"],
    [1,    64,    64, "softmax", "3x4"],
    [1,   128,    32, "softmax", "3x4"],
    [1,   256,    32, "softmax", "3x4"],
    [1, 10240,    77, "softmax", "3x4"],
    [1, 10240,  1024, "softmax", "3x4"],
    [1,   924,    77, "softmax", "3x4"],
    [1,  1536,   128, "softmax", "3x4"],
    [1,  2364,   197, "softmax", "3x4"],
    [1,  6144,   512, "softmax", "3x4"],
    [1,   512,    32, "softmax", "3x4"],
    [1,  5120,    77, "softmax", "3x4"],
    [1,  5120,   256, "softmax", "3x4"],
    [1,  1024,    32, "softmax", "3x4"],
    [1,  1568,    49, "softmax", "3x4"],
    [1,  3136,    49, "softmax", "3x4"],
    [1, 10240,    77, "softmax", "3x4"],
    [1, 10240,  1024, "softmax", "3x4"],
    [1,   600,    50, "softmax", "3x4"],
    [1,   924,    77, "softmax", "3x4"],
    [1,  2364,   197, "softmax", "3x4"],
    [1,  6272,    49, "softmax", "3x4"],
    [1,  1540,    77, "softmax", "3x4"],
    [1,  5120,    77, "softmax", "3x4"],
    [1,  5120,   256, "softmax", "3x4"],
    [1, 12544,    49, "softmax", "3x4"],
    [1,  6160,    77, "softmax", "3x4"],
]

SILU_TESTS = [
    [1,     1,  8192, "silu", "3x4"],
    [1,    64,  8192, "silu", "3x4"],
    [1,     1,  8192, "silu", "3x4"],
    [1,     1,  8960, "silu", "3x4"],
    [1,   256,   640, "silu", "3x4"],
    [1,   256,  1280, "silu", "3x4"],
    [1,   256,  1920, "silu", "3x4"],
    [1,   256,  2560, "silu", "3x4"],
    [1,  1024,   320, "silu", "3x4"],
    [1,  1024,   640, "silu", "3x4"],
    [1,  1024,   960, "silu", "3x4"],
    [1,  1024,  1280, "silu", "3x4"],
    [1,  1024,  1920, "silu", "3x4"],
    [1,  4096,    16, "silu", "3x4"],
    [1,  4096,    32, "silu", "3x4"],
    [1,  4096,    96, "silu", "3x4"],
    [1,  4096,   256, "silu", "3x4"],
    [1,  4096,   320, "silu", "3x4"],
    [1,  4096,   640, "silu", "3x4"],
    [1,  4096,   960, "silu", "3x4"],
    [1,    64,  8192, "silu", "3x4"],
    [1,    64,  8960, "silu", "3x4"],
    [1, 16384,    96, "silu", "3x4"],
    [1, 65536,    32, "silu", "3x4"],
    [1, 262144,   16, "silu", "3x4"],
    [1,     1,  1280, "silu", "3x4"],
]

GELU_TESTS = [
    [1,     1,  4096, "gelu", "3x4"],
    [1,    49,  4096, "gelu", "3x4"],
    [1,    50,  3072, "gelu", "3x4"],
    [1,    64,  4096, "gelu", "3x4"],
    [1,    77,  3072, "gelu", "3x4"],
    [1,    77,  5120, "gelu", "3x4"],
    [1,   128,  3072, "gelu", "3x4"],
    [1,   196,  2048, "gelu", "3x4"],
    [1,   197,  3072, "gelu", "3x4"],
    [1,   256,  5120, "gelu", "3x4"],
    [1,   512,  3072, "gelu", "3x4"],
    [1,   784,  1024, "gelu", "3x4"],
    [1,  1024,  2560, "gelu", "3x4"],
    [1,  3136,   512, "gelu", "3x4"],
    [1,   770,  2048, "gelu", "3x4"],
]

SIGMOID_TESTS = [
    [1,     1,     1, "sigmoid", "3x4"],
    [1,   924,     2, "sigmoid", "3x4"],
    [1,  6144,     2, "sigmoid", "3x4"],
    [1,   512,     1, "sigmoid", "3x4"],
]

SWISH_TESTS = [
    [1,    77,  3072, "swish", "3x4"],
]

TANH_TESTS = [
    [1,     1,   768, "tanh", "3x4"],
]


QDQ_TESTS = DQ_TESTS + Q_TESTS
NORM_OP_TESTS = LRN_TESTS + LPN_TESTS + GPN_TESTS + SFMX_TESTS
LUT_OP_TESTS = SILU_TESTS + GELU_TESTS + SIGMOID_TESTS + SWISH_TESTS + TANH_TESTS
UNIOP_TESTS = QDQ_TESTS + NORM_OP_TESTS + LUT_OP_TESTS
# UNIOP_TESTS = LUT_OP_TESTS


def to_int_samebin(f32_x):
    '''Cast 32 bit float to 32 bit integer'''
    binary_representation = struct.pack("f", f32_x)
    y = int.from_bytes(binary_representation, byteorder="little")
    return y


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", range(len(UNIOP_TESTS)))
def test_uniop(shape_index: int, target: BuildTarget, hwtest: bool, clean: bool, output_root: str) -> None:
    '''Run all the tests'''
    main(target=target, shape_index=shape_index, hwtest=hwtest, clean=clean, output_root=output_root)


def main(
    target: BuildTarget = typer.Option(default=BuildTarget.DATAFLOW, help="Build target for the operator"),
    shape_index: Optional[int] = typer.Option(default=None, help="Index of the shape to test"),
    vcd: bool = typer.Option(default=False, help="Dump VCD trace", is_flag=True),
    hwtest: bool = typer.Option(default=False, help="Run HW_test after builds", is_flag=True),
    clean: bool = typer.Option(default=False, help="Clean output directory before running", is_flag=True),
    output_root: str = typer.Option(default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output"),
    fp_16: bool = typer.Option(default=False, help="Is QDQ FP16 Datatype?", is_flag=True),
) -> None:
    """Function for running GAP regression testing using build script"""
    is_cert_backend = target == BuildTarget.CERT
    gen_pdi = is_cert_backend and pdi_counter() == 1

    if shape_index is not None:
        shape_table = [UNIOP_TESTS[shape_index]]
    else:
        shape_table = UNIOP_TESTS

    # Initialize parameters
    results_list = [''] * len(shape_table)
    simtime_list = [0.0] * len(shape_table)
    output_root = str(output_root) if isinstance(output_root, str) else output_root.default
    print(output_root)
    clean_output_dir(output_root, clean)

    combined_kernel_includes, combined_kernel_names = get_combined_kernel_list("build_uniop.py")

    with change_dir("../"):
        for test in shape_table:
            TensorDim = tuple(test[0:3])
            N = TensorDim[0]
            Y = TensorDim[1]
            C = TensorDim[2]
            op = test[3]
            overlay_is_1x1 = test[4] == "1x1"
            test_dict = {
                "input": [N, Y, 1, C],
                "output": [N, Y, 1, C],
                "op": op_mapping[op][0],
                "overlay_is_1x1": overlay_is_1x1,
                "hasMask": overlay_is_1x1,
                "gen_io": True,
                "name": f"{op}_shape_{N}_{Y}_{C}"
            }
            if op in {"softmax", "l2norm"}:
                merged_dict = {
                    "dequant_zero_point": 45875,
                    "dequant_scale": 0.35,
                    "quant_zero_point": 2,
                    "quant_scale": 0.0000152659,
                    "in_dtype_X": "uint16",
                    "out_dtype_Y": "uint16",
                    "attributes": {
                        "disable_dq0": [0],
                        "disable_q": [0],
                    },
                    **test_dict
                    }
            elif op in {"swish", "tanh", "sigmoid", "elu"}:
                merged_dict = {
                    "dequant_zero_point": 6724,
                    "dequant_scale": 0.0039354744,
                    "quant_zero_point": 46,
                    "quant_scale": 0.0035341859,
                    "in_dtype_X": "float16",
                    "out_dtype_Y": "float16",
                    "attributes": {
                        "disable_dq0": [1],
                        "disable_q": [1],
                    },
                    **test_dict
                    }
            elif op == "dequant":
                merged_dict = {
                    "dequant_zero_point": 5,
                    "dequant_scale": 2,
                    "quant_zero_point": 3,
                    "quant_scale": 0.000015259 * 15,
                    "attributes": {
                        "disable_dq0": [0],
                        "disable_q": [1],
                    },
                    "in_dtype_X": "uint16",     # "float32", "bloat16", "float16", "uint32", "uint16", "uint8",  "int32", "int16", "int8"
                    "out_dtype_Y": "float32",
                    **test_dict
                    }
            elif op == "quant":
                merged_dict = {
                    "dequant_zero_point": 5,
                    "dequant_scale": 2,
                    "quant_zero_point": 3,
                    "quant_scale": 0.000015259 * 15,
                    "attributes": {
                        "disable_dq0": [1],
                        "disable_q": [0],
                    },
                    "in_dtype_X": "float32",     # "float32", "bloat16", "float16", "uint32", "uint16", "uint8",  "int32", "int16", "int8"
                    "out_dtype_Y": "uint16",
                    **test_dict
                    }
            elif op == "layernorm":
                merged_dict = {
                    "dequant_zero_point": 5,
                    "dequant_scale": 2,
                    "quant_zero_point": 3,
                    "quant_scale": 0.000015259 * 15,
                    "attributes": {
                        "disable_dq0": [0],
                        "disable_q": [0],
                    },
                    "in_dtype_X": "uint16",     # "float32", "bloat16", "float16", "uint32", "uint16", "uint8",  "int32", "int16", "int8"
                    "out_dtype_Y": "uint16",
                    "in_dtype_gamma": "int32",
                    "in_dtype_beta": "int32",
                    **test_dict
                    }
            elif op == "groupnorm":
                merged_dict = {
                    "dequant_zero_point": 54546,
                    "dequant_scale": 0.0010424006031826138,
                    "quant_zero_point": 39880,
                    "quant_scale": 0.00028025248320773244,
                    "attributes": {
                        "disable_dq0": [0],
                        "disable_q": [0],
                    },
                    "in_dtype_X": "uint16",     # "float32", "bloat16", "float16", "uint32", "uint16", "uint8",  "int32", "int16", "int8"
                    "out_dtype_Y": "uint16",
                    "in_dtype_gamma": "int32",
                    "in_dtype_beta": "int32",
                    **test_dict
                    }
            else:
                merged_dict = {
                    "dequant_zero_point": 5,
                    "dequant_scale": 2,
                    "quant_zero_point": 3,
                    "quant_scale": 0.000015259 * 15,
                    "attributes": {
                        "disable_dq0": [1],
                        "disable_q": [1],
                    },
                    "in_dtype_X": "float16",     # "float32", "bloat16", "float16", "uint32", "uint16", "uint8",  "int32", "int16", "int8"
                    "out_dtype_Y": "float16",
                    **test_dict
                    }
            if overlay_is_1x1:
                assert TensorDim in {(1, 4, 2048), (1, 4, 256), (1, 8, 1024)}
            os.environ["LOG_ENABLED"] = "true"
            set_datatype(fp_16)
            compile_operator(op_mapping[op][0],
                             merged_dict,
                             target,
                             output_root,
                             combined_kernel_includes=combined_kernel_includes,
                             combined_kernel_names=combined_kernel_names,
                             gen_standalone_pdi=gen_pdi,
                             gen_op_elf=is_cert_backend,
                             dump_vcd=vcd,
                             )

            build_dir = os.path.join(output_root, f"op_{op_mapping[op][0]}_shape_{out_dir_name_from_dict(merged_dict)}")
            if target == BuildTarget.SIM:
                sim_log = os.path.join(build_dir, 'AIESimulator.log')
                process_simulation_results(sim_log, shape_table.index(test),
                                           results_list, simtime_list)
            elif target == BuildTarget.CERT:
                create_hw_package(build_dir)
            gen_pdi = False

    # Write results to CSV
    if target == BuildTarget.SIM:
        fieldnames = ['N', 'Y', 'C', 'op', 'overlay']
        csv_file_name = os.path.join(CURRDIR, 'uniop_aiesim_results.csv')
        write_csv(shape_table, results_list,
                  simtime_list, csv_file_name,
                  fieldnames, default_row_mapper(fieldnames))

    if hwtest:
        if op == "dequant":
            dtype = "fp32" if merged_dict["out_dtype_Y"] == "float32" else "bf16"
        else:
            dtype = "uint16"
        run_hw_validation(
            output_root,
            dtype=dtype,
            debug=False,
            host="10.228.200.219"
        )


if __name__ == "__main__":
    typer.run(main)
