"""
Regression test for the broadcast layers
"""

from enum import Enum
import os
import random
from typing import Optional
import pytest
import typer

from common import (
    process_simulation_results,
    create_hw_package,
    write_csv,
    default_row_mapper,
    change_dir,
    BuildTarget, DataflowType,
    run_hw_validation,
    clean_output_dir,
    Counter
)
from build_aie4 import compile_operator, out_dir_name_from_dict
from utils.build_utils import set_datatype
from utils.utils_common import log, iceil
from buildscripts.common import dtype_info

CURRDIR = os.path.dirname(os.path.abspath(__file__))
pdi_counter = Counter()


class BroadcastOpType(str, Enum):
    """Broadcast operator types."""

    ADD_UINT8 = "add_uint8"
    ADD_INT8 = "add_int8"
    ADD_UINT16 = "add_uint16"
    ADD_INT16 = "add_int16"
    MUL_UINT16 = "mul_uint16"
    MUL_INT16 = "mul_int16"
    MUL_UINT8 = "mul_uint8"
    MUL_INT8 = "mul_int8"
    SUB_INT8 = "sub_int8"
    SUB_UINT8 = "sub_uint8"
    SUB_INT16 = "sub_int16"
    SUB_UINT16 = "sub_uint16"
    DIV_INT8 = "div_int8"
    DIV_UINT8 = "div_uint8"
    DIV_INT16 = "div_int16"
    DIV_UINT16 = "div_uint16"

    @property
    def onnx_name(self) -> str:
        """Full ONNX operator name for the broadcast operation"""
        _ONNX_NAMES = {
            BroadcastOpType.ADD_UINT8: "Add_qdq_BroadCast_uint8xuint8xuint8",
            BroadcastOpType.ADD_INT8: "Add_qdq_BroadCast_int8xint8xint8",
            BroadcastOpType.ADD_UINT16: "Add_qdq_BroadCast_uint16xuint16xuint16",
            BroadcastOpType.ADD_INT16: "Add_qdq_BroadCast_int16xint16xint16",
            BroadcastOpType.MUL_UINT16: "Mul_qdq_BroadCast_uint16xuint16xuint16",
            BroadcastOpType.MUL_INT16: "Mul_qdq_BroadCast_int16xint16xint16",
            BroadcastOpType.MUL_UINT8: "Mul_qdq_BroadCast_uint8xuint8xuint8",
            BroadcastOpType.MUL_INT8: "Mul_qdq_BroadCast_int8xint8xint8",
            BroadcastOpType.DIV_UINT16: "Div_qdq_BroadCast_uint16xuint16xuint16",
            BroadcastOpType.DIV_INT16: "Div_qdq_BroadCast_int16xint16xint16",
            BroadcastOpType.DIV_UINT8: "Div_qdq_BroadCast_uint8xuint8xuint8",
            BroadcastOpType.DIV_INT8: "Div_qdq_BroadCast_int8xint8xint8",
            BroadcastOpType.SUB_INT8: "Sub_qdq_BroadCast_int8xint8xint8",
            BroadcastOpType.SUB_UINT8: "Sub_qdq_BroadCast_uint8xuint8xuint8",
            BroadcastOpType.SUB_INT16: "Sub_qdq_BroadCast_int16xint16xint16",
            BroadcastOpType.SUB_UINT16: "Sub_qdq_BroadCast_uint16xuint16xuint16",
        }
        assert set(_ONNX_NAMES.keys()) == set(BroadcastOpType), "ONNX_NAMES must contain all enum members"
        return _ONNX_NAMES[self]

    @property
    def friendly_name(self) -> str:
        """Human-readable operator name"""
        return self.value

    @property
    def dtype(self) -> str:
        """Data type for the operator"""
        return self.value.split("_", 1)[1]


