import os
import sys
import json
import shutil
import argparse
from typing import List, Optional

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(CURRDIR)
sys.path.append(os.path.join(CURRDIR, ".."))
sys.path.append(os.path.join(CURRDIR, "..", ".."))
sys.path.append(os.path.join(CURRDIR, "..", "..", "dmacompiler"))
sys.path.append(os.path.join(CURRDIR, "..", "..", "kernels"))

from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list
from dmacompiler import BackEnd, set_dev_gen, DevGen, config
from dataflow_common import clean_overlay, build_sim_overlay, elem_size, sizeof
from slice_common import SliceDims, slice_preproc_directives, make_slice_dict
from slice_tiler import run_tiler
import slice_dataflow

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


def build_slice_qdq(
    back_end: BackEnd,
    kernel_names: dict,
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    input_shape: list,
    slice_dict: dict,
    axis: int,
    ifm_bits: int,
    ofm_bits: int,
    fixed_point_bits: int,
    out_start: int,
    out_stop: int,
    out_step: int,
    qdq_mode: int,
    frontend_only: bool = False,
    out_folder: Optional[str] = None,
    get_kernel_mode: Optional[bool] = False,
    save_bins: Optional[bool] = False,
):

    is_qdq = False if qdq_mode == 3 else True

    dims = run_tiler(
        aie_cols,
        aie_rows,
        input_shape,
        slice_dict,
        axis,
        ifm_bits,
        ofm_bits,
        fixed_point_bits,
        out_start,
        out_stop,
        out_step,
        is_qdq,
        qdq_mode,
    )

    if get_kernel_mode:
        return dims.kernel_names, dims.kernel_includes

    run_scheduler(dims, back_end, kernel_names, kernel_includes, frontend_only)

    if save_bins and back_end == BackEnd.TxnHostPatch and out_folder is not None:
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)
        in_folder = CURRDIR
        files = (
            "ifm.bin",
            "wgt.bin",
            "ofm.bin",
            "dma.hpp",
            "tiling.json",
            "txn.bin",
            "param.bin",
            "ctrl.bin",
            "patch.json",
        )
        for file in files:
            src = os.path.join(in_folder, file)
            dst = os.path.join(out_folder, file)
            shutil.move(src, dst)


def run_scheduler(
    dims: SliceDims,
    back_end: BackEnd,
    kernel_names: dict,
    kernel_includes: List[str],
    frontend_only: bool = False,
):
    clean_overlay()
    slice_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)
    if not frontend_only:
        build_sim_overlay(
            back_end, "slice_main.cpp", slice_preproc_directives(dims, back_end)
        )


def extract_fields(file_name):
    with open(file_name, "r", encoding="utf-8") as f:
        data = json.load(f)
    return data


def update_len_to_4(data):
    while len(data) < 4:
        data.insert(0, 1)
    return data


def slice_axis(updated_input: int, updated_output: int):
    for axis in range(4):
        if updated_output[axis] < updated_input[axis]:
            return axis
    assert False, "This is a no-op"


def get_nested_value(d, keys, default=None):
    for key in keys:
        if isinstance(d, dict):
            d = d.get(key)
        else:
            return default
    return d if d is not None else default


