import sys
sys.path.append('.')
sys.path.append('..')
sys.path.append('../../dmacompiler/')
sys.path.append('../../..')
sys.path.append('../../../..')
from dmacompiler import generate_transfer_params, TransferParams, memtile_dma, DmaDir


def write_L2_rm_to_w8_subvolumes(
    memt_buf_rows: int,
    memt_buf_cols: int,
    elem_bytes: int,
    memtile_dmaChannel,
    memtile_buffer_offset: int,
    id : str,
    Num4x4=1,
    refmode=0,
    AieRows=4
) -> tuple:

    assert memt_buf_cols % 8 == 0
    subv_words = (memt_buf_rows * memt_buf_cols * elem_bytes) // 4
    step0 = 1
    step1 = (memt_buf_rows * 8 * elem_bytes) // 4
    step2 = (8 * elem_bytes) // 4
    step3 = subv_words
    wrap0 = (8 * elem_bytes) // 4
    wrap1 = (memt_buf_cols // 8)
    wrap2 = (memt_buf_rows)

    Y = id+"Y"
    X = id+"X"

    if(id=="Q"):
        if(Num4x4==1):
            memory_fmt_string = f"{X}:{memt_buf_cols} {Y}:{memt_buf_rows}  {X}:8"
            #tiling_fmt_string = f"{Y}:0:{memt_buf_rows} {X}:0:{memt_buf_cols}"
        else: ##
            assert(Num4x4==2)
            memory_fmt_string = f"{Y}:{2*memt_buf_rows} {X}:{memt_buf_cols} {Y}:{memt_buf_rows}  {X}:8"  
        tiling_fmt_string = f"{Y}:0:{Num4x4*memt_buf_rows} {X}:0:{memt_buf_cols}"
    else: ## K or V    
        memory_fmt_string = f"{Y}:{memt_buf_rows} {X}:{memt_buf_cols} {Y}:{memt_buf_rows//AieRows}  {X}:8"
        tiling_fmt_string = f"{Y}:0:{memt_buf_rows} {X}:0:{memt_buf_cols}"

    param_s2mm = generate_transfer_params( memtile_dmaChannel,
                    memory_format=memory_fmt_string,#f"{Y}:{memt_buf_rows} {X}:{memt_buf_cols} {Y}:{memt_buf_rows//AieRows}  {X}:8",
                    tiling_format=tiling_fmt_string,   
                    
                 bits_per_block=elem_bytes*8, use_iter_step=False) 

    repeatCnt = param_s2mm[0] if(type(param_s2mm)==tuple) else 1
    tfparams  = param_s2mm[1][0] if(type(param_s2mm)==tuple) else param_s2mm
    
    '''
    assert(repeatCnt == 1)  ## since we expect to fill up memtile buffer with stream data exactly once
    assert(tfparams._step == [step0, step1, step2]) #print("mm2s step:", tfparams._step, " steps:", [step0, step1, step2, step3])
    assert(tfparams._wrap == [wrap0, wrap1]) #print("mm2s wrap:", tfparams._wrap, " wraps:", [wrap0, wrap1, wrap2])
    assert(tfparams._length == memt_buf_rows*memt_buf_cols*elem_bytes//4) #print("mm2s lenth:", tfparams._length, " len: ", memt_buf_rows*memt_buf_cols*elem_bytes//4)#CoreQryWords)
    '''
    if(id=="K"):
        print("memory format:", f"{X}:{memt_buf_cols} {Y}:{memt_buf_rows}  {X}:8")
        print("tiling_format:", f"{Y}:0:{memt_buf_rows} {X}:0:{memt_buf_cols}")
        print("memtile_dmaChannel:", memtile_dmaChannel)
        print("tfparams._length:", tfparams._length)
        print("memtile_buffer_offset:", memtile_buffer_offset)
        print("tfparams._step:", tfparams._step)
        print("tfparams._wrap:", tfparams._wrap)
        print("tfparams._iter_step:", tfparams._iter_step)
        #exit(1)

    if(refmode==1):
        return (step0, step1, step2, step3, wrap0, wrap1, wrap2)
    else:
        return TransferParams(memtile_dmaChannel, tfparams._length, # ShimQryShardWords,
                   offset=memtile_buffer_offset,
                   step=tfparams._step, wrap=tfparams._wrap, iter_step=tfparams._iter_step)
    

'''
S2MM W8 to RM while writing to memory
'''
def read_L2_w8_to_rm_subvolumes(
    memt_buf_rows: int,
    memt_buf_cols: int,
    elem_bytes: int,
    memtile_dmaChannel,
    memtile_buffer_offset: int,
    id : str,
    refmode=0
) -> tuple:
    assert memt_buf_cols % 8 == 0
    subv_words = (memt_buf_rows * memt_buf_cols * elem_bytes) // 4
    step0 = 1
    step1 = (memt_buf_cols * elem_bytes) // 4
    step2 = (8 * elem_bytes) // 4
    step3 = subv_words
    wrap0 = (8 * elem_bytes) // 4
    wrap1 = (memt_buf_rows)
    wrap2 = (memt_buf_cols // 8)

    Y = id+"Y"
    X = id+"X"

    param_s2mm = generate_transfer_params( memtile_dmaChannel,
                    memory_format=f"{Y}:{memt_buf_rows} {X}:{memt_buf_cols}",
                    tiling_format=f"{X}:0:{memt_buf_cols}:8 {Y}:0:{memt_buf_rows} {X}:0:8",
                 bits_per_block=elem_bytes*8, use_iter_step=False) 

    repeatCnt = param_s2mm[0] if(type(param_s2mm)==tuple) else 1
    tfparams  = param_s2mm[1][0] if(type(param_s2mm)==tuple) else param_s2mm

    assert(repeatCnt == 1) #print("mm2s repeat_cnt:", repeatCnt, "MemtileOutRepeat:", MemtileOutRepeat ) 
    assert(tfparams._step == [step0, step1, step2]) #print("mm2s step:", tfparams._step, " steps:", [step0, step1, step2, step3]) 
    assert(tfparams._wrap == [wrap0, wrap1]) #print("mm2s wrap:", tfparams._wrap, " wraps:", [wrap0, wrap1, wrap2])
    assert(tfparams._length == memt_buf_rows * memt_buf_cols * elem_bytes // 4 ) #print("mm2s lenth:", tfparams._length, " len: ", subv_rows*subv_cols*elem_bytes//4)
    
    if(id=="O"):
        print("memory format:", f"{Y}:{memt_buf_rows} {X}:{memt_buf_cols}")
        print("tiling_format:", f"{X}:0:{memt_buf_cols}:8 {Y}:0:{memt_buf_rows} {X}:0:8")
        print("memtile_dmaChannel:", memtile_dmaChannel)
        print("tfparams._length:", tfparams._length)
        print("memtile_buffer_offset:", memtile_buffer_offset)
        print("tfparams._step:", tfparams._step)
        print("tfparams._wrap:", tfparams._wrap)
        print("tfparams._iter_step:", tfparams._iter_step)
        #exit(1)

    if(refmode==1):
        return (step0, step1, step2, step3, wrap0, wrap1, wrap2)
    else:
        return TransferParams(memtile_dmaChannel, tfparams._length, 
                       offset=memtile_buffer_offset,
                   step=tfparams._step, wrap=tfparams._wrap, iter_step=tfparams._iter_step)