# NOTE: ALL UNCOMMENTED SHAPES SHOULD PASS DI
BROADCAST_SHAPES = [
    #  Yi[0],Xi[0],Ci[0]  Yi[1],Xi[1],Ci[1]     Yo,Xo,Co
    # ===================================================
    [10, 8, 77, 77,      10, 1, 77, 77,   10, 8, 77, 77],
    [1, 12, 77, 77,      1, 1, 77, 77,    1, 12, 77, 77],
    [1, 20, 77, 77,      1, 1, 77, 77,     1, 20, 77, 77],
    [1, 16, 16, 1280,    1, 1, 1, 1280,   1, 16, 16, 1280],
    [1, 32, 32, 640,     1, 1, 1, 640,    1, 32, 32, 640],
    [1, 64, 64, 320,     1, 1, 1, 320,    1, 64, 64, 320],
    [1, 16, 16, 1280,    1, 1, 1, 1280,   1, 16, 16, 1280],
    [1, 32, 32, 640,     1, 1, 1, 640,    1, 32, 32, 640],
    [1, 8, 8, 1280,      1, 1, 1, 1280,   1, 8, 8, 1280],
    [1, 1, 64, 1024,     1, 1, 1, 1024,   1, 1, 64, 1024],
    [1, 12, 512, 1,      1, 1, 1, 1,    1, 12, 512, 1],
    # Other bdcast cases that aren't tested, start index 11
    [1, 1, 65, 65,       1, 1, 65, 1,    1, 1, 65, 65],
    [1, 1, 1, 65,       1, 1, 65, 1,    1, 1, 65, 65],
    [1, 1, 1, 65,       1, 1, 65, 65,    1, 1, 65, 65],
    [1, 1, 65, 65,      1, 1, 1, 65,    1, 1, 65, 65],
    [1, 1, 65, 65,       1, 1, 65, 65,    1, 1, 65, 65],
    [1, 1, 64, 64,      1, 1, 1, 64,    1, 1, 64, 64],
    [1, 1, 64, 64,       1, 1, 64, 64,    1, 1, 64, 64],
    [1, 64, 64, 1,       1, 64, 64, 1,     1, 64, 64, 1],
    [1, 1, 64, 1,       1, 64, 64, 1,     1, 64, 64, 1],
    [1, 64, 64, 1,       1, 1, 64, 1,     1, 64, 64, 1],   # start index 20
    [1, 1, 256, 256,       1, 1, 256, 1,    1, 1, 256, 256],
    [1, 1, 1, 256,       1, 1, 256, 1,    1, 1, 256, 256],
    [1, 1, 1, 256,       1, 1, 256, 256,    1, 1, 256, 256],
    [1, 1, 256, 256,      1, 1, 1, 256,    1, 1, 256, 256],
    [1, 1, 256, 256,       1, 1, 256, 256,    1, 1, 256, 256],
    [1, 1, 257, 257,       1, 1, 257, 1,    1, 1, 257, 257],
    [1, 1, 1, 257,       1, 1, 257, 1,    1, 1, 257, 257],
    [1, 1, 1, 257,       1, 1, 257, 257,    1, 1, 257, 257],
    [1, 1, 257, 257,      1, 1, 1, 257,    1, 1, 257, 257],
    [1, 1, 257, 257,       1, 1, 257, 257,    1, 1, 257, 257],  # start index 30
    [1, 1, 257, 513,       1, 1, 257, 1,    1, 1, 257, 513],
    [1, 1, 1, 513,       1, 1, 257, 1,    1, 1, 257, 513],
    [1, 1, 1, 513,       1, 1, 257, 513,    1, 1, 257, 513],
    [1, 1, 257, 513,      1, 1, 1, 513,    1, 1, 257, 513],
    [1, 1, 257, 513,       1, 1, 257, 513,    1, 1, 257, 513],
    [1, 1, 257, 1,       1, 1, 1, 1,    1, 1, 257, 1],
    [1, 1, 1, 1,        1, 1, 257, 1,    1, 1, 257, 1],
]

