import os
import sys

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
HOSTDIR = os.path.join(CURRDIR, '..' , '..', 'host')
WORKDIR = os.path.join(CURRDIR, 'Work')

import subprocess
from typing import List

from conv_dataflow import (
    ConvDims,
    compile_conv_dataflow,
    )
from conv_helpers import (ActMode, ActFmt)
from kerneltest.overlay_1x1 import (overlay_stack_size, overlay_heap_size)
from dmacompiler import BackEnd

def aiecompiler_args(
    dims: ConvDims,
    shift_res: int, shift_bias: int,
    sign_act: int, sign_wgt: int, sign_out: int,
    act_mode: ActMode,
    run_mode: str,
) -> List[str]:
    host_filename = HOSTDIR + '/conv.cpp'
    return [
        'aiecompiler',
        host_filename,
        '-v',
        '--disable-multirate-analysis',
        '--part=xc10MDS1',
        '--adf-api-log-level=5',
        '-log-level=5',
        '--disable-dma-autostart=true',
        '--enable-core-processor-bus=true',
        f'--workdir={WORKDIR}',
        f'--stacksize={overlay_stack_size()}',
        f'--heapsize={overlay_heap_size()}',
        f'--include={CURRDIR}',
        f'--include={HOSTDIR}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel")}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel/common")}',
        f'--include={os.path.join(CURRDIR, "..", "..", "kernel/conv")}',
        f'--Xpreproc="-DC_IN={dims.Ci}"',
        f'--Xpreproc="-DY_IN={dims.Yi}"',
        f'--Xpreproc="-DX_IN={dims.Xi}"',
        f'--Xpreproc="-DC_OUT={dims.Co}"',
        f'--Xpreproc="-DY_OUT={dims.Yo}"',
        f'--Xpreproc="-DX_OUT={dims.Xo}"',
        f'--Xpreproc="-DCIS={dims.Cis}"',
        f'--Xpreproc="-DYIS={dims.Yis}"',
        f'--Xpreproc="-DXIS={dims.Xis}"',
        f'--Xpreproc="-DCOS={dims.Cos}"',
        f'--Xpreproc="-DYOS={dims.Yos}"',
        f'--Xpreproc="-DXOS={dims.Xos}"',
        f'--Xpreproc="-DKERNEL_Y={dims.Ky}"',
        f'--Xpreproc="-DKERNEL_X={dims.Kx}"',
        f'--Xpreproc="-DSTRIDE_Y={dims.Sy}"',
        f'--Xpreproc="-DSTRIDE_X={dims.Sx}"',
        f'--Xpreproc="-DPAD_Y={dims.Py}"',
        f'--Xpreproc="-DPAD_X={dims.Px}"',
        f'--Xpreproc="-DOUT_SHIFT={shift_res}"',
        f'--Xpreproc="-DBIAS_SHIFT={shift_bias}"',
        f'--Xpreproc="-DSIGN_ACT={sign_act}"',
        f'--Xpreproc="-DSIGN_WGT={sign_wgt}"',
        f'--Xpreproc="-DSIGN_OUT={sign_out}"',
        f'--Xpreproc="-DACT_MODE={act_mode}"',
        f'--Xpreproc="-DASM_MODE={1 if run_mode == "cert" else 0}"',
        f'--Xpreproc="-DCONV_NOQDQ_A8W8={1}"',
        f'--Xpreproc="-DC_OUT_SPLIT={1}"',
    ]

def aiesimulator_args() -> List[str]:
    return [
        'aiesimulator',
        '--dump-vcd=trace',
        '--profile',
        f'--pkg-dir={WORKDIR}',
        '--mt-model=false',     # Work around for https://jira.xilinx.com/browse/CR-1235452
    ]

