import sys
sys.path.append('.')
sys.path.append('..')
sys.path.append('../../dmacompiler/')
sys.path.append('../../..')
sys.path.append('../../../..')
from dmacompiler import generate_shim_data_transfer, generate_transfer_params, TransferParams, shim_dma, memtile_dma, DmaDir
from typing import Union

'''
shim access pattern :

    Input Tensor on DDR is 2 Dimensional : (num_rows, num_cols)
'''
def access_shim_rm_vert_shard(
    num_rows: int,
    num_cols: int,
    sbv_rows: Union[int, None],
    sbv_cols: int,
    elem_bytes: int,
    numHeads: int,
    shim_dmaChannel,
    bo_offset: int,
    id: str,
    aie_col=0,
    Num4x4=1,
    refmode=0
) -> tuple:
    
    assert num_cols % sbv_cols == 0
    step0     = 1
    step1     = (num_cols * elem_bytes) // 4
    wrap0     = (sbv_cols * elem_bytes) // 4
    iter_step = (sbv_cols * elem_bytes) // 4

    H = "H"
    Y = id+"Y"
    X = id+"X"

    use_iter_step = True
    num_rows_per_col = num_rows // 4  ## 4 == NumAieCompCols 
    
    tiling_format_string = f"{H}:0:{numHeads} {Y}:{aie_col*num_rows_per_col}:{(aie_col+1)*num_rows_per_col} {X}:0:{sbv_cols}"
    param_s2mm = generate_transfer_params( shim_dmaChannel,
                    memory_format=f"{H}:{numHeads} {Y}:{num_rows} {X}:{sbv_cols}",   
                    tiling_format=tiling_format_string,#f"{X}:0:{num_cols}:{sbv_cols} {Y}:0:{num_rows} {X}:0:{sbv_cols}",
                 bits_per_block=elem_bytes*8, use_iter_step=use_iter_step) 

    repeatCnt = param_s2mm[0] if(type(param_s2mm)==tuple) else 1
    tfparams  = param_s2mm[1][0] if(type(param_s2mm)==tuple) else param_s2mm
    
    if(id == "V"):
        print(id + ":")
        print("-----------------------------")
        print("shim_dmaChannel:", shim_dmaChannel)
        print("tfparams._length:", tfparams._length)
        print("bo_offset:", bo_offset)
        print("tfparams._offset:", tfparams._offset)
        print("offset: ", bo_offset+tfparams._offset)
        print("tfparams._step:", tfparams._step)
        print("tfparams._wrap:", tfparams._wrap)
        print("tfparams._iter_step:", tfparams._iter_step)
        print("repeatCnt:", repeatCnt)

    if(refmode==1):
        return (step0, step1, wrap0, iter_step)
    else:
        return TransferParams(shim_dmaChannel, tfparams._length, # ShimQryShardWords,
                   offset=bo_offset + tfparams._offset,
                   step=tfparams._step, wrap=tfparams._wrap, iter_step=tfparams._iter_step)


def gen_Qry_shim_data_transfer(
    num_rows: int,
    num_cols: int,
    sbv_rows: Union[int, None],
    sbv_cols: int,
    elem_bytes: int,
    numHeads: int,
    shim_dmaChannel,
    bo_offset_words: int,
    id: str,
    aie_col=0,
    Num4x4=1):

    H = "H"
    Y = id+"Y"
    X = id+"X"

    return generate_shim_data_transfer(
            repeat_counts=[1], dma=shim_dmaChannel, shim_buffer_idx=1,
            memory_format=f"{H}:{numHeads} {Y}:{num_rows} {X}:{sbv_cols}",
            tiling_format=f"{H}:0:{numHeads} {Y}:0:{num_rows} {X}:0:{sbv_cols}",
            bits_per_block=elem_bytes*8, buffer_offset=bo_offset_words*4, verbose=True) 

def gen_Out_shim_data_transfer(
    num_rows: int,
    num_cols: int,
    sbv_rows: Union[int, None],
    sbv_cols: int,
    elem_bytes: int,
    numHeads: int,
    shim_dmaChannel,
    bo_offset_words: int,
    id: str,
    aie_col=0,
    Num4x4=1):

    H = "H"
    Y = id+"Y"
    X = id+"X"

    return generate_shim_data_transfer(
            repeat_counts=[1],
            dma=shim_dmaChannel,
            shim_buffer_idx = 0,
            memory_format=f"{H}:{numHeads} {Y}:{num_rows} {X}:{sbv_cols}",
            tiling_format=f"{H}:0:{numHeads} {Y}:0:{num_rows}:{Num4x4*sbv_rows} {Y}:{0*sbv_rows}:{1*sbv_rows} {X}:0:{sbv_cols}",
            bits_per_block=elem_bytes*8,
            buffer_offset=bo_offset_words*4,
            verbose=True)
    
