import os
import sys
import math
import struct
from typing import Dict, Tuple, List
import os
import sys
from typing import List
CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', '..'))

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, TransferParams, core_dma, memtile_dma, shim_dma, DmaChannel,\
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    compute_buffer_size, \
    generate_core_buffer_config, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config, DmaConnection, pack_reconfig_transfers

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_stack_addr,\
    prm_shim_mm2s,  \
    prm_memtile_s2mm, \
    prm_memtile_mm2s, \
    prm_shim_memory, \
    prm_memtile_memory \
    
from gather_common import Dims

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True



CURRDIR = os.path.dirname(os.path.abspath(__file__))


def access_linear_buffer(
    dma: AieDma,
    buffer_bytes: int,
    offset_bytes: int = 0,
) -> TransferParams:
    '''Function to access line data buffer'''
    assert buffer_bytes % 4 == 0
    assert offset_bytes % 4 == 0
    buffer_words = buffer_bytes // 4
    offset_words = offset_bytes // 4
    return TransferParams(dma, buffer_words, offset=offset_words)


def linear_single_channel_memory(input_shape: List[int]):
    return f'C:{input_shape[3]}'

def linear_single_channel_tiling(input_shape: List[int]):
    return f'C:0:{input_shape[3]}'

'''
Split the indices among each column's memtile
'''
def split_offsets(offsets: List[int], gather_dims: Dims):
    per_col_offsets = [[] for i in range(gather_dims.aie_cols)]
    max_per_col_offsets = math.ceil(len(offsets) / gather_dims.aie_cols)
    print(max_per_col_offsets)
    start = 0
    start_max = len(offsets) - 1
    for col in range(gather_dims.aie_cols):
        for max_per_col_offset in range(max_per_col_offsets):
            idx = min(start, start_max)
            per_col_offsets[col].append(offsets[idx])
            start += 1
        
    return per_col_offsets

'''
Return the size of the data stored in each memtile
'''
def return_memtile_sizes(input_shape: List[int], offsets: List[int], input_bytes: int, num_idxs: int, gather_dims: Dims):
    elements_done = 0
    max_per_col_offsets = math.ceil(len(offsets) / gather_dims.aie_cols)
    per_col_memtile_sizes = []

    for icol in range(gather_dims.aie_cols):
        if (elements_done + max_per_col_offsets <= len(offsets)):
            per_col_size = max_per_col_offsets * input_bytes * input_shape[3]
        else:
            per_col_size = (len(offsets) - elements_done) * input_bytes * input_shape[3] if len(offsets) - elements_done >= 0 else 0
        per_col_memtile_sizes.append(per_col_size)
        elements_done += max_per_col_offsets
    sum_sizes = sum(per_col_memtile_sizes)
    output_shape = num_idxs * input_shape[3] * input_bytes
    assert sum_sizes == output_shape
    return per_col_memtile_sizes

'''
Returns the per-column offset in bytes of where each shim should write to
'''
def shim_ofm_offsets(MemtileDataSize: List[int]):
    curr_sum = 0
    shim_ofm_offsets = []

    for size in MemtileDataSize:
        shim_ofm_offsets.append(curr_sum)
        curr_sum += size
    return shim_ofm_offsets