def run_slice_qdq_op(json_file, path, txn_mode, kernel_d, frontend_only):
    os.chdir(path)
    _data = extract_fields(json_file)
    in_shape = _data["layer_info"]["in_act_shape"]
    out_shape = _data["layer_info"]["out_act_shape"]
    if in_shape == out_shape:
        assert False, "Input == Output. This is a no-op."
    back_end = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    aie_rows = _data["overlay_info"]["shape"]["row"]
    aie_cols = _data["overlay_info"]["shape"]["col"]
    input_bits = sizeof(_data["layer_info"]["in_datatype"])
    output_bits = sizeof(_data["layer_info"]["out_datatype"])
    op_type = _data["layer_info"]["op_type"]

    fixed_point_bits = 8 if op_type == "Slice_qdq_uint8xuint8" else 16
    # NOTE: All three blocks below a temporary changes until CR
    # - [JIRA]/AIESW-3214 is fixed
    # out_start
    start_list = get_nested_value(_data, ["layer_info", "attributes", "start"], [])
    out_start = start_list[0] if isinstance(start_list, list) and start_list else 1

    # out_stop
    stop_list = get_nested_value(_data, ["layer_info", "attributes", "end"], [])
    out_stop = stop_list[0] if isinstance(stop_list, list) and stop_list else 64

    # out_step
    step_list = get_nested_value(_data, ["layer_info", "attributes", "step"], [])
    out_step = step_list[0] if isinstance(step_list, list) and step_list else 1

    # disable_q
    disable_q_list = get_nested_value(
        _data, ["layer_info", "attributes", "disable_q"], []
    )
    disable_q = (
        disable_q_list[0] if isinstance(disable_q_list, list) and disable_q_list else 1
    )

    # disable_dq
    disable_dq_list = get_nested_value(
        _data, ["layer_info", "attributes", "disable_dq0"], []
    )
    disable_dq = (
        disable_dq_list[0]
        if isinstance(disable_dq_list, list) and disable_dq_list
        else 1
    )

    if disable_q and disable_dq:
        qdq_mode = 3  # no QDQ
    elif disable_q:
        qdq_mode = 0  # no Q, only DQ
    elif disable_dq:
        qdq_mode = 1  # no DQ, only Q
    else:
        qdq_mode = 2  # both

    if qdq_mode == 0:
        ifm_bits = fixed_point_bits
        ofm_bits = 16
    elif qdq_mode == 1:
        ifm_bits = 16
        ofm_bits = fixed_point_bits
    else:
        ifm_bits = fixed_point_bits
        ofm_bits = fixed_point_bits

    updated_input = update_len_to_4(in_shape)
    updated_output = update_len_to_4(out_shape)
    updated_slice_axis = slice_axis(updated_input, updated_output)
    slice_dict = make_slice_dict(updated_input, updated_slice_axis, out_start, out_stop)

    if not kernel_d:
        if fixed_point_bits == 8:
            kernel_names = {
                "run_combined_qdq_a8": kernel_func_list.index("run_combined_qdq_a8"),
                "run_slice_a8": kernel_func_list.index("run_slice_a8"),
            }
        else:
            kernel_names = {
                "run_combined_qdq": kernel_func_list.index("run_combined_qdq"),
                "run_slice": kernel_func_list.index("run_slice"),
            }
        kernel_includes = ["super.hh", "qdq/wrapper_qdq.cc", "slice/wrapper_slice.cc"]
    else:
        kernel_names = kernel_d["kernel_list"]
        kernel_includes = kernel_d["kernel_include"]

    output_dir = (
        os.path.dirname(os.path.realpath(json_file))
        if back_end != BackEnd.Adf
        else None
    )
    build_slice_qdq(
        back_end,
        kernel_names,
        kernel_includes,
        aie_cols,
        aie_rows,
        updated_input,
        slice_dict,
        updated_slice_axis,
        ifm_bits,
        ofm_bits,
        fixed_point_bits,
        out_start,
        out_stop,
        out_step,
        qdq_mode,
        frontend_only,
        output_dir,
    )


