import os
from typing import List, Optional, Dict, Union
import shutil
import json
import subprocess
import sys
import ast 

from dmacompiler import BackEnd
from dmacompiler import set_dev_gen, DevGen, config

from dataflow.xclbin import dummy_dataflow

from dataflow.mha.mini_mha.overlay import generate_dataflow_mini_mha

from dataflow.mha.mha_2p1.overlay import generate_dataflow_2p1_mha

from dataflow.mha.mha_general_dflow.src.generate_dataflow_4x8 import (
    generate_dataflow_mha_3p0,
    Mha3p0Parameters,
    generate_standalone_buffer_allocations,
)
from dataflow_common import (
    clean_overlay,
    build_sim_overlay,
    build_txn_aiert,
    overlay_stack_size,
    overlay_heap_size,
)

from OGOAT.src.Scheduling_Engine.schedules.BufferAllocatorResult import (
    BufferAllocations,
)
from OGOAT.src.L1_fusion.py_match.adv.attention import MHAMode

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', 'dmacompiler'))
XILINX_VITIS_AIETOOLS = os.environ.get("XILINX_VITIS_AIETOOLS")
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list
test_dir = os.environ.get("TEST_DIR")

def mha_preprocessor_directives(
    H: int,
    M: int,
    K: int,
    N: int,
    L: int,
    M_subv: int,
    back_end: BackEnd,
    model_data_folder: Optional[str] = None,
    attn_mask_exist = False,
    Bias: Optional[int] = None,
    G: Optional[int] = None,
):
    def directive(ident: str, val: int) -> str:
        #return f'--Xpreproc="-D{ident}={val}"'
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    base_directives = [
        #'--Xchess="main:backend.mist2.pnll=off"',
        #'--Xchess="main:backend.amnesia.rls=on"',
        #f'stacksize={overlay_stack_size()} heapsize={overlay_heap_size()}',
        directive('SQ_IN', M),
        directive('DH_IN', K),
        directive('SK_IN', N),
        directive('DV_IN', L),
        directive('H_IN', H),
        directive('SQ_IN_SUBV', M_subv),
        directive('SQ_IN_SUBV', M_subv),
        directive('ATTN_MASK_EXIST', 1 if(attn_mask_exist) else 0),
        directive('BIAS_EXIST', Bias if(Bias!= None) else 0),
        directive('TEST_BENCH_DIR', model_data_folder),
        directive('G_IN', G if(G is not None) else H)
        #f"-DTEST_BENCH_DIR={model_data_folder}"
        #f'--Xpreproc="-DTEST_BENCH_DIR={model_data_folder}"',
    ]
    
    aiesim_directives = [
        '--Xchess="main:backend.mist2.pnll=off"',
        "--Xpreproc=-D__AIE_API_WORKAROUND_CR_1223259__=1",
        "xlopt=0"
        #'--Xchess="main:backend.amnesia.rls=on"',
        f'stacksize={overlay_stack_size()} heapsize={overlay_heap_size()}'
    ]

    return base_directives+aiesim_directives if(back_end == BackEnd.Adf) else base_directives


def verbose_run(command: str):
    print(command)
    env = os.environ.copy()
    env['BASH_FUNC_make%%'] = f'() {{ patch_aiecompiler_make.py $*; /usr/bin/make $@; }}'
    subprocess.run(command, shell=True, env=env)