BROADCAST_SHAPES_L3 = [
    [1, 1, 4096, 4096,   1, 1, 1, 4096, 1, 1, 4096, 4096],    # PSMU 0 - start index 38
    [1, 1, 4096, 1024,   1, 1, 1, 1024,   1, 1, 4096, 1024],      # L3 ONLY, PSMU 0
    [1, 12, 512, 512,       1, 1, 1, 512,   1, 12, 512, 512],  # L3 ONLY, PSH
]

ELEWISE_SHAPES = [
    [1, 56, 56, 256],  # start index 41
    [1, 28, 28, 512],
    [1, 14, 14, 1024],
    [1, 7, 7, 2048],
    [1, 52, 52, 256],
    [1, 26, 26, 512],
    [1, 13, 13, 1024],
    [1, 1, 64, 3072],
    [1, 1, 64, 9216],
    [1, 1, 64, 16384],
    [1, 12, 77, 77],
    [1, 1, 77, 768],
    [1, 12, 77, 1],
    [1, 1, 1, 768],
    [1, 12, 512, 1],  # start index 55
]

BROADCAST_SHAPES_L3_DIV = [[1, 1, 1, 768,   1, 1, 1, 1,   1, 1, 1, 768]]  # PSJ
BROADCAST_SHAPES_ON_N = [[1, 1, 1, 768,   10, 1, 1, 768,   10, 1, 1, 768]]

for e in ELEWISE_SHAPES:
    e.extend(e)
    e.extend(e)

BROADCAST_SHAPES = BROADCAST_SHAPES + BROADCAST_SHAPES_L3 + ELEWISE_SHAPES + BROADCAST_SHAPES_L3_DIV + BROADCAST_SHAPES_ON_N

extra_table = {
    # L2 addresses are hard-coded in scheduler/broadcast as well
    # if l2 fusion is disabled, these should match.
    "input0_addr": {0: 20000},
    "input1_addr": {1: 20000},
    "output_addr": {2: 20000},
    "wgt_addr": [[0, [4096, 5120]],
                 [1, [4096, 5120]],
                 [2, [4096, 5120]]],
    "prm_addr": [[0, 0], [1, 0], [2, 0]],
    "load_input_from_ddr": True,
    "store_output_to_ddr": True,
    # L3 values are changed if b_on_wgt = true, see below
    "attributes": {
        "disable_dq0": [0],
        "disable_q": [0],
    },
    "in_dtype_A": "uint16",
    "in_dtype_B": "uint16",
    "out_dtype_C": "uint16",
    "debug_mode": 1,
}


def random_broadcast_shapes(n: int):
    """
    Generate 200 random broadcast shapes. Each shape is of one of the following forms:
    (1, X, C, 1, X, C, 1, X, C)  <- binary, no broadcasting
    (1, 1, C, 1, X, C, 1, X, C)  <- broadcast on xa
    (1, X, C, 1, 1, C, 1, X, C)  <- broadcast on xb
    (1, X, C, 1, X, 1, 1, X, C ) <- broadcast on cb
    (1, 1, C, 1, X, 1, 1, X, C)  <- broadcast on Xa and cb
    (1, 1, C, 1, 1, 1, 1, 1, C)  <- broadcast on Xa and cb
    We should test both even and odd values of X*Y.
    C should always be either 1 or a multiple of 64.
    """
    patterns = [
        [1, 1, 1, 'C', 1, 1, 'X', 'C', 1, 1, 'X', 'C'],
        [1, 1, 'X', 'C', 1, 1, 1, 'C', 1, 1, 'X', 'C'],
        [1, 1, 'X', 'C', 1, 1, 'X', 1, 1, 1, 'X', 'C'],
        [1, 1, 1, 'C', 1, 1, 'X', 1, 1, 1, 'X', 'C'],
        [1, 'Y', 'X', 'C', 1, 'Y', 1, 'C', 1, 'Y', 'X', 'C'],
        [1, 1, 'X', 'C', 1, 'Y', 1, 'C', 1, 'Y', 'X', 'C'],
        ['N', 1, 'X', 'C', 'N', 'Y', 1, 'C', 'N', 'Y', 'X', 'C'],
    ]
    random_shapes = []
    N_list = list(range(2, 8))
    Y_list = list(range(2, 8))
    X_list = list(range(2, 128))
    C_list = list(range(2, 128))
    random.shuffle(X_list)
    for _ in range(n):
        pattern = random.choice(patterns)
        n = random.choice(N_list)
        y = random.choice(Y_list)
        x = random.choice(X_list)
        c = random.choice(C_list)
        shape = []
        for p in pattern:
            if p == 'N':
                shape.append(n)
            elif p == 'Y':
                shape.append(y)
            elif p == 'X':
                shape.append(x)
            elif p == 'C':
                shape.append(c)
            else:
                shape.append(p)
            assert all(isinstance(x, int) for x in shape), f"Shape contains non-int values: {shape}"
        random_shapes.append(shape)
        if shape[0:4] == shape[4:8] and shape[0:4] == shape[8:12]:
            assert False, f"Should not be generating binary shapes, got X: {x}, C: {c} for shape {shape}"
    return random_shapes