'''
Generate the actual data transfers and super kernel
'''
def compile_dataflow(
    gather_dims: Dims,
    idxs_list: List[int],
    kernel_names: List[str],
    kernel_includes: List[str],
    hw_run : bool = False,
):
    # Dimensions of the input shape:
    input_shape = gather_dims.input_shape


    # Size of each element in the input tensor in bytes
    input_bytes = gather_dims.input_bits // 8


    # Size of the input data
    input_data_size =  input_shape[0] * input_shape[1] * input_shape[2] * input_shape[3] * input_bytes

    # Backend of the dmacompiler: Adf for simulation and TxnHostPatch for hardware run
    back_end = BackEnd.Adf if hw_run == False else BackEnd.TxnHostPatch

    # number of indices to gather
    num_idxs = len(idxs_list)

    # output size of the gather operation
    output_size = num_idxs * input_shape[3] * input_bytes


    # number of indices to gather per column in the AIE array
    per_col_idxs = split_offsets(offsets=idxs_list, gather_dims=gather_dims)


    #
    # Buffer Sizes
    #

    # ------------------------------ Param ------------------------------------------

    ParamSize = gather_dims.param_subv_size * gather_dims.aie_rows


    # ------------------------------ Shim -------------------------------------------

    ShimDataSize = input_data_size
    ShimOfmDataSize = output_size

    # ----------------------------  Memtile -----------------------------------------

    MemtileDataSize = return_memtile_sizes(
                        input_shape=input_shape, 
                        offsets=idxs_list,
                        input_bytes=input_bytes,
                        num_idxs=num_idxs,
                        gather_dims=gather_dims
                    )

    #
    # Buffer Addrs
    #

    # ----------------------------------- Memtile ----------------------------------

    MemtilePrmPingAddr = 0
    MemtileSrcPingAddr = MemtilePrmPingAddr + ParamSize
    per_col_memtile_phase_offsets = [[input_shape[3] * input_bytes * idx for idx in range(len(col))] for col in per_col_idxs]



    # ------------------------------------- Shim -----------------------------------

    ShimDstBufferIdx = 0
    ShimSrcBufferIdx = 1
    ShimPrmBufferIdx = 3


    per_col_shim_phase_offsets = [[input_shape[3] * input_bytes * idx for idx in col] for col in per_col_idxs]
    per_col_shim_ofm_offsets = shim_ofm_offsets(MemtileDataSize)




    # ----------------------------- Dimension of AIE Array --------------------------

    overlay_shape = OverlayShape(gather_dims.aie_cols, gather_dims.aie_rows)


    # ------------------------------- Core Instructions -----------------------------

    core_instrs = [

        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), 0, 0, 0),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),

    ]

    #------------------------------- Memtile Transfers ------------------------------

    memtile_param_transfers = [
        DataTransfer(
            [1] + [0] * (len(per_col_memtile_phase_offsets[col]) - 1),
            AieTile(TileType.Memtile, col),
            [MemtilePrmPingAddr],
            ParamSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(gather_dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(gather_dims),
                prm_memtile_mm2s(row),
            ) for row in range(gather_dims.aie_rows)],
        ) for col in range(gather_dims.aie_cols)
    ]


    
    # Here the input transfers to the memtile is phased but the output transfer is not (for the purposes of reducing reconfigs)
    # As a result, a dummy transfer is introduced in the ofm transfer

    # Let's take a look at the task queue for each of the channels:

    #          |     DT1     |    DT2     |
    #        ---------------------------------
    # s2mm1:   |     ifm     |   dummy    |  
    # ---------------------------------------
    # mm2s5:   |             |    ofm     |

    # Without the dummy transfer in s2mm1, then as a result, there is no guarantee that the ofm data transfer on the last phase 
    # (since it only has a repeat count of 1) will read the data after the IFM transfer has finished reading from stream and writing to 
    # the buffer in Memtile. Within a Data Transfer object, there is a guarantee that there the BD of writing to stream (mm2s) will happen 
    # after the BD of reading from stream (s2mm) is complete. Thus, by introducing a dummy transfer after the BD s2mm for ifm in the task queue
    # for the last phase, then it is a guarantee that the mm2s ofm transfer will occur after the s2mm ifm transfer. 


    memtile_ifm_transfer = [
        DataTransfer(
            [1] * len(per_col_memtile_phase_offsets[col]),
            AieTile(TileType.Memtile, col),
            [MemtileSrcPingAddr],
            MemtileDataSize[col],
            [
                pack_reconfig_transfers( 
                    memtile_dma(col, DmaDir.S2MM, 1),
                    [linear_single_channel_memory(input_shape)] * len(per_col_memtile_phase_offsets[col]),
                    [linear_single_channel_tiling(input_shape)] * len(per_col_memtile_phase_offsets[col]),
                    bits_per_elem=input_bytes * 8,
                    buffer_offset = per_col_memtile_phase_offsets[col]
                )
            ],
            [   
                access_linear_buffer(memtile_dma(col, DmaDir.MM2S, 5), 0,
                                  offset_bytes=0)
            ]

        ) for col in range(gather_dims.aie_cols)
    ]


    memtile_ofm_transfer = [
        DataTransfer(
            [0] * (len(per_col_memtile_phase_offsets[col]) - 1) + [1],
            AieTile(TileType.Memtile, col), [MemtileSrcPingAddr], MemtileDataSize[col],
            [],
            [access_linear_buffer(memtile_dma(col, DmaDir.MM2S, 5), MemtileDataSize[col],
                                  offset_bytes=0)],
        ) for col in range(gather_dims.aie_cols)    
    ]
    

    memtile_transfers = memtile_param_transfers + memtile_ifm_transfer + memtile_ofm_transfer

    

    # ---------------------------------- Shim Transfers -------------------------------


    shim_param_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (len(per_col_shim_phase_offsets[col]) - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            ShimPrmBufferIdx,
            prm_shim_memory(gather_dims),
            prm_shim_mm2s(col),
        ) for col in range(gather_dims.aie_cols)
    ]


    shim_ifm_transfer = [
        DataTransfer (
            [1] * len(per_col_shim_phase_offsets[col]),
            AieTile(TileType.Shim, col),
            [ShimSrcBufferIdx],
            ShimDataSize,
            [],
            [
                pack_reconfig_transfers(
                    shim_dma(col, DmaDir.MM2S, 1),
                    [linear_single_channel_memory(input_shape)] * len(per_col_shim_phase_offsets[col]),
                    [linear_single_channel_tiling(input_shape)] * len(per_col_shim_phase_offsets[col]),
                    bits_per_elem=input_bytes * 8,
                    buffer_offset= per_col_shim_phase_offsets[col]
                )
            ]
        ) for col in range(gather_dims.aie_cols)
    ]


    shim_ofm_transfer = [
        DataTransfer(
            [1] + [0] * (len(per_col_shim_phase_offsets[col]) - 1),
            AieTile(TileType.Shim, col), 
            [ShimDstBufferIdx], 
            ShimOfmDataSize,
            [access_linear_buffer(shim_dma(col, DmaDir.S2MM, 0), MemtileDataSize[col],
                                offset_bytes=per_col_shim_ofm_offsets[col])],
            [],
        ) for col in range(gather_dims.aie_cols)
    ]


    shim_transfers = shim_param_transfers + shim_ifm_transfer + shim_ofm_transfer

    # --------------------------------------------------------------------------------
    run_layer_compilation(
        overlay_shape,
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections(),
        back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
    )