def build_txn_bins(
    host_filename: str,
    compile_flags: List[str]
):
    print_flags = [
        '--Xpreproc="-DLOG_CORE_COL=0"',
        '--Xpreproc="-DLOG_CORE_ROW=2"'
    ]
    large_memory_flag = ['--large-program-memory=true']
    compile_args = [
        'aiecompiler',
        host_filename,
        '--target=hw',
        '--part=xc10AIE2P_ML-die-0x-e-S-es1',
        f'--include={os.path.join(CURRDIR, "..", "..","kernels")}',
        f'--include={os.path.join(CURRDIR, "..", "..","kernels", "qdq")}',
        f'--include={os.path.join(CURRDIR, "..", "..","kernels", "include")}',
        f'--include={os.path.join(CURRDIR, "..", "..","kernels", "common")}',
        '--adf-api-log-level=3',
        '--enable-partition=0:8 ',
        '--aie2ipu-base-addr=0',
        f'--stacksize={overlay_stack_size()}',
        f'--heapsize={overlay_heap_size()}',
        '--enable-core-processor-bus=true',
        '--disable-dma-autostart=true',
        '--Xchess="main:backend.mist2.pnll=off"',
        "--Xpreproc=-D__AIE_API_WORKAROUND_CR_1223259__=1",
        '--Xchess="main:backend.amnesia.rls=on"',
        '--Xpreproc="-D__AIE_API_WORKAROUND_CR_1223259__=1"',
        '--Xpreproc="-D_main_init=_waic_main_init"',
        # '--Xelfgen="-j `grep "^processor" /proc/cpuinfo | wc -l`',
        # '--Xpreproc="-DTXN_MODE=1"',
        '--xlopt=0',
        '--enable-light-cdo'
    ] + compile_flags

    sim_args = [
        'aiesimulator',
        '--profile',
    ]
    compile_command = ' '.join(compile_args)
    systemC_sed_command = "sed -i 's/-ladf_api/-ladf_rt_ctrl_api -ladf_api/g' Work/ps/c_rts/systemC/Makefile"
    systemC_make_command = 'make -C Work/ps/c_rts/systemC/ all'
    sim_command = ' '.join(sim_args)
    verbose_run(compile_command)
    verbose_run(systemC_sed_command)
    verbose_run(systemC_make_command)
    try:
        verbose_run(sim_command)
    except Exception as e:
        pass


def build_txn_bins_gcc(
    host_filename: str,
    compile_flags: List[str]
):
    compile_args = [
        'g++',
        '-Wall -Wextra',
        '-D__AIECONTROLCODE__',
        '-D__TXNRT__',
        '-o txn_dma',
        '-w',
        f'{host_filename}',
        f'-I{os.path.join(CURRDIR)}',
        f'-I{os.path.join(CURRDIR,"..", "..")}',
        f'-I{os.path.join(CURRDIR, "..", "kernels")}',
	    #f'-I{os.path.join(CURRDIR, "..", "kernels", "conv")}',
        #f'-I{os.path.join(CURRDIR, "..", "kernels", "qdq")}',
        f'-I{os.path.join(CURRDIR, "..", "kernels", "include")}',
        f'-I{os.path.join(CURRDIR, "..", "kernels", "common")}',
        f'-I{XILINX_VITIS_AIETOOLS}/include/drivers/aiengine',
        f'-L{XILINX_VITIS_AIETOOLS}/lib/lnx64.o',
        '-lxaiengine',
        f'-Wl,-rpath,{XILINX_VITIS_AIETOOLS}/lib/lnx64.o',
        '-lstdc++',
    ] + compile_flags
    compile_command = ' '.join(compile_args)
    run_target = './txn_dma'

    verbose_run(compile_command)
    verbose_run(run_target)


def build_mha_standalone_xclbin(kernel_names: List[str], kernel_includes: List[str], output_dir: str, overlay: str):
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    build_dir = CURRDIR
    os.chdir(build_dir)

    dummy_dataflow.compile_dataflow(kernel_names, kernel_includes, overlay)

    stack_size = overlay_stack_size()
    heap_size = overlay_heap_size()
    xclbin_name = 'out.xclbin'
    verbose_run(
        f'make STACKSIZE={stack_size} HEAPSIZE={heap_size} FILENAME={xclbin_name}')

    src_file = os.path.join(build_dir, xclbin_name)
    dst_file = os.path.join(output_dir, xclbin_name)
    print("src_file:", src_file)
    print("dst_file:", dst_file)
    shutil.copy(src_file, dst_file)