# there are some timeouts and stochastic failures that don't seem to be an issue
CI_BROADCAST_SHAPES = [0, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]


@pytest.mark.dma
@pytest.mark.parametrize("shape_index", CI_BROADCAST_SHAPES)
@pytest.mark.parametrize("operator", [BroadcastOpType.ADD_UINT16, BroadcastOpType.ADD_INT8])
@pytest.mark.parametrize("b_on_wgt", [True, False])
def test_broadcast_add_l3(shape_index: int, operator: BroadcastOpType, b_on_wgt: bool, target: BuildTarget, hwtest: bool, clean: bool, output_root: str) -> None:
    """Test broadcast operations for all shapes for l3 add."""
    main(
        target=target,
        operator=operator,
        a_on_wgt=False,
        b_on_wgt=b_on_wgt,
        shape_index=shape_index,
        dataflow_type=DataflowType.L3,
        hwtest=hwtest,
        clean=clean,
        output_root=output_root,
        random_shapes=False,
        vcd=False,
        disable_dq0=False,
        disable_q=False,
        debug_mode=1,
    )


@pytest.mark.dma
@pytest.mark.parametrize("operator", [BroadcastOpType.MUL_UINT16, BroadcastOpType.MUL_INT8])
@pytest.mark.parametrize("a_on_wgt", [True, False])
def test_broadcast_mul_l2(operator: BroadcastOpType, a_on_wgt: bool, target: BuildTarget, hwtest: bool, clean: bool, output_root: str,) -> None:
    """Test L2 dataflow using mul just for one shape"""
    main(
        target=target,
        operator=operator,
        a_on_wgt=a_on_wgt,
        b_on_wgt=False,
        shape_index=8,
        dataflow_type=DataflowType.L2,
        hwtest=hwtest,
        clean=clean,
        output_root=output_root,
        random_shapes=False,
        vcd=False,
        disable_dq0=False,
        disable_q=False,
        debug_mode=1,
    )


@pytest.mark.dma
@pytest.mark.parametrize("operator", [BroadcastOpType.SUB_UINT16])
def test_broadcast_sub_qdq(operator: BroadcastOpType, target: BuildTarget, hwtest: bool, clean: bool, output_root: str,) -> None:
    """Test that disabling qdq works using sub"""
    main(
        target=target,
        operator=operator,
        a_on_wgt=False,
        b_on_wgt=False,
        shape_index=8,
        dataflow_type=DataflowType.L2,
        hwtest=hwtest,
        clean=clean,
        disable_dq0=False,
        disable_q=False,
        output_root=output_root,
        random_shapes=False,
        vcd=False,
        debug_mode=1,
    )


