import os
import sys

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
HOSTDIR = os.path.join(CURRDIR, '..' , '..', 'host')
WORKDIR = os.path.join(CURRDIR, 'Work')

import subprocess
from typing import List

from conv_dataflow import (
    ConvDims,
    compile_conv_dataflow,
)

from scheduler.common import (
    LinearOpType,
)

from conv_helpers import (ActMode, ActFmt)
from kerneltest.overlay_1x1 import (overlay_stack_size, overlay_heap_size)
from dmacompiler import BackEnd
from buildscripts.common import save_cfg_json

def conv_preproc_directives(
    dims: ConvDims,
    back_end: BackEnd,
    out_shift: int = 0,
    bias_shift: int = 0,
    sign_act: int = 0,
    sign_wgt: int = 0,
    sign_out: int = 0,
    vec_coeff: int = 0,
    act_mode: ActMode = ActMode.AC_RELU,
    dtype: LinearOpType = LinearOpType(1),
    read_bins: list[int] = [0, 0],
) -> None:
    '''Create a list of testbench arguments for the convolution operator'''
    asm_mode = int(back_end != BackEnd.Adf)
    Yi, Xi, Ci = dims.Yi, dims.Xi, dims.Ci
    Yo, Xo, Co = dims.Yo, dims.Xo, dims.Co
    Ky, Kx = dims.Ky, dims.Kx
    Sy, Sx = dims.Sy, dims.Sx
    Py, Px = dims.Py, dims.Px
    Yis, Xis, Cis = dims.Yis, dims.Xis, dims.Cis
    Yos, Xos, Cos = dims.Yos, dims.Xos, dims.Cos
    C_split = 1
    qdq = 1
    if dtype == LinearOpType.conv_A8W8_noqdq:
        qdq = 0
    cfg = {
        "C_IN": Ci,
        "Y_IN": Yi,
        "X_IN": Xi,
        "C_OUT": Co,
        "Y_OUT": Yo,
        "X_OUT": Xo,
        "KERNEL_Y": Ky,
        "KERNEL_X": Kx,
        "STRIDE_Y": Sy,
        "STRIDE_X": Sx,
        "CIS": Cis,
        "YIS": Yis,
        "XIS": Xis,
        "COS": Cos,
        "YOS": Yos,
        "XOS": Xos,
        "PAD_Y": Py,
        "PAD_X": Px,
        "C_OUT_SPLIT": C_split,
        "SHIFT_OUT": out_shift,
        "BIAS_SHIFT": bias_shift,
        "COEFF_VECTOR": vec_coeff,
        "SIGN_ACT": sign_act,
        "SIGN_WGT": sign_wgt,
        "SIGN_OUT": sign_out,
        "ACT_MODE": act_mode,
        "QDQ": qdq,
        "ASM_MODE": asm_mode,
        "READ_IFM": read_bins[0],
        "READ_WGT": read_bins[1],
        "DTYPE_ACT": 'uint16',
        "DTYPE_WGT": 'uint8',
        "DTYPE_OFM": 'uint16',
        "DTYPE_BIAS": 'float',
    }
    save_cfg_json(cfg, WORKDIR + "/conv_cfg.json")
    

def aiecompiler_args(
    dims: ConvDims,
    shift_res: int, shift_bias: int,
    sign_act: int, sign_wgt: int, sign_out: int,
    vec_coeff: int,
    act_mode: ActMode,
    run_mode: str,
) -> List[str]:
    host_filename = HOSTDIR + '/conv.cpp'
    conv_preproc_directives(
        dims,
        run_mode,
        shift_res,
        shift_bias,
        sign_act, sign_wgt, sign_out, vec_coeff, 
        act_mode, dtype= LinearOpType(5),
    )
    return [
        'aiecompiler',
        host_filename,
        '-v',
        '--disable-multirate-analysis',
        '--part=xc10MDS1',
        '--adf-api-log-level=5',
        '-log-level=5',
        '--disable-dma-autostart=true',
        '--enable-core-processor-bus=true',
        f'--workdir={WORKDIR}',
        f'--stacksize={overlay_stack_size()}',
        f'--heapsize={overlay_heap_size()}',
        f'--include={CURRDIR}',
        f'--include={HOSTDIR}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel")}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel/common")}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel/conv")}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kerneltest/conv_int16x8")}',
    ]