def slice_kernel_selection_logic(
    inputs: List[dict], outputs: List[dict], ifm_bytes: int, attributes: dict
):

    input_shapes = [input["shape"] for input in inputs if len(input.get("shape")) != 0]
    output_shapes = [
        output["shape"] for output in outputs if len(output.get("shape")) != 0
    ]

    updated_input = update_len_to_4(input_shapes[0])
    updated_output = update_len_to_4(output_shapes[0])
    out_start = attributes["start"][0]
    out_stop = attributes["end"][0]
    out_step = attributes["step"][0]

    back_end = BackEnd.Adf
    updated_slice_axis = slice_axis(updated_input, updated_output)

    aie_cols = 8
    aie_rows = 4
    slice_dict = make_slice_dict(updated_input, updated_slice_axis, out_start, out_stop)
    disable_qdq = attributes["disable_qdq"][0]
    is_qdq = False if disable_qdq else True
    frontend_only = True
    output_dir = ""
    kernel_names = {}
    kernel_includes = []
    get_kernel_mode = True
    ifm_bits = ifm_bytes << 3
    return build_slice_qdq(
        back_end,
        kernel_names,
        kernel_includes,
        aie_cols,
        aie_rows,
        updated_input,
        slice_dict,
        updated_slice_axis,
        ifm_bits,
        out_start,
        out_stop,
        out_step,
        is_qdq,
        frontend_only,
        output_dir,
        get_kernel_mode,
    )


def parse_args():
    parser = argparse.ArgumentParser(
        description="Build Slice for various shapes and permutations."
    )
    parser.add_argument(
        "--backend", type=int, default=0, help="Backend type (default: 0 for Adf)"
    )
    parser.add_argument("--qdq_mode", type=int, default=3, help="QDQ mode (default: 3)")
    parser.add_argument(
        "--shape_index",
        type=int,
        help="Index of the shape from the input set to run (if not provided, runs all)",
    )
    parser.add_argument(
        "--save_bins",
        type=bool,
        default=False,
        help="Save generated bin files and dma.hpp",
    )
    parser.add_argument("--dtype", type=int, default=16, help="dtype of slice + qdq")
    return parser.parse_args()


"""
Tasks to do in future PRs:
- Remove allocation of scratch buffer and increase allocation in L1 in the following cases:
    - Dequant INT8 -> BF16: double allocation for input data
    - Quant and Dequant INT8 -> BF16 -> INT8: double allocation for input data
"""