@pytest.mark.dma
@pytest.mark.parametrize("operator", [BroadcastOpType.DIV_UINT16])
@pytest.mark.parametrize("shape_index", [BROADCAST_SHAPES.index(BROADCAST_SHAPES_L3_DIV[0])])
def test_broadcast_div_qdq(operator: BroadcastOpType, shape_index: int, target: BuildTarget, hwtest: bool, clean: bool, output_root: str,) -> None:
    """Test that disabling qdq works using div"""
    main(
        target=target,
        operator=operator,
        a_on_wgt=False,
        b_on_wgt=False,
        shape_index=shape_index,
        dataflow_type=DataflowType.L3,
        hwtest=hwtest,
        clean=clean,
        disable_dq0=False,
        disable_q=False,
        output_root=output_root,
        random_shapes=False,
        vcd=False,
        debug_mode=1,
    )


def main(
    target: BuildTarget = typer.Option(default=BuildTarget.DATAFLOW, help="Build target for the operator"),
    operator: BroadcastOpType = typer.Option(default=BroadcastOpType.ADD_UINT8, help="Binary operator to test"),
    shape_index: Optional[int] = typer.Option(default=None, help="Index of the shape to test"),
    dataflow_type: DataflowType = typer.Option(default=DataflowType.L2, help="Dataflow type to use"),
    disable_dq0: bool = typer.Option(default=bool(extra_table["attributes"]["disable_dq0"][0]), help="Disable DQ0", is_flag=True),
    disable_q: bool = typer.Option(default=bool(extra_table["attributes"]["disable_q"][0]), help="Disable Q", is_flag=True),
    a_on_wgt: bool = typer.Option(default=bool(False), help="Put IFM A on weight buffer", is_flag=True),
    b_on_wgt: bool = typer.Option(default=bool(False), help="Put IFM B on weight buffer", is_flag=True),
    debug_mode: int = typer.Option(default=extra_table["debug_mode"], help="Debug mode level (0=off)"),
    hwtest: bool = typer.Option(default=False, help="Run HW_test after builds", is_flag=True),
    clean: bool = typer.Option(default=False, help="Clean output directory before running", is_flag=True),
    output_root: str = typer.Option(default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output"),
    random_shapes: bool = typer.Option(default=False, help="Test dataflow for 1000 random shapes", is_flag=True),
    vcd: bool = typer.Option(default=False, help="Dump VCD for waveform analysis", is_flag=True),
    fp_16: bool = typer.Option(default=False, help="Is QDQ FP16 Datatype?", is_flag=True),
) -> None:
    """Function for running add regression testing using build script"""
    operator_name = operator.onnx_name
    assert not (a_on_wgt and b_on_wgt), "Cannot put both IFM A and IFM B on weight buffer"

    if random_shapes:
        shape_table = random_broadcast_shapes(1000)
        target = BuildTarget.DATAFLOW
    elif shape_index is not None:
        shape_table = [BROADCAST_SHAPES[shape_index]]
    else:
        shape_table = BROADCAST_SHAPES

    log(f"Using shape table: {shape_table}")
    is_cert_backend = target == BuildTarget.CERT
    gen_pdi = is_cert_backend and pdi_counter() == 1
    # Initialize parameters
    results_list = [""] * len(shape_table)
    simtime_list = [0.0] * len(shape_table)
    in_dtype = None
    out_dtype = None

    output_root = str(output_root)
    print(output_root)
    clean_output_dir(output_root, clean)

    ifm_bytes = 0
    with change_dir("../"):
        for shape_idx, shape in enumerate(shape_table):
            base_dtype = operator.dtype
            in_dtype = base_dtype
            out_dtype = base_dtype
            if operator.dtype in ("uint16", "int16"):
                if disable_dq0:
                    in_dtype = "float16" if fp_16 else "bfloat16"
                if disable_q:
                    out_dtype = "float16" if fp_16 else "bfloat16"
            extra_table["in_dtype_A"] = in_dtype
            extra_table["in_dtype_B"] = in_dtype
            extra_table["out_dtype_C"] = out_dtype
            ifm_bytes = dtype_info(extra_table["in_dtype_A"])[0] // 8
            shape.append(ifm_bytes)
            extra_table["attributes"]["disable_dq0"] = [int(disable_dq0)]
            extra_table["attributes"]["disable_q"] = [int(disable_q)]
            extra_table["debug_mode"] = int(debug_mode)
            extra_table["a_on_wgt"] = [int(a_on_wgt)]  # for the name
            extra_table["b_on_wgt"] = [int(b_on_wgt)]  # for the name
            extra_table["input_types"] = {"A": "const" if a_on_wgt else "act", "B": "const" if b_on_wgt else "act"}
            extra_table["qdq_mode"] = [
                int((not disable_dq0)*(2**1) + (not disable_q)*(2**0))]
            if b_on_wgt or a_on_wgt:
                # We adjust L3 addresses to put IFM1 on weight buffer in build_broadcast.py
                extra_table["L3"] = {"ifm": [1, 0], "ofm": [0, 0], "wgt": [2, 0]}
            else:
                # Otherwise, we use the below addresses
                pad_all_channels = shape[11] == 1
                Cia = shape[3]
                ifm_a_channels = iceil(Cia, 64) if pad_all_channels or not Cia == 1 else 1
                ifm_a_bytes = shape[0] * shape[1] * shape[2] * ifm_a_channels * ifm_bytes
                extra_table["L3"] = {"ifm0": [1, 0], "ifm1": [1, iceil(ifm_a_bytes, 4)], "ofm": [0, 0], "wgt": [2, 0]}

            a_key = "input0" if a_on_wgt else "input"
            b_key = "input" if a_on_wgt else "input1"

            merged_shape = {
                a_key: shape[:4],
                b_key: shape[4:8],
                "output": shape[8:12],
                "dataflow_type": [int(dataflow_type == DataflowType.L3)],
                "enable_L2_fusion": dataflow_type == DataflowType.L2,
                "op": operator_name,
                **extra_table,
            }
            log(f"Building add with shape: {shape}")
            set_datatype(fp_16)
            if debug_mode:
                os.environ["LOG_ENABLED"] = "true"
            else:
                os.environ.pop("LOG_ENABLED", None)
            compile_operator(
                operator_name,
                merged_shape,
                target,
                output_root,
                gen_standalone_pdi=gen_pdi,
                gen_op_elf=is_cert_backend,
                dump_vcd=vcd,
            )
            build_dir = os.path.join(output_root, f"op_{operator_name}_shape_{out_dir_name_from_dict(merged_shape)}")
            if target == BuildTarget.SIM:
                sim_log = os.path.join(build_dir, "AIESimulator.log")
                process_simulation_results(
                    sim_log, shape_idx, results_list, simtime_list
                )
            elif target == BuildTarget.CERT:
                assert os.path.exists(build_dir), f"build dir: {build_dir}"
                create_hw_package(build_dir)
            gen_pdi = False  # only first build generates PDI
            print(f"Completed build for shape index {shape_idx} in {build_dir}")

    if hwtest:
        assert in_dtype and in_dtype == out_dtype, "HW test only supports same input/output dtypes"
        run_hw_validation(output_root, in_dtype, host="10.228.202.104")

    # Write results to CSV
    if target == BuildTarget.SIM:
        fieldnames = [
            "Nia",
            "Yia",
            "Xia",
            "Cia",
            "Nib",
            "Yib",
            "Xib",
            "Cib",
            "No",
            "Yo",
            "Xo",
            "Co",
            "ifm_bytes",
        ]
        csv_file_name = os.path.join(CURRDIR, "broadcast_aiesim_results.csv")
        write_csv(
            shape_table,
            results_list,
            simtime_list,
            csv_file_name,
            fieldnames,
            default_row_mapper(fieldnames),
        )


if __name__ == "__main__":
    typer.run(main)