'''
def access_shim_rm_hori_shard(
    num_rows: int,
    num_cols: int,
    sbv_rows: int,
    sbv_cols: int,
    elem_bytes: int,
    numHeads: int,
    shim_dmaChannel,
    bo_offset: int,
    refmode=0

) -> tuple:
    assert num_rows % sbv_rows == 0
    assert num_cols % sbv_cols == 0
    step0 = 1
    step1 = (num_cols * elem_bytes) // 4
    step2 = (sbv_cols * elem_bytes) // 4
    wrap0 = (sbv_cols * elem_bytes) // 4
    wrap1 = sbv_rows
    iter_step = (sbv_rows * num_cols * elem_bytes) // 4

    param_s2mm = generate_transfer_params( shim_dma(0, DmaDir.S2MM, 1),
                    memory_format=f"Y:{num_rows} X:{num_cols}",   
                    tiling_format=f"Y:0:{num_rows}:{sbv_rows} X:0:{num_cols}:{sbv_cols} Y:0:{sbv_rows} X:0:{sbv_cols}",
                 bits_per_block=elem_bytes*8, use_iter_step=True) 

    repeatCnt = param_s2mm[0] if(type(param_s2mm)==tuple) else 1
    tfparams  = param_s2mm[1][0] if(type(param_s2mm)==tuple) else param_s2mm


    assert(repeatCnt == numHeads) #print("mm2s repeat_cnt:", repeatCnt) #tfparams2._iter_wrap)
    assert(tfparams._step == [step0, step1, step2]) #print("mm2s step:", tfparams._step, " steps:", [step0, step1, step2])
    assert(tfparams._wrap == [wrap0, wrap1])        #print("mm2s wrap:", tfparams._wrap, " wraps:", [wrap0, wrap1])
    assert(tfparams._length == num_rows*num_cols*elem_bytes//4//numHeads) #print("mm2s lenth:", tfparams._length, " len: ", num_rows*num_cols*elem_bytes//4 //H)
    assert(tfparams._iter_step == iter_step) #print("iter_step:", tfparams._iter_step, "iter_step:", iter_step)

    if(refmode==1):
        return (step0, step1, step2, wrap0, wrap1, iter_step)
    else:
        return TransferParams(shim_dmaChannel, tfparams._length, # ShimQryShardWords,
                   offset=bo_offset,
                   step=tfparams._step, wrap=tfparams._wrap, iter_step=tfparams._iter_step)
'''

#if(id=="O"):  ## Jump for every 64 rows within a head
#    if(aie_col < 4):
#        tiling_format_string = f"{H}:0:{numHeads} {X}:0:{num_cols}:{sbv_cols} {Y}:0:{num_rows}:{Num4x4*sbv_rows} {Y}:{0*sbv_rows}:{1*sbv_rows} {X}:0:{sbv_cols}"
#    else:
#        tiling_format_string = f"{H}:0:{numHeads} {X}:0:{num_cols}:{sbv_cols} {Y}:0:{num_rows}:{Num4x4*sbv_rows} {Y}:{1*sbv_rows}:{2*sbv_rows} {X}:0:{sbv_cols}"
#    use_iter_step = False
#param_s2mm = generate_transfer_params( shim_dmaChannel,
#                memory_format=f"{H}:{numHeads} {Y}:{num_rows} {X}:{num_cols}",   
#                tiling_format=tiling_format_string,#f"{X}:0:{num_cols}:{sbv_cols} {Y}:0:{num_rows} {X}:0:{sbv_cols}",
#             bits_per_block=elem_bytes*8, use_iter_step=use_iter_step) 

'''
    if(numHeads >= 2):
        assert(repeatCnt == numHeads) #print("mm2s repeat_cnt:", repeatCnt)#tfparams2._iter_wrap)
        assert(tfparams._step == [step0, step1]) #print("mm2s step:", tfparams._step, " steps:", [step0, step1])
        assert(tfparams._wrap == [wrap0]) #print("mm2s wrap:", tfparams._wrap, " wraps:", [wrap0])
        assert(tfparams._length == num_rows * num_cols * elem_bytes // 4 // numHeads) #print("mm2s lenth:", tfparams._length, " len: ", ShimQryWords//H)
        #ShimKeyShardWords
        #print("tfparams._length=", tfparams._length)
        #print("ShimKeyShardWords=", ShimKeyShardWords)
        #if(shim_dmaChannel==shim_dma(col, DmaDir.MM2S, 1)):
        #    assert(tfparams._length == ShimKeyShardWords)
        assert(tfparams._iter_step == iter_step) #print("iter_step:", tfparams._iter_step, "iter_step:", iter_step)
'''

#if((id=="Q" and bo_offset != 0) or (id == "O")):
#if(((id == "K") or (id == "V"))):
