import math
import os
import random
import shutil
import sys
from typing import List
import json
CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', '..'))
from dmacompiler import BackEnd, config
from dataflow_common import clean_overlay, \
                            build_sim_overlay, \
                            sizeof, \
                            tiling_json_gen, \
                            count_dims
from gather_dataflow import compile_dataflow
from gather_common import \
    directive, \
    write_bin_file, \
    get_random_idxs, \
    Dims


def extract_fields(file_name):
    with open(file_name, 'r', encoding="utf-8") as f:
        data = json.load(f)
    return data

'''
Gets the value of keys in a nested dictionary
'''
def get_nested_value(d: dict, keys: List[str], default=None):
    for key in keys:
        if isinstance(d, dict):
            d = d.get(key)
        else:
            return default
    return d if d is not None else default


'''
Function to update the shape list to have four dimensions (N, H, W, C)
and the relevant axis being gathered 
'''
def update_len_to_4(data: List[int], axis: int):
    while len(data) < 4:
        data.insert(0, 1)
        axis += 1
        
    return data, axis

'''
The purpose of this function is to:
1. Call DMACompiler to generate the data transfers
2. Either Simulate through AIESimulator with the datatransfers 
    or generate binaries to be run on hardware


Currently, the dataflow here
works for a 2D input tensor and only on the outermost dimension.

Further, the assumption here is that the number of input indices divided by the number of 
columns in use in the AIE array multiplied by the inner dimension should be 
able to fit inside Memtile. This is taken care of in the assert statement.

For further iterations of gather, the tasks are to:
1. Allow for a larger number of indices 
2. Add support for gathering along all four dimensions

The actual data for simulation is generated in the testbench: main.cpp
'''
def build_gather(
    input_shape: List[int],
    num_idxs: int,
    txn_mode: int,
    AieRows: int,
    AieCols: int,
    input_bytes: int,
    idxs_list: int,
    frontend_only: bool,
    kernel_names: List[str],
    kernel_includes: List[str],
    out_folder: str,
    input_prints: int = 0,
    check_output_prints: int = 0,
    gen_waveform: bool = False,
):
    
    # Assert statement to ensure there is enough memtile memory for the dataflow,
    # Since the per channel indices are split across the columns, we need to ensure that a tensor
    # of (num_idxs / AieCols) * C can fit in a memtile
    available_memtile_size = config.MAX_MEMTILE_ADDR  - (config.MAX_CORE_LAYER_PARAM_SIZE * AieRows)
    assert (math.ceil(num_idxs / AieCols) * input_shape[3] * (input_bytes)) <= available_memtile_size, "This shape is unsupported for this dataflow"

    # Writing the indices to a binary file to be read by the testbench
    write_bin_file(idxs_list, 'idxs.bin')

    # Generate output shape depending on the indices
    output_shape = [1, 1, num_idxs, input_shape[3]]
    
    # Whether we are generating binaries to be run on hardware
    hw_run = True if txn_mode == 1 else False

    # Bookkeeping for information about the gather operation, named 'Dims' for consistency among dataflow ops
    gather_dims = Dims (
                    input_shape=input_shape, 
                    output_shape=output_shape, 
                    is_qdq=False, 
                    axis=2, 
                    param_subv_size=config.MAX_CORE_LAYER_PARAM_SIZE,
                    aie_cols=AieCols, 
                    aie_rows=AieRows, 
                    input_bits=input_bytes * 8, 
                    wgt_bits=32, 
                    output_bits=input_bytes * 8
                )
    
    # Compile flags for the testbench
    directives = [
        directive("TXN_MODE", txn_mode, hw_run),
        directive("Nin", input_shape[0], hw_run),
        directive("Hin", input_shape[1], hw_run),
        directive("Win", input_shape[2], hw_run),
        directive("Cin", input_shape[3], hw_run),
        directive("NUM_INDICES", num_idxs, hw_run),
        directive("INPUT_BYTES", input_bytes, hw_run),
        directive("INPUT_PRINTS", input_prints, hw_run),
        directive("CHECK_OUTPUT_PRINTS", check_output_prints, hw_run)
    ]

    # Generate tiling.json
    tiling = {}
    tiling["op_type"] = "gather"
    tiling["orig_input"] = input_shape
    tiling["orig_output"] = output_shape
    tiling_json_gen(tiling, os.path.join(os.getcwd(), 'tiling.json'))

    # Call dmacompiler to generate the data transfers and super kernel
    compile_dataflow(
        gather_dims=gather_dims, 
        idxs_list=idxs_list, 
        kernel_names=kernel_names, 
        kernel_includes=kernel_includes, 
        hw_run=hw_run
    )

    # If not only generating data transfers, simulate through if we are not generating binaries, 
    # else generate relevant transaction binaries, input and output bins, etc to be run on hardware

    if not frontend_only:
        build_sim_overlay(
            backend=BackEnd.Adf if hw_run == False else BackEnd.TxnHostPatch,
            host_filename="gather_main.cpp", 
            compile_flags=directives, 
            dump_trace=gen_waveform
        )

        # if txn_mode: 
        #     if not os.path.exists(out_folder):
        #         os.makedirs(out_folder)
        #         in_folder = CURRDIR
        #         if hw_run:
        #             files = ('ifm.bin', 'wgt.bin', 'ofm.bin', 'tiling.json', 'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        #         else:
        #             assert False
        #         for file in files:
        #             src = os.path.join(in_folder, file)
        #             dst = os.path.join(out_folder, file)
        #             shutil.move(src, dst)