def aiesimulator_args() -> List[str]:
    return [
        'aiesimulator',
        '--dump-vcd=trace',
        '--profile',
        f'--pkg-dir={WORKDIR}',
        '--mt-model=false',     # Work around for https://jira.xilinx.com/browse/CR-1235452
    ]

shape_table = [
    (
        (6, 10, 64),     # Yi, Xi, Ci
        (4, 8, 64),     # Yo, Xo, Co
        (6, 10, 64),     # Yis, Xis, Cis
        (4, 8, 64),     # Yos, Xos, Cos
        (3, 3),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        10, 0,            # Shift_res, shift_bias
        0, 0, 0,        # sign of act, wgt, out 
        0,              # vec_coeff
    ),
    
    (
        (6, 10, 64),     # Yi, Xi, Ci
        (4, 8, 64),     # Yo, Xo, Co
        (6, 10, 64),     # Yis, Xis, Cis
        (4, 8, 64),     # Yos, Xos, Cos
        (3, 3),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        10, 0,            # Shift_res, shift_bias
        1, 1, 1,        # sign of act, wgt, out 
        0,              # vec_coeff
    ),
]

def main():
    assert len(sys.argv) in (1, 3)
    run_mode = sys.argv[1]
    # check if mode is "dataflow" or "sim" or "cert"
    if run_mode not in ('dataflow', 'sim', 'cert'):
        print(f"Invalid mode: {run_mode}. Use 'dataflow', 'sim', or 'cert'.")
        sys.exit(1)
    shape_index = int(sys.argv[2]) if len(sys.argv) == 3 else None
    if shape_index is not None:
        to_run_shape_table = [shape_table[shape_index]]
    for (
        (Yi, Xi, Ci),
        (Yo, Xo, Co),
        (Yis, Xis, Cis),
        (Yos, Xos, Cos),
        (Ky, Kx),
        (Py, Px),
        (Sy, Sx),
        shift_res, shift_bias,
        sign_act, sign_wgt, sign_out,
        vec_coeff
        ) in to_run_shape_table:
        N = 1  # Batch dimension
        aie_rows = 1
        aie_cols = 1
        act_bits = 16
        wgt_bits = 8
        bias_bits = 32
        out_bits = 16
        param_bits = 8 
        Ci_gran = 64
        if Ci < 64:
            Ci_gran = 8
        act_fmt = ActFmt.CYXC
        act_mode = ActMode.AC_SRS
        dims = ConvDims(
            N,
            Yi, Xi, Ci, Yo, Xo, Co, Yis, Xis, Cis, Yos, Xos, Cos, Ky, Kx, Py, Px, Sy, Sx,
            aie_rows, aie_cols, act_bits, wgt_bits, bias_bits, out_bits, param_bits,
            Ci_gran, act_fmt,
        )
        print(f"Compiling for shape: {dims}")
        compile_conv_dataflow(dims, BackEnd.TxnHostPatch if run_mode == 'cert' else BackEnd.Adf)
        if run_mode != 'dataflow':
            if os.path.isdir(WORKDIR):
                print("Deleting existing Work folder")
                subprocess.run(
                    "rm -r Work",
                    shell=True
                )
            os.makedirs(WORKDIR)
            os.chdir(WORKDIR)
            compile_args = aiecompiler_args(dims, shift_res, shift_bias, sign_act, sign_wgt, sign_out, vec_coeff, act_mode, run_mode)
            subprocess.run(' '.join(compile_args), shell=True)
            subprocess.run(
                f"sed -i 's/-ladf_api/-ladf_rt_ctrl_api -ladf_api/g' {WORKDIR}/ps/c_rts/systemC/Makefile",
                shell=True
            )
            subprocess.run(
                f'make -C {WORKDIR}/ps/c_rts/systemC/ all',
                shell=True
            )
            aiesim_args = aiesimulator_args()
            subprocess.run(' '.join(aiesim_args), shell=True)

if __name__ == '__main__':
    main()
