'''
Define unit tests for unitary operator's dataflow
'''
import os
import re
import subprocess

from buildscripts.build_uniop import compile_uniop_3x4_dataflow
from buildscripts.common import ScheduleInputs
from scheduler.uniop.uniop_common import UnaryMapping, UnaryShape
from scheduler.uniop.uniop_util import SpatialSplitModes
from utils.utils_common import log
from dmacompiler import BackEnd


def parse_split_mode(split_mode_string):
    '''parse the split mode string and obtain split factors with three scalars'''

    # Extract numbers
    spatial_split_factors = [int(num) for num in re.findall(r'\d+', split_mode_string)]
    # Extract alphabets
    alphabets = [char for char in split_mode_string if char.isalpha()]

    assert len(spatial_split_factors) == 3
    return spatial_split_factors, alphabets


def test_deadlock_sweep_shapes():
    '''test shapes with dataflow through dmacompiler stage. see if deadlock check passes and cpp source files generated '''

    AllSplitModes = SpatialSplitModes()
    SubVolumeDim = (2, 12, 1024)
    Nsubv, Xsubv, Csubv = SubVolumeDim

    for SpatialSplitMode in AllSplitModes.Table:

        split_factors, _ = parse_split_mode(SpatialSplitMode)
        assert len(split_factors) == 3
        N_split_factor, X_split_factor, C_split_factor = split_factors
        TensorDim = (Nsubv * N_split_factor, Xsubv * X_split_factor, Csubv * C_split_factor)

        # status = gen_cpp_sources(TensorDim, SubVolumeDim, SpatialSplitMode, function="silu", remove_cpp_sources=False)

        kernel_names = {"run_copy_fp16x16": 25}

        kernel_includes = ["super.hh", "q/q.hpp", "dq/dq.hpp"]
        kernel_includes = kernel_includes + ["softmax_fp16x16/copy_fp16x16_wrapper.cc"]

        dims_shape: UnaryShape = UnaryShape("copy", TensorDim, ifmbytes=2, ofmbytes=2, ifmSign=0, ofmSign=0, SpatialSplitMode=SpatialSplitMode)
        mapping: UnaryMapping = UnaryMapping(TensorDim, TensorDim, SubVolumeDim, SubVolumeDim, (1, 2, TensorDim[2]), SpatialSplitMode)

        schedule_input = ScheduleInputs(shape=None, mapping=None, dataflow_type=1, L2_alloc=None, L3_alloc=None)

        schedule_input.shape = dims_shape
        schedule_input.mapping = mapping
        schedule_input.kernel_includes = kernel_includes
        schedule_input.kernel_names = kernel_names
        schedule_input.backend = BackEnd.Adf
        schedule_input.layer_file_name = "dma.hpp"

        compile_uniop_3x4_dataflow(schedule_input)

        # ###########################################################################
        # ## Check If source files (4 files) is generated :
        # #############################################################################
        currdir = os.getcwd()  # assume to be where cpp sources are generated
        dma_hpp_exist = os.path.exists(currdir + "/dma.hpp")
        graph_hpp_exist = os.path.exists(currdir + "/graph.hpp")
        super_cc_exist = os.path.exists(currdir + "/super.cc")
        super_hh_exist = os.path.exists(currdir + "/super.hh")

        error_code = 0
        error_code += (0 if dma_hpp_exist else -1)
        error_code += (0 if graph_hpp_exist else -2)
        error_code += (0 if super_cc_exist else -4)
        error_code += (0 if super_hh_exist else -8)

        os.environ["LOG_ENABLED"] = "true"
        log("Generate source status:", error_code, "<-----", TensorDim, SubVolumeDim, SpatialSplitMode)
        # ###########################################################################
        # Remove prior work folder :
        # #############################################################################
        if error_code == 0 and os.path.isdir(currdir):
            log("Deleting generated cpp sources")
            subprocess.run("rm -rf " + currdir + "/dma.hpp", shell=True, check=False)
            subprocess.run("rm -rf " + currdir + "/graph.hpp", shell=True, check=False)
            subprocess.run("rm -rf " + currdir + "/super.hh", shell=True, check=False)
            subprocess.run("rm -rf " + currdir + "/super.cc", shell=True, check=False)


def main():
    ''' Test entry point for deadlock checks through various shapes of unitary op
    kernel_names = {"run_copy_fp16x16": 25}
    kernel_includes = ["super.hh", "q/q.hpp", "dq/dq.hpp"]
    kernel_includes = kernel_includes + ["softmax_fp16x16/copy_fp16x16_wrapper.cc"]

    TensorDim = (1, 4096, 144)
    TensorDim, SubVolumeDim, SpatialSplitMode = uniop_tiler(TensorDim, function="dequant")
    compile_uniop_3x4_dataflow(TensorDim, SubvolumeDim=SubVolumeDim, SpatialSplitMode=SpatialSplitMode,
                                kernel_includes=kernel_includes, kernel_names=kernel_names, func="copy")
    '''
    test_deadlock_sweep_shapes()


if __name__ == '__main__':
    main()