'''
Run the gather op after reading relevant attributes from a json file
'''
def run_gather_op(json_file, path, txn_mode, kernel_d, frontend_only):
    os.chdir(path)
    _data = extract_fields(json_file)
    in_shape = _data['layer_info']['in_act_shape']
    back_end = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    aie_rows = _data['overlay_info']['shape']['row']
    aie_cols = _data['overlay_info']['shape']['col']
    ifm_bits = sizeof(_data['layer_info']['in_datatype'])
    if not kernel_d:
        kernel_names = []
        kernel_includes = ['super.hh']
    else:
        kernel_names = kernel_d['kernel_list']
        kernel_includes = kernel_d['kernel_include']
    axis_list = get_nested_value(_data, ['layer_info', 'attributes', 'axis'], [])
    axis = axis_list[0] if isinstance(axis_list, list) else axis_list
    indices = get_nested_value(_data, ['layer_info', 'attributes', 'indices'], [])
    updated_input, axis = update_len_to_4(in_shape, axis)
    num_indices = len(indices)
    rank = count_dims(updated_input)
    assert rank <= 2, "The current dataflow does not support a tensor of rank greater than 2"
    assert axis == 2, "The current dataflow only supports gather along the outer dimension"
    output_dir = os.path.dirname(os.path.realpath(json_file)) if back_end != BackEnd.Adf else None
    build_gather(
        input_shape=updated_input,
        num_idxs=num_indices,
        txn_mode=txn_mode,
        AieRows=aie_rows,
        AieCols=aie_cols,
        input_bytes = ifm_bits // 8,
        idxs_list = indices,
        frontend_only=frontend_only,
        kernel_names=kernel_names,
        kernel_includes=kernel_includes,
        out_folder=output_dir,
    )


'''
Entry point for manual testing of the gather op
'''
def main():
    # run_gather_op('gather.json', '.', 0, None, False)

    # Shape of the AIE to be used
    AieCols = 8
    AieRows = 4

    # Remove bin files and previous files defining data transfers from a previous run
    clean_overlay()
    
    # Whether the backend of the dmacompiler will be Adf (for sim run) or TxnHostPatch (for a hardware run)
    txn_mode = sys.argv[1] if len(sys.argv) >= 2 else 0


    # Input Shapes that have been tested in the format of id: [input_shape, indices_shape]
    INPUT_SHAPES = {
        0 : [[1, 1, 49408, 512], [1, 1, 10, 77]],
        1 : [[1, 1, 770, 512], [1, 1, 1, 10]],
        2 : [[1, 1, 2, 768], [1, 1, 1, 128]],
        3 : [[1, 1, 30522, 768], [1, 1, 1, 128]],
        4 : [[1, 1, 119547, 768], [1, 1, 1, 128]],
        5 : [[1, 1, 2, 768], [1, 1, 1, 128]]
    }

    # Select the desired input shape through the command line, if no shape specified, all the shapes are selected
    target_dict = INPUT_SHAPES
    shape_id = input("SHAPE ID: ")
    shapes = [target_dict[int(shape_id)]] if len(shape_id) != 0 else target_dict.values()

    # Whether to enable prints regarding the input/output tensor, indices, and errors
    input_prints_entered = input("INPUT_PRINTS: ")
    input_prints = input_prints_entered if int(input_prints_entered) == 1 else 0

    # Whether to enable prints during data integrity check for output
    check_output_prints_entered = input("CHECK OUTPUT PRINTS: ")
    check_output_prints = check_output_prints_entered if int(check_output_prints_entered) == 1 else 0

    # Whether to generate a waveform for the simulation
    gen_waveform_entered = input("GENERATE WAVEFORM: ")
    gen_waveform = gen_waveform_entered if int(gen_waveform_entered) == 1 else 0

    # Relevant kernels and files to be included in the super kernel
    kernel_names = []
    kernel_includes = ['super.hh']

    # For each input_shape, call the build function to call DMACompiler to either 
    # simulate through AIESimulator or generate the binaries
    for input_shape, indices_shape in shapes:

        input_bytes = 2
        num_idxs = indices_shape[0] * indices_shape[1] * indices_shape[2] * indices_shape[3]

        idxs_list = get_random_idxs(input_shape, num_idxs, 2)

        write_bin_file(idxs_list, 'idxs.bin')

        frontend_only = False

        build_gather (
            input_shape=input_shape,
            num_idxs=num_idxs,
            txn_mode=txn_mode,
            AieRows=AieRows,
            AieCols=AieCols,
            input_bytes=input_bytes,
            idxs_list=idxs_list,
            frontend_only=frontend_only,
            kernel_names=kernel_names,
            kernel_includes=kernel_includes,
            out_folder=f'gather_{input_shape[2]}_{input_shape[3]}_{num_idxs}',
            input_prints=input_prints,
            check_output_prints=check_output_prints,
            gen_waveform=gen_waveform,
        )


if __name__ == "__main__":
    main()