def main():
    args = parse_args()

    back_end = BackEnd(args.backend)
    qdq_mode = args.qdq_mode
    save_bins = args.save_bins

    kernel_names = {"run_combined_qdq": 4, "run_slice": 24}
    kernel_includes = ["super.hh", "qdq/wrapper_qdq.cc", "slice/wrapper_slice.cc"]
    aie_cols, aie_rows = 8, 4

    fixed_point_bits = args.dtype

    ifm_bits, ofm_bits = elem_size(fixed_point_bits, qdq_mode)

    frontend_only = False

    psp1_shapes = {
        0: [
            [1, 115, 115, 99],
            [1, 115, 115, 4],
            95,
            99,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        1: [
            [1, 115, 199, 98],
            [1, 115, 199, 96],
            0,
            96,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        2: [
            [1, 1882, 384],
            [1, 1, 384],
            0,
            1,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        3: [
            [1, 1886, 384],
            [1, 1881, 384],
            5,
            1886,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        4: [
            [1, 1886, 384],
            [1, 5, 384],
            0,
            5,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        5: [
            [1, 4, 460, 796],
            [1, 3, 460, 796],
            0,
            3,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        6: [
            [1, 6, 460, 796],
            [1, 1, 460, 796],
            0,
            1,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        7: [
            [1, 115, 199, 98],
            [1, 115, 199, 1],
            97,
            98,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        8: [
            [1, 1882, 384],
            [1, 1881, 384],
            1,
            1882,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        9: [
            [1, 6, 460, 796],
            [1, 4, 460, 796],
            2,
            6,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        10: [
            [6, 9, 16, 22885],
            [6, 1, 16, 22885],
            8,
            9,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        11: [
            [1, 1, 64, 64],
            [1, 1, 64, 63],
            1,
            64,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        12: [
            [1, 1, 64, 16384],
            [1, 1, 64, 8192],
            8192,
            16384,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        13: [
            [1, 1, 64, 16384],
            [1, 1, 64, 8192],
            0,
            8192,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        14: [
            [1, 9, 16, 64],
            [1, 1, 16, 64],
            8,
            9,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        15: [
            [1, 1, 6, 16384],
            [1, 1, 6, 8192],
            8192,
            16384,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
        16: [
            [1, 1, 6, 16384],
            [1, 1, 6, 8192],
            0,
            8192,
            1,
        ],  # PASS: INT8 & INT16 - QDQ MODES: 0, 1, 2, 3
    }

    debug_shapes = {
        0: [[1, 230, 449, 880], [1, 230, 449, 277], 80, 357, 1],
        1: [[8, 573, 657, 320], [8, 573, 657, 277], 11, 288, 1],
        2: [[8, 282, 921, 608], [8, 282, 921, 367], 224, 591, 1],
        3: [[1, 991, 727, 848], [1, 991, 727, 387], 187, 574, 1],
        4: [[8, 453, 925, 352], [8, 453, 925, 15], 335, 350, 1],
        5: [[8, 740, 250, 896], [8, 740, 250, 277], 588, 865, 1],
        6: [[1, 382, 811, 400], [1, 382, 811, 1], 387, 388, 1],
        7: [[1, 878, 341, 416], [1, 878, 341, 357], 14, 371, 1],
        8: [[1, 115, 200, 98], [1, 115, 200, 95], 0, 95, 1],
        9: [[1, 73, 16, 16], [1, 73, 16, 13], 0, 13, 1],
        10: [[1, 35, 88, 7], [1, 35, 88, 3], 0, 3, 1],
        11: [[1, 30, 47, 127], [1, 30, 47, 123], 0, 123, 1],
        12: [[1, 31, 33, 32], [1, 31, 33, 31], 0, 31, 1],
        13: [[1, 7, 4, 3], [1, 7, 4, 1], 0, 1, 1],
        14: [[1, 8, 9, 54], [1, 8, 9, 33], 0, 33, 1],
        15: [[1, 15, 20, 98], [1, 15, 20, 95], 0, 95, 1],
        16: [[8, 260, 1010, 488], [8, 260, 1010, 3], 260, 263, 1],
    }

    OSS_20B_shapes = {
        0: [[1, 32], [1, 1], 16, 17, 1],
    }

    if fixed_point_bits == 8:
        kernel_names = {
            "run_slice_a8": kernel_func_list.index("run_slice_a8"),
            "run_combined_qdq_a8": kernel_func_list.index("run_combined_qdq_a8"),
        }
    else:
        kernel_names = {
            "run_slice": kernel_func_list.index("run_slice"),
            "run_combined_qdq": kernel_func_list.index("run_combined_qdq"),
        }

    target_shape = OSS_20B_shapes
    shapes = (
        [target_shape[args.shape_index]]
        if args.shape_index is not None
        else target_shape.values()
    )

    for input_shape, output_shape, out_start, out_stop, out_step in shapes:
        print("input_shape", input_shape)
        print("output_shape", output_shape)
        updated_input = update_len_to_4(input_shape)
        updated_output = update_len_to_4(output_shape)
        axis = slice_axis(updated_input, updated_output)
        slice_dict = make_slice_dict(updated_input, axis, out_start, out_stop)
        print("slice_dict_before", slice_dict)
        build_slice_qdq(
            back_end,
            kernel_names,
            kernel_includes,
            aie_cols,
            aie_rows,
            updated_input,
            slice_dict,
            axis,
            ifm_bits,
            ofm_bits,
            fixed_point_bits,
            out_start,
            out_stop,
            out_step,
            qdq_mode,
            frontend_only,
            out_folder=f"input_{'_'.join(map(str, updated_input + [axis] + [out_start] + [out_stop] + [qdq_mode] + [ifm_bits] + [ofm_bits]))}",
            save_bins=save_bins,
        )
        print("slice_dict_after", slice_dict)


if __name__ == "__main__":
    main()