shape_table = [
    (
        (8, 8, 128),     # Yi, Xi, Ci
        (8, 8, 64),     # Yo, Xo, Co
        (8, 8, 128),     # Yis, Xis, Cis
        (8, 8, 64),     # Yos, Xos, Cos
        (1, 1),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        12, 1,          # Shift_res, shift_bias
        0, 0, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (4, 16, 256),     # Yi, Xi, Ci
        (4, 16, 64),     # Yo, Xo, Co
        (4, 16, 256),     # Yis, Xis, Cis
        (4, 16, 64),     # Yos, Xos, Cos
        (1, 1),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        12, 1,          # Shift_res, shift_bias
        0, 0, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (21, 21, 8),     # Yi, Xi, Ci
        (8, 8, 64),     # Yo, Xo, Co
        (21, 21, 8),     # Yis, Xis, Cis
        (8, 8, 64),     # Yos, Xos, Cos
        (7, 7),         # Ky, Kx
        (0, 0),         # Py, Px
        (2, 2),         # Sy, Sx
        12, 1,          # Shift_res, shift_bias
        0, 0, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (8, 8, 64),     # Yi, Xi, Ci
        (8, 8, 64),     # Yo, Xo, Co
        (8, 8, 64),     # Yis, Xis, Cis
        (8, 8, 64),     # Yos, Xos, Cos
        (1, 1),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        6, 2,          # Shift_res, shift_bias
        1, 1, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (8, 8, 128),     # Yi, Xi, Ci
        (8, 8, 64),     # Yo, Xo, Co
        (8, 8, 128),     # Yis, Xis, Cis
        (8, 8, 64),     # Yos, Xos, Cos
        (1, 1),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        6, 1,          # Shift_res, shift_bias
        1, 1, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (10, 10, 64),   # Yi, Xi, Ci
        (8, 8, 64),     # Yo, Xo, Co
        (10, 10, 64),   # Yis, Xis, Cis
        (8, 8, 64),     # Yos, Xos, Cos
        (3, 3),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        6, 0,          # Shift_res, shift_bias
        1, 1, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (12, 12, 64),   # Yi, Xi, Ci
        (8, 8, 64),     # Yo, Xo, Co
        (12, 12, 64),   # Yis, Xis, Cis
        (8, 8, 64),     # Yos, Xos, Cos
        (5, 5),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        8, 0,          # Shift_res, shift_bias
        1, 1, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (1, 64, 64),    # Yi, Xi, Ci
        (1, 64, 64),    # Yo, Xo, Co
        (1, 64, 64),    # Yis, Xis, Cis
        (1, 64, 64),    # Yos, Xos, Cos
        (1, 1),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        10, 0,          # Shift_res, shift_bias
        1, 1, 1,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (3, 66, 64),    # Yi, Xi, Ci
        (1, 64, 64),    # Yo, Xo, Co
        (3, 66, 64),    # Yis, Xis, Cis
        (1, 64, 64),    # Yos, Xos, Cos
        (3, 3),         # Ky, Kx
        (0, 0),         # Py, Px
        (1, 1),         # Sy, Sx
        10, 0,          # Shift_res, shift_bias
        1, 1, 1,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
    (
        (17, 17, 64),    # Yi, Xi, Ci
        (8, 8, 64),    # Yo, Xo, Co
        (17, 17, 64),    # Yis, Xis, Cis
        (8, 8, 64),    # Yos, Xos, Cos
        (3, 3),         # Ky, Kx
        (0, 0),         # Py, Px
        (2, 2),         # Sy, Sx
        8, 0,          # Shift_res, shift_bias
        1, 1, 0,        # sign of act, wgt, out 
        ActMode.AC_RELU, # activation mode
    ),
]

def main():
    assert len(sys.argv) in (1, 3)
    run_mode = sys.argv[1]
    # check if mode is "dataflow" or "sim" or "cert"
    if run_mode not in ('dataflow', 'sim', 'cert'):
        print(f"Invalid mode: {run_mode}. Use 'dataflow', 'sim', or 'cert'.")
        sys.exit(1)
    shape_index = int(sys.argv[2]) if len(sys.argv) == 3 else None
    if shape_index is not None:
        to_run_shape_table = [shape_table[shape_index]]
    for (
        (Yi, Xi, Ci),
        (Yo, Xo, Co),
        (Yis, Xis, Cis),
        (Yos, Xos, Cos),
        (Ky, Kx),
        (Py, Px),
        (Sy, Sx),
        shift_res, shift_bias,
        sign_act, sign_wgt, sign_out,
        act_mode,
        ) in to_run_shape_table:
        # NOTE: ACT mode is not packed in layer params
        N = 1  # Batch dimension
        aie_rows = 1
        aie_cols = 1
        act_bits = 8
        wgt_bits = 8
        bias_bits = 16
        out_bits = 8
        param_bits = 8 
        Ci_gran = 64
        if Ci < 64:
            Ci_gran = 8
        act_fmt = ActFmt.CYXC 
        dims = ConvDims(
            N,
            Yi, Xi, Ci, Yo, Xo, Co, Yis, Xis, Cis, Yos, Xos, Cos, Ky, Kx, Py, Px, Sy, Sx,
            aie_rows, aie_cols, act_bits, wgt_bits, bias_bits, out_bits, param_bits,
            Ci_gran, act_fmt,
        )
        print(f"Compiling for shape: {dims}")
        compile_conv_dataflow(dims, BackEnd.TxnHostPatch if run_mode == 'cert' else BackEnd.Adf)
        compile_args = aiecompiler_args(dims, shift_res, shift_bias, sign_act, sign_wgt, sign_out, act_mode, run_mode)
        if run_mode != 'dataflow':
            if os.path.isdir(WORKDIR):
                print("Deleting existing Work folder")
                subprocess.run(
                    "rm -r Work",
                    shell=True
                )
            os.makedirs(WORKDIR)
            os.chdir(WORKDIR)
            subprocess.run(' '.join(compile_args), shell=True)
            subprocess.run(
                f"sed -i 's/-ladf_api/-ladf_rt_ctrl_api -ladf_api/g' {WORKDIR}/ps/c_rts/systemC/Makefile",
                shell=True
            )
            subprocess.run(
                f'make -C {WORKDIR}/ps/c_rts/systemC/ all',
                shell=True
            )
            aiesim_args = aiesimulator_args()
            subprocess.run(' '.join(aiesim_args), shell=True)

if __name__ == '__main__':
    main()
