import os
import sys
import ast
import json
import shutil
import logging
from typing import List, Optional

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'OGOAT', 'src', 'L1_fusion'))

from resize_tiler import run_tiler
from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
from dataflow_common import clean_overlay, build_sim_overlay, sizeof
from resize_common import ResizeDims, resize_preproc_directives
import resize_dataflow
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def resize_kernel_selection_logic(inputs: List[dict], outputs: List[dict], ifm_bytes: int,  attributes: dict):
    kernel_names = {}
    kernel_includes = ["super.hh"]
    resize_mode = attributes['mode'][0]

    if resize_mode == "linear":
        kernel_names["run_bilinear_resize_bf16"] = kernel_func_list.index("run_bilinear_resize_bf16")
        kernel_includes.append("bilinear_pixel_resize_bf16/bilinear_pixel_resize_bf16_wrapper.cc")
        
    return kernel_names, kernel_includes


def extract_fields(file_name):
    with open(file_name, 'r') as f:
        data = json.load(f)
    return data

def build_resize(
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    h_in: int, w_in: int, c_in: int,
    num_interpolations: int,
    ifm_bits: int,
    int_16: Optional[int] = 1, bfloat_16: Optional[int] = 0,
    frontend_only: bool = False,
    out_folder: Optional[str] = None,
):
    assert (back_end != BackEnd.Adf) or (out_folder is None)


    dims = run_tiler(
        aie_rows, 
        aie_cols, 
        1,
        h_in, 
        w_in, 
        c_in,
        num_interpolations,
        ifm_bits,
        int_16, 
        bfloat_16
    )
    
    run_scheduler(dims, back_end, kernel_names, kernel_includes, frontend_only)

   

def run_scheduler(dims: ResizeDims, back_end: BackEnd, kernel_names: List[str], kernel_includes: List[str], frontend_only: bool):
    clean_overlay()
    resize_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    if not frontend_only:
        host_cpp = os.path.join(os.getcwd(), 'resize_main.cpp')
        build_sim_overlay(back_end, host_cpp, resize_preproc_directives(dims, back_end))


def run_resize_op(json_file, path, txn_mode,  kernel_d, frontend_only):
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    data = {}
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    if not kernel_d:
        data['kernel_names'] = []
        data['kernel_includes'] = ['super.hh']
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_includes'] = kernel_d['kernel_include']
    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']
    data['ifm_bits'] = sizeof(_data['layer_info']['in_datatype'])
    data['inputs'] = ast.literal_eval(_data['layer_info']['inputs'])
    data['act_inputs'] = [input for input in data['inputs'] if len(input.get('type')) != 0][0]
    data['h_in'] = data['act_inputs']['shape'][1]
    data['w_in'] = data['act_inputs']['shape'][2]
    data['c_in'] = data['act_inputs']['shape'][3]
    data['interpolation_n'] = int(_data['layer_info']['attributes']['scales_1'][0])
    data['interpolation_h'] = int(_data['layer_info']['attributes']['scales_2'][0])
    data['interpolation_w'] = int(_data['layer_info']['attributes']['scales_3'][0])
    data['interpolation_c'] = int(_data['layer_info']['attributes']['scales_4'][0])
    if data['interpolation_n'] != 1:
        assert False, "Interpolation cannot be performed in the n dimension"
    if data['interpolation_c'] != 1:
        assert False, "Interpolation cannot be performed in the c dimension"
    if data['interpolation_h'] != data['interpolation_w']:
        assert False, "Interpolation scale has to equal in the h and w dimension"
    data['num_interpolations'] = data['interpolation_h']
    data['int_16'] = 1 if _data['layer_info']['in_datatype'] == 'uint16' else 0
    data['bfloat_16'] = 1 if _data['layer_info']['in_datatype'] == 'bfloat16' else 0
    output_dir = os.path.dirname(os.path.realpath(json_file)) if data['back_end'] != BackEnd.Adf else None
    logging.info(f" NNI input args: {data}")
    build_resize(data['back_end'],
                data['kernel_names'], data['kernel_includes'],
                data['aie_cols'], data['aie_rows'], 
                data['h_in'], data['w_in'], data['c_in'],
                data['num_interpolations'],
                data['ifm_bits'],
                data['int_16'], data['bfloat_16'],
                frontend_only,
                output_dir)

def main():
    back_end = BackEnd.Adf
    kernel_names = []
    kernel_includes = ['super.hh']
    aie_cols, aie_rows = 8, 4
    h_in = 16
    w_in = 16
    c_in = 1280
    num_interpolations = 2
    ifm_bits = 16
    int_16 = 1
    bfloat_16 = 0
    frontend_only = False


    build_resize(
        back_end,
        kernel_names, kernel_includes,
        aie_cols, aie_rows,
        h_in, w_in, c_in,
        num_interpolations,
        ifm_bits,
        int_16, bfloat_16,
        frontend_only=frontend_only,
    )

if __name__ == '__main__':
    main()
