import os
import sys
import json
import shutil
import argparse
from typing import List, Optional

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels'))

from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list
from dmacompiler import (
    BackEnd,
    set_dev_gen, DevGen, config
)
from dataflow_common import clean_overlay, build_sim_overlay, elem_size, sizeof
from pad_common import PadDims, pad_preproc_directives
from pad_tiler import run_tiler
import pad_dataflow

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def build_pad_qdq(
    back_end: BackEnd,
    kernel_names: dict,
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    input_shape: list,
    output_shape: list,
    # pad_dict: dict,
    # axis: int,
    ifm_bits: int,
    ofm_bits: int,
    fix_point_bits: int,
    is_signed: bool,
    # fix_point_bits: int,
    # out_start: int,
    # out_stop: int,
    qdq_mode: int,
    frontend_only: bool = False,
    out_folder: Optional[str] = None,
    get_kernel_mode: Optional[bool] = False,
    save_bins: Optional[bool] = False,
):

    dims = run_tiler(
        aie_cols, aie_rows,
        input_shape, output_shape,
        ifm_bits, ofm_bits,
        qdq_mode,
        fix_point_bits, is_signed
    )

    if get_kernel_mode:
        return dims.kernel_names, dims.kernel_includes

    run_scheduler(dims, back_end, kernel_names, kernel_includes, frontend_only)

    if save_bins and back_end == BackEnd.TxnHostPatch and out_folder is not None:
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)
        in_folder = CURRDIR
        files = ('ifm.bin', 'wgt.bin', 'ofm.bin', 'dma.hpp',
                'tiling.json', 'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        for file in files:
            src = os.path.join(in_folder, file)
            dst = os.path.join(out_folder, file)
            shutil.move(src, dst)


def run_scheduler(dims: PadDims,
                  back_end: BackEnd,
                  kernel_names: dict,
                  kernel_includes: List[str],
                  frontend_only: bool = False
                  ):
    clean_overlay()
    pad_dataflow.compile_dataflow(
        dims,
        back_end,
        kernel_names,
        kernel_includes
    )
    if not frontend_only:
        build_sim_overlay(back_end, 'pad_main.cpp', pad_preproc_directives(dims, back_end))


def extract_fields(file_name):
    with open(file_name, 'r', encoding="utf-8") as f:
        data = json.load(f)
    return data


def update_len_to_4(data):
    while len(data) < 4:
        data.insert(0, 1)
    return data


def pad_axis(updated_input: int, updated_output: int):
    for axis in range(4):
        if updated_output[axis] < updated_input[axis]:
            return axis
    assert False, "This is a no-op"


def get_nested_value(d, keys, default=None):
    for key in keys:
        if isinstance(d, dict):
            d = d.get(key)
        else:
            return default
    return d if d is not None else default


def run_pad_qdq_op(json_file, path, txn_mode, kernel_d, frontend_only):
    os.chdir(path)
    _data = extract_fields(json_file)
    in_shape = _data['layer_info']['in_act_shape']
    out_shape = _data['layer_info']['out_act_shape']
    if in_shape == out_shape:
        assert False, "Input == Output. This is a no-op."
    back_end = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    aie_rows = _data['overlay_info']['shape']['row']
    aie_cols = _data['overlay_info']['shape']['col']
    input_bits = sizeof(_data['layer_info']['in_datatype'])
    output_bits = sizeof(_data['layer_info']['out_datatype'])
    op_type = _data['layer_info']['op_type']

    fix_point_bits = 8 if op_type == "pad_qdq_uint8xuint8" else 16
    is_signed = False

    # disable_q
    disable_q_list = get_nested_value(_data, ['layer_info', 'attributes', 'disable_q'], [])
    disable_q = disable_q_list[0] if isinstance(disable_q_list, list) and disable_q_list else 1

    # disable_dq
    disable_dq_list = get_nested_value(_data, ['layer_info', 'attributes', 'disable_dq0'], [])
    disable_dq = disable_dq_list[0] if isinstance(disable_dq_list, list) and disable_dq_list else 1

    if disable_q and disable_dq:
        qdq_mode = 3 # no QDQ
    elif disable_q:
        qdq_mode = 0 # no Q, only DQ
    elif disable_dq:
        qdq_mode = 1 # no DQ, only Q
    else:
        qdq_mode = 2 # both

    if qdq_mode == 0:
        ifm_bits = fix_point_bits
        ofm_bits = 16
    elif qdq_mode == 1:
        ifm_bits = 16
        ofm_bits = fix_point_bits
    else:
        ifm_bits = fix_point_bits
        ofm_bits = fix_point_bits


    if not kernel_d:
        if fix_point_bits == 8:
            kernel_names = {
                                "run_combined_qdq_a8": kernel_func_list.index('run_combined_qdq_a8'),
                            }
        else:
            kernel_names = {
                    "run_combined_qdq": kernel_func_list.index('run_combined_qdq'),
                }
        kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc']
    else:
        kernel_names = kernel_d['kernel_list']
        kernel_includes = kernel_d['kernel_include']

    output_dir = os.path.dirname(os.path.realpath(json_file)) if back_end != BackEnd.Adf else None
    build_pad_qdq(back_end,
                    kernel_names, kernel_includes,
                    aie_cols, aie_rows,
                    in_shape, out_shape,
                    ifm_bits,
                    ofm_bits,
                    fix_point_bits,
                    is_signed,
                    qdq_mode, frontend_only,
                    output_dir)

def parse_args():
    parser = argparse.ArgumentParser(description="Build Slice for various shapes and permutations.")
    parser.add_argument(
        "--backend", type=int, default=0,
        help="Backend type (default: 0 for Adf)"
    )
    parser.add_argument(
        "--qdq_mode", type=int, default=3,
        help="QDQ mode (default: 3)"
    )
    parser.add_argument(
        "--shape_index", type=int,
        help="Index of the shape from the input set to run (if not provided, runs all)"
    )
    parser.add_argument(
        "--save_bins", type=bool, default=False,
        help="Save generated bin files and dma.hpp"
    )
    parser.add_argument(
        "--dtype", type=int, default=False,
        help="dtype of slice + qdq"
    )
    return parser.parse_args()

'''
Tasks to do in future PRs:
- Remove allocation of scratch buffer and increase allocation in L1 in the following cases:
    - Dequant INT8 -> BF16: double allocation for input data
    - Quant and Dequant INT8 -> BF16 -> INT8: double allocation for input data
'''

def main():
    args = parse_args()

    back_end = BackEnd(args.backend)
    qdq_mode = args.qdq_mode
    save_bins = args.save_bins

    kernel_names = {
                        "run_combined_qdq": 4,
                    }
    kernel_includes = ['super.hh', 'qdq/wrapper_qdq.cc']
    aie_cols, aie_rows = 8, 4

    # fix_point_bits = args.dtype


    frontend_only = False
    is_signed = False

    sam2_en_shapes = {
        0: [[1, 32, 32, 768],    [1, 35, 35, 768],   16],
        1: [[1, 32, 32, 768],    [1, 35, 35, 768],    8],
        2: [[1, 64, 64, 384],    [1, 70, 70, 384],   16],
        3: [[1, 64, 64, 384],    [1, 70, 70, 384],    8],
        4: [[1, 512, 512, 128],  [1, 513, 513, 128], 16],
        5: [[1, 512, 512, 128],  [1, 513, 513, 128],  8],
    }

    target_shape = sam2_en_shapes
    shapes = [target_shape[args.shape_index]] if args.shape_index is not None else target_shape.values()


    for input_shape, output_shape, fix_point_bits in shapes:
        print("input_shape", input_shape)
        print("output_shape", output_shape)
        # updated_input = update_len_to_4(input_shape)
        # updated_output = update_len_to_4(output_shape)
        # axis = slice_axis(updated_input, updated_output)
        # slice_dict = make_slice_dict(updated_input, axis, out_start, out_stop)
        # print("slice_dict_before", slice_dict)

        ifm_bits, ofm_bits = elem_size(fix_point_bits, qdq_mode)
        if fix_point_bits == 8:
            kernel_names = {
                'run_combined_qdq_a8' : kernel_func_list.index("run_combined_qdq_a8"),
            }
        else:
            kernel_names = {
                'run_combined_qdq' : kernel_func_list.index("run_combined_qdq"),
            }

        build_pad_qdq(back_end,
                        kernel_names, kernel_includes,
                        aie_cols, aie_rows,
                        input_shape, output_shape,
                        ifm_bits,
                        ofm_bits, #same as ifm_bits
                        fix_point_bits,
                        is_signed,
                        qdq_mode,
                        frontend_only,
                        out_folder=f"input_{'_'.join(map(str, input_shape + output_shape + [qdq_mode] + [ifm_bits]))}",
                        save_bins=save_bins,
        )
        # print("slice_dict_after", slice_dict)

if __name__ == '__main__':
    main()