def mha_3p0_build_qdq(
    tiler_output_path: str,  # Assume to be path to json output file from tiler
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
    model_data: Optional[str] = None,
    out_folder: Optional[str] = None,
    run_mha_standalone: bool = False
):

    with open(tiler_output_path) as json_data:
        parameters = json.load(json_data)
    # Either "4x4" or "8x4"
    overlay_str = parameters['overlay_info']['overlay']

    internal_parameters = Mha3p0Parameters.compute_internal_parameters(parameters)
    buffer_alloc = generate_standalone_buffer_allocations(internal_parameters)

    clean_overlay()
    (M, K, N, L, M_subv, K_subv, N_subv, L_subv) = generate_dataflow_mha_3p0(
        internal_parameters, buffer_alloc, back_end, kernel_names, kernel_includes)

    main_cpp_path = os.path.join(CURRDIR, "mha_general_dflow", "tests", "main.cpp")
    if(back_end == BackEnd.TxnHostPatch):
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)

        build_txn_bins(main_cpp_path,
                       mha_preprocessor_directives(M, K, N, L, M_subv, back_end, model_data))
        if run_mha_standalone:
            build_mha_standalone_xclbin(kernel_names=kernel_names, \
                                        kernel_includes=kernel_includes,  
                                        output_dir="./" + out_folder, overlay=overlay_str)  
    else:
        build_sim_overlay(back_end, main_cpp_path,
                          mha_preprocessor_directives(M, K, N, L, M_subv, back_end, model_data))

    if out_folder is not None and back_end != BackEnd.Adf:
        in_folder = CURRDIR
        if back_end == BackEnd.TxnHostPatch:
            files = ('ifm.bin', 'wgt.bin', 'ofm.bin',
                     tiler_output_path, 'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        else:
            assert False
        for file in files:
            src = os.path.join(in_folder, file)
            target_file = "ctrl_meta.json" if(file == 'patch.json') else file
            dst = os.path.join(out_folder, target_file)
            if(file == tiler_output_path and src != dst):
                shutil.copy(src, dst)
            elif(src != dst):
                shutil.move(src, dst)


def run_qkt_sm_op(json_file: str, path: str, txn_mode: bool, kernel_d, model_path, front_end_only):
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    back_end = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    output_dir = os.path.dirname(os.path.realpath(json_file)) if back_end != BackEnd.Adf else None 
    
    H = _data['layer_info']['attributes']['num_heads'][0]
    Sin_q   = _data['layer_info']['orig_in_q_shape'][0]
    Sin_kv  = _data['layer_info']['orig_in_k_shape'][1]
    Sin_dh  = _data['layer_info']['orig_in_q_shape'][1]
    Sq_subv = _data['core_tile_params']['subvols']['q'][0]
    aie_overlay_cols = _data['overlay_info']['shape']['col'] #data['aie_rows'] = _data['overlay_info']['shape'][0]
    bias_vector_exist = True if('num_bias' in _data['layer_info']['attributes'] \
                                          and _data['layer_info']['attributes']['num_bias'][0] > 0) else False

    model_data_folder = model_path
    out_folder = output_dir

    kernel_names = {}
    kernel_list = [
        'run_gemm_qdq_mini',
        'run_softmax_qdq',
        'run_bcast_add_mini',
        'run_mini_mha_preprocess',
        'run_presoftmax_dequant'
    ]
    for k in kernel_list:
        try:
            kernel_names[k] = kernel_func_list.index(k)
        except ValueError:
            print(f"Error: '{k}' not found in the kernel func list!")

    kernel_includes = [
        'super.hh',
        'mha_qdq/wrapper_mini_mha.cc',
        'norm/softmax.cc'
    ]
    back_end = BackEnd.TxnHostPatch if txn_mode == 1 else BackEnd.Adf
    

    mha_2p1_test_build_qdq(
        H=H,
        Sin_q=Sin_q,
        Sin_kv=Sin_kv,
        Sin_dh=Sin_dh,
        Sq_subv=Sq_subv,
        aie_overlay_cols=aie_overlay_cols,
        back_end=back_end,
        standalone_hw_test=False,
        kernel_names=kernel_names,
        kernel_includes=kernel_includes,
        disable_fast_pm=False,
        front_end_only=front_end_only,
        model_data=model_data_folder,
        out_folder=out_folder,
        bias_vector_exist=bias_vector_exist
    )

def run_qkt_sm_common(tiler_output_path: str, path: str, txn_mode: bool, kernel_d, front_end_only):
    with open(tiler_output_path, "r") as fd:
        tiler_output = json.load(fd)

    S = {
            # PSR self attention shapes : 
            (64, 64, 64, 64),
            (256, 64, 256, 64),
            (1024, 64, 1024, 64),
            (4096, 64, 4096, 64),
            # PSR cross attention shapes : 
            (64, 64, 77, 64),
            (256, 64, 77, 64),
            (1024, 64, 77, 64),
            (4096, 64, 77, 64)
    }   
    
    ## Per head dimension from tiler output : json file
    layer_info = tiler_output['layer_info']
    mha_mode: MHAMode = layer_info['attributes']['mha_mode'][0]
    MTensor = layer_info['orig_in_q_shape'][0]
    KTensor = layer_info['orig_in_q_shape'][1]
    NTensor = layer_info['orig_in_k_shape'][1]
    LTensor = KTensor

    model_path = f"{test_dir}/model_data_{MTensor}x{KTensor}x{NTensor}_{MTensor}x{NTensor}x{LTensor}/"

    if(not((MTensor, KTensor, NTensor, LTensor) in S)):
        print("Non tested MHA Shape!")

    match mha_mode:
        case MHAMode.TWO_P_ONE:
            run_qkt_sm_op(tiler_output_path, path, txn_mode, kernel_d, model_path, front_end_only)
        case MHAMode.THREE_P_ZERO_MINI | MHAMode.TWO_P_ONE_MINI:
            mini_3p0 = mha_mode == MHAMode.THREE_P_ZERO_MINI
            run_mini_qkt_sm_smxv_op(tiler_output_path, path, txn_mode, kernel_d, model_path, front_end_only, mini_3p0=mini_3p0)
        case _:
            assert False, "Unsupported mha mode found."

def run_qkt_sm_smxv_op(json: str, path: str, txn_mode: int, kernel_d, model_path: str):
    os.chdir(path)
    output_dir = os.path.dirname(
        os.path.realpath(json)) if txn_mode != 0 else None

    # if not kernel_d:
    kernel_names = ['run_act_K_preprocess', 'run_act_V_preprocess',
                    'run_qkt_gemm_qdq', 'run_sfmx_i16_to_i16', 'run_smxv_gemm_qdq']
    kernel_includes = ['super.hh', 'mha_qdq/wrapper_mha_3p0_4x4_i16i16.cc']
    back_end = BackEnd.TxnHostPatch if txn_mode == 1 else BackEnd.Adf
    mha_3p0_build_qdq(
        json,
        back_end,
        kernel_names,
        kernel_includes,
        model_path,
        output_dir
    )
    print("Exit mha_3p0_build_qdq call\n")


def mha_mini_build_qdq(
    H: int,
    Sin_q: int,
    Sin_kv: int,
    Sin_dh: int,
    Sq_subv: int,
    aie_overlay_cols: int,
    back_end: BackEnd,
    standalone_hw_test: bool, 
    kernel_names: Union[List[str], Dict[str, int]],
    kernel_includes: List[str],
    front_end_only: bool, 
    disable_fast_pm = False, 
    model_data: Optional[str] = None,
    out_folder: Optional[str] = None,
    enable_attn_mask = False,
    K_is_tranposed_on_DDR = True,
    G: Optional[int] = None,
    G_seq: Optional[List[int]] = None,
    B: Optional[int] = None,
    B_seq: Optional[List[int]] = None,
    mini_3p0: Optional[int] = 1,
):

    overlay_str = "8x4"
    #cflags = [  '--Xelfgen="-j `grep "^processor" /proc/cpuinfo | wc -l`"',
    #            '--xlopt=0',
    #            '--aie2ipu-base-addr=0',]
    cflags = [f'-I {os.path.join(CURRDIR)}', f'-I {out_folder}']
    print("H", H)
    print("Sin_q", Sin_q)
    print("Sin_kv", Sin_kv)
    print("Sin_dh", Sin_dh)
    print("Sq_subv", Sq_subv)
    print("aie_overlay_cols", aie_overlay_cols)
    print("stdalone_hw_test", standalone_hw_test)
    print("Enable_attn_mask", enable_attn_mask)
    print("kernel_names", kernel_names)
    print("kernel_includes", kernel_includes)
    print("back_end", back_end)
    print("model_data", model_data)
    print("out_folder", out_folder)
    print("G", G)
    print("G_seq", G_seq)
    print("B", B)
    print("B_seq", B_seq)
    print("mini_3p0", mini_3p0)
   
    clean_overlay()
    generate_dataflow_mini_mha(H, Sin_q, Sin_kv, Sin_dh, Sq_subv, aie_overlay_cols, back_end, kernel_names, kernel_includes,\
                               enable_attn_mask=enable_attn_mask, disable_fast_pm=disable_fast_pm, K_is_tranposed_on_DDR=K_is_tranposed_on_DDR,\
                               G=G, G_seq=G_seq, B=B, B_seq=B_seq, mini_3p0=mini_3p0)


    main_cpp_path = os.path.join(CURRDIR, "mini_mha", "main_di_model.cpp")
    if(back_end == BackEnd.TxnHostPatch):
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)
        print("CURDIR  :", CURRDIR)
        build_txn_aiert(main_cpp_path,
                       mha_preprocessor_directives(H, Sin_q, Sin_dh, Sin_kv, Sin_kv, Sq_subv, back_end, model_data, attn_mask_exist=enable_attn_mask, Bias=B, G=G) + cflags )
    else:
        build_sim_overlay(back_end, main_cpp_path,
                          mha_preprocessor_directives(H, Sin_q, Sin_dh, Sin_kv, Sin_kv, Sq_subv, back_end, model_data, attn_mask_exist=enable_attn_mask, Bias=B, G=G) + cflags)



