import os
import sys
import ast
import json
import shutil
import logging
import argparse
from typing import Tuple, List, Optional

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(CURRDIR)
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
from dataflow_common import clean_overlay, build_sim_overlay, sizeof, tiling_json_gen
from concat_common import concat_preproc_directives
from concat_run_tiler import run_tiler
import concat_8x4_dataflow, concat_8x4_any_input_nums_dataflow
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


def build_concat(
    back_end: BackEnd,
    kernel_names: dict,
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    num_inputs: int,
    concat_mode: int, # 0: ch-wise; 1: col-wise
    input_rows: list,
    input_cols: list,
    input_chs: list,
    # ifm_bits: int,
    is_int16: bool,
    qdq_mode: int,
    is_signed: bool,
    frontend_only: bool = False,
    out_folder: Optional[str] = None,
    get_kernel_mode: bool = False,
    save_bins: Optional[bool] = False,
    input_types: List = ['act'],
):

    is_qdq = False if qdq_mode == 3 else True
    dims = run_tiler(
        aie_cols, aie_rows,
        num_inputs, concat_mode,
        input_rows, input_cols, input_chs,
        # ifm_bits, ofm_bits,
        is_int16,
        is_qdq, qdq_mode, is_signed,
        input_types,
    )

    if get_kernel_mode:
        return dims.kernel_names, dims.kernel_includes

    tiling = {}
    tiling["op_type"] = "concat"
    tiling["orig_input"] = []
    for i in range(len(input_rows)):
        tiling["orig_input"].append([1, input_rows[i], input_cols[i], input_chs[i]])

    if concat_mode == 0:
        tiling["orig_output"] = [1, input_rows[0], input_cols[0], sum(input_chs)]
    elif concat_mode == 1:
        tiling["orig_output"] = [1, input_rows[0], sum(input_cols), input_chs[0]]
    else:
        assert False, "We only support channel and column dimension concat"
    tiling_json_gen(tiling, os.path.join(os.getcwd(), 'tiling.json'))

    run_scheduler(dims, back_end, kernel_names, kernel_includes, frontend_only)

    if save_bins and back_end == BackEnd.TxnHostPatch and out_folder is not None:
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)
        in_folder = CURRDIR
        files = ('ifm.bin', 'wgt.bin', 'ofm.bin', 'dma.hpp',
                'tiling.json', 'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        for file in files:
            src = os.path.join(in_folder, file)
            dst = os.path.join(out_folder, file)
            shutil.move(src, dst)

def run_scheduler(dims, back_end, kernel_names, kernel_includes, frontend_only):
    clean_overlay()
    if dims.num_inputs_exception:
        concat_8x4_any_input_nums_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    else:
        concat_8x4_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    if not frontend_only:
        build_sim_overlay(back_end, 'concat_main.cpp', concat_preproc_directives(dims, back_end))

def extract_fields(file_name):
	with open(file_name, 'r') as f:
		data = json.load(f)
	return data


def update_len_to_4(data):
    while len(data) < 4:
        data.insert(0, 1)
    return data

def concat_axis(updated_input: int, updated_output: int):
    for axis in range(4):
        if updated_output[axis] > updated_input[axis]:
            return axis
    assert False, "This is a no-op"

def get_nested_value(d, keys, default=None):
    for key in keys:
        if isinstance(d, dict):
            d = d.get(key)
        else:
            return default
    return d if d is not None else default

def is_valid_param(inp):
    name = inp.get("param_name", "")
    return not (name.endswith("_scale") or name.endswith("_zero_point"))

def run_concat_op(json_file, path, txn_mode, kernel_d, frontend_only):
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    data = {}
    data['op_type'] = _data['layer_info']['op_type']
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']
    data['num_inputs'] = _data['layer_info']['attributes']['num_inputs'][0]
    data['input'] = _data['layer_info']['in_act_shape']
    data['output'] = _data['layer_info']['out_act_shape']
    input_shape = update_len_to_4(data['input'])
    output_shape = update_len_to_4(data['output'])
    data['axis'] = concat_axis(input_shape, output_shape) # 3: channel-wise; 2: col/w -wise
    if data['axis'] == 3:
        data['concat_mode'] = 0
    elif data['axis'] == 2:
        data['concat_mode'] = 1
    else:
        assert False, "Axis not supported"
    data['ifm_bits'] = sizeof(_data['layer_info']['in_datatype'])
    data['inputs'] = ast.literal_eval(_data['layer_info']['inputs'])

    act_inputs = [
        inp for inp in data["inputs"]
        if inp.get("type") == "act"
        and len(inp.get("shape", [])) != 0
        and is_valid_param(inp)
    ]

    valid_const_inputs = [
        inp for inp in data["inputs"]
        if inp.get("type") == "const"
        and len(inp.get("shape", [])) != 0
        and is_valid_param(inp)
    ]

    # 2) If valid const inputs > 1, raise
    if len(valid_const_inputs) > 1:
        raise RuntimeError(f"Expected at most 1 const input, found {len(valid_const_inputs)} const inputs.")

    data["act_inputs"] = act_inputs + valid_const_inputs
    data["is_const_input"] = bool(valid_const_inputs)

    if (len(data['act_inputs']) + len(data['const_inputs'])) != data['num_inputs']:
        assert False, f"Concat input shape list {data['inputs']} length doesn't match number of inputs: {data['num_inputs']} in the layer info attributes"
    if not all(len(d['shape']) == len(data['act_inputs'][0]['shape']) for d in data['act_inputs']):
        assert False, f"Concat input shape list {data['act_inputs']} each individual shape length should match"
    data['rows'] = []
    data['cols'] = []
    data['chs'] = []
    for i in range(data['num_inputs']):
        curr_shape = update_len_to_4(data['act_inputs'][i]['shape'])
        data['rows'].append(curr_shape[1])
        data['cols'].append(curr_shape[2])
        data['chs'].append(curr_shape[3])
    if data['num_inputs'] > 2 and data['axis'] == 3:
        print("CONCAT KERNEL DOES NOT SUPPORT CONCAT OF MORE THAN 2 INPUTS")
        sys.exit(1)
    # disable_qdq / is_not_qdq
    disable_qdq_list = get_nested_value(_data, ['layer_info', 'attributes', 'disable_qdq'], [])
    is_not_qdq = disable_qdq_list[0] if isinstance(disable_qdq_list, list) and disable_qdq_list else 1
    is_qdq = 1 - is_not_qdq
    qdq_mode = 2 if is_qdq else 3 # NOTE: we need qdq mode from json
    is_int16 = True if data['op_type'] == "Concat_qdq_uint16" else False

    if not kernel_d:
        if is_int16:
            data['kernel_names'] = {
                                        "run_combined_qdq": kernel_func_list.index("run_combined_qdq"),
                                        "run_concat": kernel_func_list.index("run_concat"),
                                    }
        else:
            data['kernel_names'] = {
                            "run_combined_qdq_a8": kernel_func_list.index("run_combined_qdq_a8"),
                            "run_concat_a8": kernel_func_list.index("run_concat_a8"),
                        }
        data['kernel_includes'] = ['super.hh', 'qdq/wrapper_qdq.cc', 'concat/wrapper_concat.cc']
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_includes'] = kernel_d['kernel_include']

    output_dir = os.path.dirname(os.path.realpath(json_file)) if data['back_end'] != BackEnd.Adf else None
    logging.info(f" Concat input args: {data}")
    build_concat(data['back_end'],
                data['kernel_names'], data['kernel_includes'],
                data['aie_cols'], data['aie_rows'],
                data['num_inputs'], data['concat_mode'],
                data['rows'], data['cols'], data['chs'],
                # data['ifm_bits'],
                is_int16,
                qdq_mode,
                frontend_only,
                output_dir,
                is_const_input=data["is_const_input"])


def concat_kernel_selection_logic(inputs: List[dict], outputs: List[dict], ifm_bytes: int, attributes: dict):
    aie_cols = 8
    aie_rows = 4
    input_shapes = [input['shape'] for input in inputs if len(input.get('shape')) != 0]
    output_shapes = [output['shape'] for output in outputs if len(output.get('shape')) != 0]
    num_inputs = len(input_shapes)
    input_0 = update_len_to_4(input_shapes[0])
    output = update_len_to_4(output_shapes[0])

    axis = concat_axis(input_0, output)

    if axis == 3:
        concat_mode = 0
    elif axis == 2:
        concat_mode = 1
    else:
        assert False, "Axis not supported"
    input_rows = []
    input_cols = []
    input_chs = []
    for i in range(num_inputs):
        curr_shape = update_len_to_4(input_shapes[i])
        input_rows.append(curr_shape[1])
        input_cols.append(curr_shape[2])
        input_chs.append(curr_shape[3])

    back_end = BackEnd.Adf
    kernel_names = {}
    kernel_includes = []
    frontend_only = False
    out_folder = None
    get_kernel_mode = True
    ifm_bits = ifm_bytes << 3
    disable_qdq = attributes['disable_qdq'][0]
    is_qdq = False if disable_qdq else True
    is_int16 = True if ifm_bits == 16 else False
    qdq_mode = 2 if is_qdq else 3
    return build_concat(
                        back_end, \
                        kernel_names, \
                        kernel_includes, \
                        aie_cols, \
                        aie_rows, \
                        num_inputs, \
                        concat_mode, \
                        input_rows, \
                        input_cols, \
                        input_chs, \
                        # ifm_bits, \
                        is_int16, \
                        qdq_mode, \
                        frontend_only, \
                        out_folder, \
                        get_kernel_mode
        )

def kernel_selection_logic_basic_test():
    inputs = [
        {
            "inputs": [{'param_name': 'input_0', 'type': 'act', 'shape': [1, 115, 199, 2], 'dtype': 'uint16', 'dtype_bytes': 2}, {'param_name': 'input_0_scale', 'type': 'const', 'shape': [], 'dtype': 'float32', 'dtype_bytes': 4}, {'param_name': 'input_0_zeropoint', 'type': 'const', 'shape': [], 'dtype': 'uint16', 'dtype_bytes': 2}, {'param_name': 'input_1', 'type': 'act', 'shape': [1, 115, 199, 4], 'dtype': 'uint16', 'dtype_bytes': 2}, {'param_name': 'input_1_scale', 'type': 'const', 'shape': [], 'dtype': 'float32', 'dtype_bytes': 4}, {'param_name': 'input_1_zeropoint', 'type': 'const', 'shape': [], 'dtype': 'uint16', 'dtype_bytes': 2}, {'param_name': 'concat_result_y_scale', 'type': 'const', 'shape': [], 'dtype': 'float32', 'dtype_bytes': 4}, {'param_name': 'concat_result_y_zero_point', 'type': 'const', 'shape': [], 'dtype': 'uint16', 'dtype_bytes': 2}],
            "outputs" : [{'param_name': 'concat_result', 'type': 'act', 'shape': [1, 115, 199, 6], 'dtype': 'uint16', 'dtype_bytes': 2}],
            "ifm_bytes" : 2,
            "attributes": {
                "axis": [
                    -1
                ],
                "num_inputs": [
                    2
                ],
                "disable_qdq": [
                    0
                ],
                "pm_id": [
                    0
                ]
            }
        }
    ]

    for input in inputs:
        print(concat_kernel_selection_logic(input['inputs'], input['outputs'], input['ifm_bytes'], input['attributes']))


def parse_args():
    parser = argparse.ArgumentParser(description="Build Concat for various shapes and permutations.")
    parser.add_argument(
        "--backend", type=int, default=0,
        help="Backend type (default: 0 for Adf)"
    )
    parser.add_argument(
        "--qdq_mode", type=int, default=3,
        help="QDQ mode (default: 3)"
    )
    parser.add_argument(
        "--ifm_bits", type=int, default=16,
        help="ifm bits (default: 16)"
    )
    parser.add_argument(
        "--shape_index", type=int,
        help="Index of the shape from the input set to run (if not provided, runs all)"
    )
    parser.add_argument(
        "--save_bins", type=bool, default=False,
        help="Save generated bin files and dma.hpp"
    )
    return parser.parse_args()


def main():
    args = parse_args()

    back_end = BackEnd(args.backend)
    qdq_mode = args.qdq_mode
    save_bins = args.save_bins
    kernel_names = {
                        "run_combined_qdq": kernel_func_list.index("run_combined_qdq"),
                        "run_combined_qdq_a8": kernel_func_list.index("run_combined_qdq_a8"),
                        "run_concat": kernel_func_list.index("run_concat"),
                        "run_concat_a8": kernel_func_list.index("run_concat_a8")
                    }
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc', 'concat/wrapper_concat.cc']
    aie_cols = 8
    aie_rows = 4
    ifm_bits = args.ifm_bits
    is_qdq = False if qdq_mode == 3 else True
    frontend_only = False
    # qdq_mode = 2 if is_qdq else 3
    # qdq_mode = sys.argv[2] if len(sys.argv) >= 3 else 2
    is_int16 = True if ifm_bits == 16 else False

    has_constant_int = False
    is_signed = False

    PSP1_SHAPE = {
        0: [[115, 115], [199, 199], [2, 4],        2,   0],
        1: [[115, 115], [199, 199], [3, 1],        2,   0],
        2: [[115, 115], [199, 199], [48, 54],      2,   0],
        3: [[115, 115], [199, 199], [6, 48],       2,   0],
        4: [[33, 33],   [57, 57],   [48, 48],      2,   0],
        5: [[66, 66],   [114, 114], [48, 48],      2,   0],
        6: [[66, 66],   [114, 114], [48, 96],      2,   0],
        7: [[1, 1],     [2, 2],     [3072, 3072],  2,   0],
        8: [[1, 1],     [64, 64],     [63, 1],     2,   0],
        9: [[1, 1],     [64, 64],     [64, 64],    2,   1],
        10: [[1, 1],     [6, 6],     [128, 128],    2,   0],
        11: [[1, 1],     [6, 6],     [100, 108],    2,   0],
        12: [[1, 1],     [32, 32],     [1, 3],    2,   0],
    }

    PSD5_SHAPE = {
        0: [[16, 16], [16, 16], [1280, 1280], 2, 0],
        1: [[16, 16], [16, 16], [1280,  640], 2, 0],
        2: [[32, 32], [32, 32], [1280,  640], 2, 0],
        3: [[32, 32], [32, 32], [640 ,  320], 2, 0],
        4: [[32, 32], [32, 32], [640 ,  640], 2, 0],
        5: [[64, 64], [64, 64], [320 ,  320], 2, 0],
        7: [[64, 64], [64, 64], [640 ,  320], 2, 0]
    }

    target_shape = PSP1_SHAPE
    shapes = [target_shape[args.shape_index]] if args.shape_index is not None else target_shape.values()

    for input_rows, input_cols, input_chs, num_inputs, concat_mode in shapes:
        # qdq_mode = 2 if is_qdq else 3
        input_types = ["act" ]* (num_inputs-1) + ['const'] if has_constant_int else ["act"] * num_inputs
        build_concat(
            back_end, kernel_names, kernel_includes,
            aie_cols, aie_rows,
            num_inputs,
            concat_mode,
            input_rows, input_cols, input_chs,
            # ifm_bits,
            is_int16,
            int(qdq_mode),
            is_signed,
            frontend_only,
            out_folder=f"input_{'_'.join(map(str, input_rows + input_cols + input_chs + [concat_mode] + [is_qdq] + [qdq_mode]))}",
            save_bins=save_bins,
            input_types=input_types,
        )

if __name__ == '__main__':
    main()
