'''
Regression test for the convolutional layers
'''
import os
from typing import Optional
import typer
from common import change_dir, BuildTarget
from build_aie4 import compile_operator
from graph.utilities import config_logger_from_env

CURRDIR = os.path.dirname(os.path.abspath(__file__))
config_logger_from_env()


RESNET50_SHAPES = [

    [7, 7, 2048, 1, 1, 2048, 8, 8, 1, 1, 8, 8, 4, 3, 7, 7, 256, 1, 1, 256],
]


# NOTE: Disabling GAP PyTest temporarily till it's refactored according to new build interface
# @pytest.mark.dma
# @pytest.mark.parametrize(
#     "shape_index",
#     range(len(RESNET50_SHAPES)),
#     ids=[f"shape_{i}_{RESNET50_SHAPES[i]}" for i in range(len(RESNET50_SHAPES))],
# )
# def test_gap(shape_index: int) -> None:
#     '''Run all the tests'''
#     main(['_', 'dataflow', str(shape_index)])


def main(
    target: BuildTarget = typer.Option(default=BuildTarget.DATAFLOW, help="Build target for the operator"),
    shape_index: Optional[int] = typer.Option(default=None, help="Index of the shape to test"),
    output_root: str = typer.Option(default=os.path.join(CURRDIR, "..", "Output"), help="Root directory for output")
) -> None:
    '''Function for running GAP regression testing using build script'''
    gen_pdi = target == BuildTarget.CERT
    output_root = str(output_root)

    if shape_index is not None:
        shape_table = [RESNET50_SHAPES[shape_index]]
    else:
        shape_table = RESNET50_SHAPES

    with change_dir("../"):
        for shape in shape_table:
            Yi, Xi, Ci = shape[0:3]
            Yo, Xo, Co = shape[3:6]
            act_bits, out_bits = shape[6:8]
            sign_act, sign_out = shape[8:10]
            prm_bits, bits_per_byte = shape[10:12]
            aie_rows, aie_cols = shape[12:14]
            Yis, Xis, Cis = shape[14:17]
            Yos, Xos, Cos = shape[17:20]
            shape = [Yi, Xi, Ci, Yo, Xo, Co, act_bits, out_bits, sign_act, sign_out, prm_bits,
                     bits_per_byte, aie_rows, aie_cols, Yis, Xis, Cis, Yos, Xos, Cos]
            os.environ["LOG_ENABLED"] = "true"
            compile_operator("gap",
                             shape,
                             target,
                             output_root,
                             gen_standalone_pdi=gen_pdi
                             )


if __name__ == '__main__':
    typer.run(main)
