import os
import sys
CURRDIR = os.path.dirname(os.path.abspath(__file__))
WORKDIR = os.path.join(CURRDIR, 'Work')
HOSTDIR = os.path.join(CURRDIR, '..' , '..','host')
KERNELDIR = os.path.join(CURRDIR, '..' , '..','kernel')

import subprocess
from typing import List

from gemm_dataflow import generate_gemm_dataflow
from kerneltest.overlay_1x1 import overlay_stack_size, overlay_heap_size
from dmacompiler import BackEnd

def aiecompiler_args(
    M: int, K: int, N: int,
    Msubv: int, Ksubv: int, Nsubv: int,
    run_mode: str,
    sign_A: int, sign_W: int, sign_O: int,
    shift_out: int, vector_coeff: int,
    read_ifm: int, read_wgt:int,
    is_int4: int
) -> List[str]:
    host_filename = HOSTDIR + '/gemm.cpp'
    return [
        'aiecompiler',
        host_filename,
        '-v',
        '--disable-multirate-analysis',
        '--part=xc10MDS1',
        '--adf-api-log-level=5',
        '-log-level=5',
        '--disable-dma-autostart=true',
        '--enable-core-processor-bus=true',
        f'--workdir={WORKDIR}',
        f'--stacksize={overlay_stack_size()}',
        f'--heapsize={overlay_heap_size()}',
        f'--include={CURRDIR}',
        f'--include={HOSTDIR}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel")}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel/common")}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel/gemm_qdq_int16x4")}',
        f'--Xpreproc="-DM_GEMM_A16W8={M}"',
        f'--Xpreproc="-DK_GEMM_A16W8={K}"',
        f'--Xpreproc="-DN_GEMM_A16W8={N}"',
        f'--Xpreproc="-DM_SUBV_A16W8={Msubv}"',
        f'--Xpreproc="-DK_SUBV_A16W8={Ksubv}"',
        f'--Xpreproc="-DN_SUBV_A16W8={Nsubv}"',
        f'--Xpreproc="-DSIGN_ACT={sign_A}"',
        f'--Xpreproc="-DSIGN_WGT={sign_W}"',
        f'--Xpreproc="-DSIGN_OUT={sign_O}"',
        f'--Xpreproc="-DSHIFT_OUT={shift_out}"',
        f'--Xpreproc="-DCOEFF_VECTOR={vector_coeff}"',
        f'--Xpreproc="-DREAD_IFM={read_ifm}"',
        f'--Xpreproc="-DREAD_WGT={read_wgt}"',
        f'--Xpreproc="-DIS_INT4_WGT={is_int4}"',
        f'--Xpreproc="-DASM_MODE={1 if run_mode == "cert" else 0}"',
    ]

def aiesimulator_args() -> List[str]:
    return [
        'aiesimulator',
        '--dump-vcd=trace',
        '--profile',
        f'--pkg-dir={WORKDIR}',
    ]

def main():
    assert len(sys.argv) in (1, 3)
    run_mode = sys.argv[1]
    # check if mode is "dataflow" or "sim" or "cert"
    if run_mode not in ('dataflow', 'sim', 'cert'):
        print(f"Invalid mode: {run_mode}. Use 'dataflow', 'sim', or 'cert'.")
        sys.exit(1)
    shape_index = int(sys.argv[2]) if len(sys.argv) == 3 else None
    shape_table = [
        (
            (32, 256, 64),     # M, K, N
            (32, 256, 64),     # Msubv, Ksubv, Nsubv
            (0, 0, 0),          # sign_A, sign_W, sign_)
            10,                 # Shift_out
            0,          # Vector_coeff
            0, 0,        # read_ifm, read_wgt
            1,           # int4
        ),
        (
            (32, 256, 64),     # M, K, N
            (32, 256, 64),     # Msubv, Ksubv, Nsubv
            (0, 0, 0),          # sign_A, sign_W, sign_)
            10,                 # Shift_out
            1,          # Vector_coeff
            0, 0,        # read_ifm, read_wgt
            1,          # int4
        ),
        (
            (32, 256, 64),     # M, K, N
            (32, 256, 64),     # Msubv, Ksubv, Nsubv
            (0, 0, 0),          # sign_A, sign_W, sign_)
            10,                 # Shift_out
            2,          # Vector_coeff
            0, 0,        # read_ifm, read_wgt
            1,          # int4
        ),
    ]
    if shape_index is not None:
        shape_table = [shape_table[shape_index]]
    for (
            (M, K, N), (Msubv, Ksubv, Nsubv),
            (sign_A, sign_W, sign_O),
            shift_out, vector_coeff,
            read_ifm, read_wgt,
            is_int4
        )in shape_table:
        generate_gemm_dataflow(M, K, N, Msubv, Ksubv, Nsubv, sign_A, sign_W, sign_O, shift_out, vector_coeff, BackEnd.CertAsm if run_mode == 'cert' else BackEnd.Adf)
        compile_args = aiecompiler_args(M, K, N, Msubv, Ksubv, Nsubv, run_mode, sign_A, sign_W, sign_O, shift_out, vector_coeff, read_ifm, read_wgt, is_int4)
        if run_mode != 'dataflow':
            if os.path.isdir(WORKDIR):
                print("Deleting existing Work folder")
                subprocess.run(
                    "rm -r Work",
                    shell=True
                )
            os.makedirs(WORKDIR)
            os.chdir(WORKDIR)
            subprocess.run(compile_args)
            subprocess.run(
                f"sed -i 's/-ladf_api/-ladf_rt_ctrl_api -ladf_api/g' {WORKDIR}/ps/c_rts/systemC/Makefile",
                shell=True
            )
            subprocess.run(
                f'make -C {WORKDIR}/ps/c_rts/systemC/ all',
                shell=True
            )
            aiesim_args = aiesimulator_args()
            subprocess.run(aiesim_args)

if __name__ == '__main__':
    main()