def run_mini_qkt_sm_smxv_op(json_file: str, path: str, txn_mode: bool, kernel_d, model_path, front_end_only, mini_3p0: int):
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    back_end = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    output_dir = os.path.dirname(os.path.realpath(json_file)) if back_end != BackEnd.Adf else None 

    H = _data['layer_info']['attributes']['num_heads'][0]
    Sin_q   = _data['layer_info']['orig_in_q_shape'][0]
    Sin_kv  = _data['layer_info']['in_k_shape'][1]
    Sin_dh  = _data['layer_info']['in_q_shape'][1]
    Sq_subv = _data['core_tile_params']['subvols']['q'][0]
    aie_overlay_cols = _data['overlay_info']['shape']['col'] #data['aie_rows'] = _data['overlay_info']['shape'][0]
    G = _data['layer_info']['attributes']['num_groups'][0] if 'num_groups' in _data['layer_info']['attributes'] else None
    B = _data['layer_info']['attributes']['num_mask'][0] if 'num_mask' in _data['layer_info']['attributes'] else None
    G_seq, B_seq = None, None

    if 'groups_sequence' in _data['layer_info']['attributes']:
        G_seq = _data['layer_info']['attributes']['groups_sequence']
    elif G != H and G is not None:
        raise AssertionError("G_seq is required when G is not equal to H")
    if 'mask_sequence' in _data['layer_info']['attributes']:
        B_seq = _data['layer_info']['attributes']['mask_sequence']
    elif B != H and B is not None:
        raise AssertionError("B_seq is required when B is not equal to H")
    
    layer_info = _data.get('layer_info', {})
    inputs_raw = layer_info.get('inputs', '[]')

    try:
        inputs = ast.literal_eval(inputs_raw)  # safely convert string to list
    except (ValueError, SyntaxError):
        inputs = []

    enable_attn_mask = any(
        isinstance(inp, dict) and inp.get('param_name') == 'M'
        for inp in inputs
    ) or ('num_bias' in _data['layer_info']['attributes'] and _data['layer_info']['attributes']['num_bias'][0]==1)
    Sin_q_mod = 64 if(Sin_q == 1) else Sin_q
    
    model_data_folder = f"{test_dir}/model_data_{Sin_q_mod}x{Sin_dh}x{Sin_kv}_{Sin_q_mod}x{Sin_kv}x{Sin_dh}/"
    out_folder = output_dir

    if kernel_d:
        kernel_names = kernel_d["kernel_list"]
        kernel_includes = kernel_d["kernel_include"]
    else:
        kernel_includes = [
            'super.hh',
            'mha_qdq/wrapper_mini_mha.cc'
        ]
        kernel_names = {}
        kernel_list = [
            'run_gemm_qdq_mini',
            'run_softmax_qdq',
            'run_bcast_add_mini',
            'run_mini_mha_preprocess',
            'run_presoftmax_dequant'
        ]
        for k in kernel_list:
            try:
                kernel_names[k] = kernel_func_list.index(k)
            except ValueError:
                print(f"Error: '{k}' not found in the kernel func list!")
    back_end = BackEnd.TxnHostPatch if txn_mode == 1 else BackEnd.Adf
   
    mha_mini_build_qdq(
        H=H,
        Sin_q=Sin_q,
        Sin_kv=Sin_kv,
        Sin_dh=Sin_dh,
        Sq_subv=Sq_subv,
        aie_overlay_cols=aie_overlay_cols,
        back_end=back_end,
        standalone_hw_test=False,
        kernel_names=kernel_names,
        kernel_includes=kernel_includes,
        front_end_only=front_end_only,
        disable_fast_pm=False,
        model_data=model_data_folder,
        out_folder=out_folder,
        enable_attn_mask=enable_attn_mask,
        K_is_tranposed_on_DDR=(Sin_q==1),
        G=G,
        G_seq=G_seq,
        B=B,
        B_seq=B_seq,
        mini_3p0=mini_3p0
    )
    
