import os
import sys
import json
import shutil
import logging
from typing import List, Optional

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(CURRDIR)

from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
from dataflow_common import clean_overlay, build_sim_overlay
from depthtospace_common import DepthToSpace_dims, depthtospace_preproc_directives
import depthtospace_dataflow
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

def build_depth2space(
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    in_shape: List[int],
    blockSize: int,
    perm_mode: str,
    act_bits: int,
    frontend_only: bool = False,
    out_folder: Optional[str] = None,
):
    assert (back_end != BackEnd.Adf) or (out_folder is None)
    # assert (aie_cols, aie_rows) == (8, 4)

    tiling_json = {
        'in_shape': in_shape,
        'perm_mode': perm_mode,
        'ifm_bits': act_bits,
    }
    tiling_json_filename = os.path.join(CURRDIR, 'tiling.json')
    with open(tiling_json_filename, 'w') as f:
        f.write(json.dumps(tiling_json, sort_keys=True, indent=4))
        
    dims = DepthToSpace_dims(
        aie_rows, aie_cols, 
        in_shape, blockSize, perm_mode,
        act_bits   
    )

    clean_overlay()
    depthtospace_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    if not frontend_only:
        build_sim_overlay(back_end, 'depthtospace_main.cpp', depthtospace_preproc_directives(dims, back_end))


def extract_fields(file_name):
	with open(file_name, 'r') as f:
		data = json.load(f)
	return data

def run_depthtospace_op(json_file, path, txn_mode, kernel_d, frontend_only):
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    data = {}
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    if not kernel_d:
        data['kernel_names'] = {}
        kernel_names = ['run_int16_permute']
        for k in kernel_names:
            try:
                data['kernel_names'][k] = kernel_func_list.index(k)
            except ValueError:
                print(f"Error: '{k}' not found in the kernel func list!")
        data['kernel_includes'] = ['super.hh', "permute/wrapper_permute.cc"]
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_includes'] = kernel_d['kernel_include']
    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']
    data['act_bits'] = _data['layer_info']['in_bytes'] * 8
    # in_act_shape = _data['layer_info']['in_act_shape']
    data['input'] = _data['layer_info']['in_act_shape']
    data['block_size'] = _data['layer_info']['attributes']['blocksize'][0]
    data['perm_mode'] = _data['layer_info']['attributes']['mode'][0]
    # if len(in_act_shape) != 2:
    # 	assert False, f"depth2space input shape list {in_act_shape} length doesn't match number of inputs: 2"  
    output_dir = os.path.dirname(os.path.realpath(json_file)) if data['back_end'] != BackEnd.Adf else None 
    logging.info(f" depth2space input args: {data}")
    build_depth2space(data['back_end'],
                data['kernel_names'], data['kernel_includes'],
                data['aie_cols'], data['aie_rows'], 
                data['input'], data['block_size'], data['perm_mode'],
                data['act_bits'],
                frontend_only,
                output_dir)
 
def main():
    back_end = BackEnd.Adf
    kernel_names = ["run_int16_permute"]
    kernel_includes = ['super.hh', "permute/wrapper_permute.cc"]
    aie_cols, aie_rows = 8, 4
    in_shape = [1, 128, 128, 256]
    blockSize = 2
    perm_mode = "CRD"
    act_bits = 16
    frontend_only = False
    build_depth2space(
        back_end,
        kernel_names,
        kernel_includes,
        aie_cols,
        aie_rows,
        in_shape,
        blockSize,
        perm_mode,
        act_bits,
        frontend_only=frontend_only
    )

if __name__ == '__main__':
    main()
