import os
import sys
import json
import shutil
import logging
from typing import List, Optional

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))

from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
from dataflow_common import clean_overlay, build_sim_overlay, sizeof
from slice_neg_common import SliceDims, slice_neg_preproc_directives
import slice_neg_dataflow
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

def build_slice_neg(
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    h_in: int, w_in: int,
    w_out_start: int, w_out_stop: int,
    ifm_bits: int,
    out_folder: Optional[str] = None,
):
    assert (back_end != BackEnd.Adf) or (out_folder is None)
    # assert (aie_cols, aie_rows) == (8, 4)

    tiling_json = {
        'h_in': h_in,
        'w_in': w_in,
        'w_out_start': w_out_start,
        'w_out_stop': w_out_stop,
        'ifm_bits': ifm_bits,
    }
    tiling_json_filename = os.path.join(CURRDIR, 'tiling.json')
    with open(tiling_json_filename, 'w') as f:
        f.write(json.dumps(tiling_json, sort_keys=True, indent=4))
        
    dims = SliceDims(
        aie_rows, aie_cols, 
        h_in, w_in,
        w_out_start, w_out_stop,
        ifm_bits,
    )

    clean_overlay()
    slice_neg_dataflow.compile_dataflow(
        dims,
        back_end,
        kernel_names,
        kernel_includes
    )
    build_sim_overlay(back_end, 'slice_neg_main.cpp', slice_neg_preproc_directives(dims, back_end))


def extract_fields(file_name):
	with open(file_name, 'r') as f:
		data = json.load(f)
	return data

def run_slice_neg_op(json_file, path, txn_mode, kernel_d):
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    data = {}
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    if not kernel_d:
        data['kernel_names'] = {}
        kernel_names = ['run_int16_negative']
        for k in kernel_names:
            try:
                data['kernel_names'][k] = kernel_func_list.index(k)
            except ValueError:
                print(f"Error: '{k}' not found in the kernel func list!")
        data['kernel_includes'] = ['super.hh', 'qdq/wrapper_qdq.cc']
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_includes'] = kernel_d['kernel_include']
    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']
    data['ifm_bits'] = sizeof(_data['layer_info']['in_datatype'])
    in_act_shape = _data['layer_info']['in_act_shape']
    if len(in_act_shape) != 4:
        assert False, f"Slice input shape list {in_act_shape} length doesn't match number of inputs: 4" 
    data['h_in'] = _data['layer_info']['in_act_shape'][2] * _data['layer_info']['in_act_shape'][3]
    data['w_in'] = _data['layer_info']['in_act_shape'][1]
    data['w_out_start'] = _data['layer_info']['attributes']['start'][0]
    data['w_out_stop'] = _data['layer_info']['attributes']['end'][0]
    output_dir = os.path.dirname(os.path.realpath(json_file)) if data['back_end'] != BackEnd.Adf else None 
    logging.info(f" Slice input args: {data}")
    build_slice_neg(data['back_end'],
                data['kernel_names'], data['kernel_includes'],
				data['aie_cols'], data['aie_rows'], 
    			data['h_in'], data['w_in'],
				data['w_out_start'], data['w_out_stop'],
    			data['ifm_bits'],
				output_dir)
    
def main():
    back_end = BackEnd.Adf
    kernel_names = ['run_int16_negative']
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc']
    aie_cols, aie_rows = 8, 4
    h_in = 4096
    w_in = 2560
    w_out_start = 640
    w_out_stop = 1280
    ifm_bits = 16
  
    build_slice_neg(
        back_end,
        kernel_names,
        kernel_includes,
        aie_cols,
        aie_rows,
        h_in,
        w_in,
        w_out_start,
        w_out_stop,
        ifm_bits,
    )

if __name__ == '__main__':
    main()