def mha_2p1_test_build_qdq(
    H: int,
    Sin_q: int,
    Sin_kv: int,
    Sin_dh: int,
    Sq_subv: int,
    aie_overlay_cols: int,
    back_end: BackEnd,
    standalone_hw_test: bool, 
    kernel_names: List[str],
    kernel_includes: List[str],
    front_end_only : bool,
    disable_fast_pm: bool,
    bias_vector_exist = False,
    model_data: Optional[str] = None,
    out_folder: Optional[str] = None,
):
    
    overlay_str = "8x4"
    cflags = [f'-I {os.path.join(CURRDIR)}', f'-I {out_folder}']
    clean_overlay()
    generate_dataflow_2p1_mha(H, Sin_q, Sin_kv, Sin_dh, Sq_subv, aie_overlay_cols, back_end, kernel_names, kernel_includes, disable_fast_pm=disable_fast_pm, enable_mask_vector=bias_vector_exist)

    main_cpp_path = os.path.join(CURRDIR, "main.cpp")
    #exit(1)
    if(back_end == BackEnd.TxnHostPatch):
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)

        build_txn_aiert(main_cpp_path,
                       mha_preprocessor_directives(H, Sin_q, Sin_dh, Sin_kv, Sin_kv, Sq_subv, back_end, model_data) + cflags )
    else:
        build_sim_overlay(back_end, main_cpp_path,
                          mha_preprocessor_directives(H, Sin_q, Sin_dh, Sin_kv, Sin_kv, Sq_subv, back_end, model_data) + cflags)

    if False: #out_folder is not None and back_end != BackEnd.Adf and standalone_hw_test:
        if not os.path.exists(out_folder):
           os.makedirs(out_folder)
        in_folder = CURRDIR
        if back_end == BackEnd.TxnHostPatch:
            files = ('ifm.bin', 'wgt.bin', 'ofm.bin',
                      'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        else:
            assert False
        for file in files:
            src = os.path.join(in_folder, file)
            target_file = "ctrl_meta.json" if(file == 'patch.json') else file
            dst = os.path.join(out_folder, target_file)
            if(src != dst):
                shutil.move(src, dst)
                
def extract_fields(file_name):
	with open(file_name, 'r') as f:
		data = json.load(f)
	return data


