# pylint: skip-file
#######################################################################################
from math import ceil, gcd
from dmacompiler import CascDir
from dataflow.dataflow_common import overlay_4x4_core_stream_bdcast, overlay_8x4_core_stream_bdcast

def largest_factor_pair(n):
    for d in range(floor(sqrt(n)), 1, -1):
        if (n % d) == 0:
            return (d, n // d)
    return (1, n)
def lcm(a, b):
    return abs(a*b) // gcd(a, b)

def split_value(value, max_val):
    result = []
    while value > max_val:
        result.append(max_val)
        value -= max_val
    result.append(value)
    return result

def get_bias_rep(T):
    r = ceil(T/63)
    if(r == 1):
        return r
    else:
        for ii in range(2,63):
            if (T%ii == 0) and (ii>=r) and (ii%2 == 0):
                return ii

class group_norm:
    def __init__(
            self,
            AieInst,
            AieRows,
            AieCols,
            KernelName,
            Mlrn,
            Nlrn,
            TdimLayer,
            InOutBytesPerElem,
            CoreMsubv,
            CoreNsubv,
            CoreTsubv,
            CoreMsubvNorm,
            CoreNsubvNorm,
            CoreTsubvNorm,
            MemMsubv,
            MemNsubv
            ):
        
        self.AieInst        = AieInst
        self.AieCols        = AieCols
        self.AieRows        = AieRows
        self.KernelName     = KernelName
        self.ParamElemBytes = 2
        self.InBytes        = InOutBytesPerElem[0] #2 # 1 - int8; 2 - int16
        self.OutBytes       = InOutBytesPerElem[1]
        self.Mlrn = Mlrn
        self.Nlrn = Nlrn
        self.Mlayer = TdimLayer[0]

        self.LRN_EN = 0           # 1 Layer Norm; 0 Group Norm 
        self.combined_xclbin = 0
        self.NG = 32
        self.NGPerInst = 32 // self.AieInst
        assert (self.NG == 32) # kernel assumes 32 groups for PM optimization
        self.ColPerGrp = self.Nlrn // self.NG
        self.CoreVecLen = 16
        self.RowPerGrp = CoreMsubv
        self.Msubv = self.RowPerGrp 
        self.Nsubv = CoreNsubv  
        self.IterFactNorm = CoreTsubvNorm // CoreTsubv

        print('IterFactNorm:', self.IterFactNorm)

        assert (self.Mlrn * self.Nlrn) % (self.Msubv*self.Nsubv*self.AieInst*self.AieRows * self.AieCols) == 0
        assert self.Mlrn % (self.AieRows * self.AieCols) == 0
        self.biasRepetition = self.RowPerGrp 
        self.ElmPerGrpToCore = self.ColPerGrp*self.RowPerGrp
        self.CorePrmSize  = 1024
        self.CoreBiasSize = 2 * self.Nlrn * self.ParamElemBytes * self.biasRepetition // self.IterFactNorm#+ 64
        self.CoreActSize  = (self.Msubv * self.Nsubv) * self.InBytes
        self.CoreOutSize  = (self.Msubv * self.Nsubv) * self.OutBytes
        self.CoreActSizeNormAff  = (CoreMsubvNorm * CoreNsubvNorm) * self.InBytes
        self.CoreOutSizeNormAff  = (CoreMsubvNorm * CoreNsubvNorm) * self.OutBytes
        self.CoreQdqPrmSize = 64
        self.MemtileActSize = MemMsubv * MemNsubv * self.InBytes 
        self.MemtileOutSize = MemMsubv * MemNsubv * self.OutBytes 
        
        self.ShimBiasSize = 2 * self.Nlrn * self.ParamElemBytes 
        self.ShimActSize  = ((self.Mlrn * self.Nlrn) * self.InBytes)// self.AieCols // self.AieInst
        self.ShimOutSize  = ((self.Mlrn * self.Nlrn) * self.OutBytes) // self.AieCols // self.AieInst
        
        self.Tsubv = CoreTsubv 
        self.TsubvNorm = CoreTsubvNorm

        memtile_act_rep = split_value(self.Tsubv, 512)
        self.MemtileActRepeat = memtile_act_rep + memtile_act_rep # one for each pass 
        print('MemtileActRepeat:', self.MemtileActRepeat)

        self.is_host_padded = (self.Mlrn != self.Mlayer)
        #Mask related params
        if self.is_host_padded:
            self.Mlayer_residual = self.Mlayer - (self.Msubv * self.AieRows * self.AieCols * (self.Tsubv-1)) # residual in the last Tsubv
        else:
            self.Mlayer_residual = 0

        assert (self.Mlayer_residual >= 0)
        if(self.Mlayer_residual % self.Msubv):
            self.Msubv_residual = self.Mlayer_residual % self.Msubv
        else :
            self.Msubv_residual = 0

        self.Mask_core_idx = self.Mlayer_residual // self.Msubv
        print('Mlayer:', self.Mlayer, 'Mlayer_residual:', self.Mlayer_residual, 'Msubv_residual:', self.Msubv_residual, 'Mask_core_idx:', self.Mask_core_idx)
        self.Mask_start_col = (self.Mask_core_idx)//self.AieCols
        self.Mask_start_row = (self.Mask_core_idx) % self.AieCols
        print('Mask_start_col:', self.Mask_start_col, 'Mask_start_row:', self.Mask_start_row)
def lrn_params( Nsubv: int,
                Mparam_msb: int,
                Mparam_lsb: int,
                ColPerGrp : int,
                CoreScratchAddr: int,
                IS_LAST_ITER: int,
                OP_SEL: int,
                num_groups: int,
                ElmPerGrpToCore: int,
                BiasOffset: int, 
                mask_enb: int,
                col_id: int,
                inst: int,
                numInst: int,
                msubv_residual: int,
                true_elem_row: int
                ) -> bytes:

    kernel_params = Nsubv.to_bytes(length=2, byteorder='little', signed=False) + \
                    Mparam_msb.to_bytes(length=2, byteorder='little', signed=False) + \
                    Mparam_lsb.to_bytes(length=2, byteorder='little', signed=False) + \
                    ColPerGrp.to_bytes(length=2, byteorder='little', signed=False) + \
                    CoreScratchAddr.to_bytes(length=2, byteorder='little', signed=False) + \
                    IS_LAST_ITER.to_bytes(length=1, byteorder='little', signed=False) + \
                    OP_SEL.to_bytes(length=1, byteorder='little', signed=False) + \
                    num_groups.to_bytes(length=2, byteorder='little', signed=False) + \
                    ElmPerGrpToCore.to_bytes(length=2, byteorder='little', signed=False) + \
                    BiasOffset.to_bytes(length=2, byteorder='little', signed=False) + \
                    mask_enb.to_bytes(length=2, byteorder='little', signed=False) + \
                    col_id.to_bytes(length=2, byteorder='little', signed=False) + \
                    inst.to_bytes(length=2, byteorder='little', signed=False) + \
                    numInst.to_bytes(length=2, byteorder='little', signed=False) + \
                    msubv_residual.to_bytes(length=2, byteorder='little', signed=False) + \
                    true_elem_row.to_bytes(length=2, byteorder='little', signed=False)
    return kernel_params

def get_core_instrs(
    params: group_norm,
    col: int,
    row: int):

    row = row - 2 # adjust for shim rows
    # Common IO buffer
    CoreInPingAddr   = 0
    CoreInPongAddr   = 2*16*1024
    CoreOutPingAddr  = 16*1024 #CoreInPingAddr + CoreActSize
    CoreOutPongAddr  = 3*16*1024 #CoreInPingAddr + CoreActSize
    # Only used for Norm and Affine 
    k_name              = list(params.KernelName.keys())[0]
    in_size_factor      = 2 if params.InBytes == 1 else 1
    out_size_factor     = 2 if params.OutBytes == 1 else 1
    
    CoreBiasPingAddr = CoreInPongAddr + in_size_factor*params.CoreActSizeNormAff
    CoreScratchAddr  = CoreOutPingAddr + out_size_factor*params.CoreOutSizeNormAff
    CoreQdqPrmAddr   = CoreBiasPingAddr + params.CoreBiasSize # always at CoreInPongAddr + 7680*2
    
    Nsubv           = params.Nsubv
    Msubv           = params.Msubv
    Mlrn            = params.Mlrn
    Mlayer          = params.Mlayer
    #assert(Mlrn%32 == 0)
    ColPerGrp       = params.ColPerGrp      
    NGPerInst       = params.NGPerInst             
    ElmPerGrpToCore = params.ElmPerGrpToCore
    IterFactNorm    = params.IterFactNorm
    CoreTsubv = params.Tsubv  
    CoreTsubvNormAff = params.TsubvNorm
    AieInst = params.AieInst
    inst = col // params.AieCols
    col_idx = col % params.AieCols # wraps for every instance

    Mask_start_col = params.Mask_start_col
    Mask_start_row = params.Mask_start_row

    is_mask_start = (row == Mask_start_row and col_idx == Mask_start_col)
    
    msubv_residual = params.Msubv_residual if is_mask_start else 0
    msubv_core_residual = ceil((msubv_residual * ColPerGrp) / 16)
    true_elem_in_last_residual_row = (msubv_residual * ColPerGrp) % 16
    is_partial_mask = (is_mask_start) and (msubv_residual < params.Msubv)
    # Determine mask enable
    is_beyond_mask_start = (
        col_idx > Mask_start_col or
        (col_idx == Mask_start_col and row >= Mask_start_row)
    )

    mask_enb = int((is_beyond_mask_start and params.Msubv_residual != 0) or (is_beyond_mask_start and params.is_host_padded))

    print('col_idx:', col_idx, 
          'row:', row, 
          'is_mask_start:', is_mask_start, 
          'is_partial_mask:', is_partial_mask, 
          'mask_enb:', mask_enb, 
          'msubv_residual:', msubv_residual,
          'msubv_core_residual:', msubv_core_residual,
          'true_elem_in_last_residual_row:', true_elem_in_last_residual_row)

    
    Mlayer_lsb = Mlayer & 0xFFFF          # Lower 16 bits
    Mlayer_msb = (Mlayer >> 16) & 0xFFFF  # Upper 16 bits

    print('Mlayer_msb:', Mlayer_msb, 'Mlayer_lsb:', Mlayer_lsb)

    def get_mean_var_core_inst(Tsubv):
        return [

        # loop to compute mean and variance
                ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreInPingAddr, CoreInPongAddr, params.CoreActSize),
                Loop(Tsubv, [
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    CallKernel(k_name, kernel_params=lrn_params(Msubv*Nsubv, 
                                                                Mlayer_msb,
                                                                Mlayer_lsb, 
                                                                ColPerGrp, 
                                                                CoreScratchAddr, 
                                                                IS_LAST_ITER=0, 
                                                                OP_SEL=0, 
                                                                num_groups=NGPerInst, 
                                                                ElmPerGrpToCore=ElmPerGrpToCore, 
                                                                BiasOffset=0, 
                                                                mask_enb = mask_enb,  
                                                                col_id=col_idx, 
                                                                inst = inst, 
                                                                numInst = AieInst, 
                                                                msubv_residual = msubv_core_residual,
                                                                true_elem_row = true_elem_in_last_residual_row)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ]),
                
                ]
    def get_mean_var_core_inst_last():
        return [
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                CallKernel(k_name, kernel_params=lrn_params(Msubv*Nsubv, 
                                                            Mlayer_msb,
                                                            Mlayer_lsb,  
                                                            ColPerGrp, 
                                                            CoreScratchAddr, 
                                                            IS_LAST_ITER=1, 
                                                            OP_SEL=0, 
                                                            num_groups=NGPerInst, 
                                                            ElmPerGrpToCore=ElmPerGrpToCore, 
                                                            BiasOffset=0, 
                                                            mask_enb = mask_enb,  
                                                            col_id=col_idx, 
                                                            inst=inst, 
                                                            numInst = AieInst,
                                                            msubv_residual=msubv_core_residual,
                                                            true_elem_row = true_elem_in_last_residual_row)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ]
    
    def get_norm_core_inst(TsubvNormAff, is_last_itr):
        return [
        # loop for normalization and affine transformation
            # Bias
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreBiasPingAddr, None, params.CoreBiasSize),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutPingAddr, CoreOutPongAddr, params.CoreOutSizeNormAff),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreInPingAddr, CoreInPongAddr, params.CoreActSizeNormAff),
            Loop(TsubvNormAff, [
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                CallKernel(k_name, kernel_params=lrn_params(Msubv*Nsubv // IterFactNorm, 
                                                            Mlayer_msb,
                                                            Mlayer_lsb,  
                                                            ColPerGrp, 
                                                            CoreScratchAddr, 
                                                            IS_LAST_ITER=is_last_itr, 
                                                            OP_SEL=1, 
                                                            num_groups=NGPerInst, 
                                                            ElmPerGrpToCore=ElmPerGrpToCore, 
                                                            BiasOffset=Msubv*Nsubv // IterFactNorm, 
                                                            mask_enb = mask_enb, 
                                                            col_id=col_idx, 
                                                            inst=inst, 
                                                            numInst = AieInst,
                                                            msubv_residual=msubv_core_residual,
                                                            true_elem_row = true_elem_in_last_residual_row)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
            ]),
        ]

    core_tsubv_enq = split_value(CoreTsubv, 1024)
    core_tsubv_norm_enq = split_value(CoreTsubvNormAff, 1024)
    
    core_inst = []
    # QDQ aquire happens only once
    core_inst += [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreQdqPrmAddr, None, params.CoreQdqPrmSize),
                    ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOutPingAddr, CoreOutPongAddr, 0), # null transfer to get the ping-pong addresses used in a8 kernel
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                 ]
    for tsubv in core_tsubv_enq[:-1] :
        core_inst += get_mean_var_core_inst(tsubv)
    core_inst += get_mean_var_core_inst(core_tsubv_enq[-1]-1)
    core_inst += get_mean_var_core_inst_last()
    # QDQ release 
    core_inst += [RelBuffer(DmaChannel(DmaDir.S2MM, 1))]

    #for tsubv_norm in core_tsubv_norm_enq[:-1] : 
    #    core_inst += get_norm_core_inst(tsubv_norm, is_last_itr=0)
    #core_inst += get_norm_core_inst(core_tsubv_norm_enq[0]-1, is_last_itr = 0)
    #core_inst += get_norm_core_inst(1, is_last_itr = 1)
    
    for idx, tsubv_norm in enumerate(core_tsubv_norm_enq):
        if idx == len(core_tsubv_norm_enq) - 1:
            # For the last value, call with tsubv_norm - 1 times using is_last_itr = 0
            core_inst += get_norm_core_inst(tsubv_norm - 1, is_last_itr=0)
            # Final iteration: call with 1 and is_last_itr = 1
            core_inst += get_norm_core_inst(1, is_last_itr=1)
        else:
            core_inst += get_norm_core_inst(tsubv_norm, is_last_itr=0)

    print("core_tsubv_enq:", core_tsubv_enq, "core_tsubv_norm_enq:", core_tsubv_norm_enq)
    return core_inst 


def get_shim_transfers(params):
    
    AieInst         = params.AieInst
    AieCols         = params.AieCols
    AieRows         = params.AieRows
    Mlrn            = params.Mlrn
    Nlrn            = params.Nlrn
    Msubv           = params.Msubv
    CorePrmSize     = params.CorePrmSize     
    CoreQdqPrmSize  = params.CoreQdqPrmSize  
    ShimBiasSize    = params.ShimBiasSize    
    ShimActSize     = params.ShimActSize     
    ShimOutSize     = params.ShimOutSize     
    biasRepetition  = params.biasRepetition   
    ColPerGrp       = params.ColPerGrp 
    ParamElemBytes  = params.ParamElemBytes 
    InBytes         = params.InBytes 
    OutBytes        = params.OutBytes 
    
    CorePrmWords    = CorePrmSize // 4
    ShimBiasWords   = ShimBiasSize // 4
    ShimActWords    = ShimActSize // 4
    ShimQdqPrmWords = CoreQdqPrmSize // 4
    ShimOutWords    = ShimOutSize // 4
    
    #shim transfer is unrolled to * BDs with chaining
    ShimActOffset = ShimActWords
    ShimActShardWords  = ShimActWords
    ShimActShardOffset = ShimActWords
    
    ShimOutOffset = ShimOutWords
    ShimOutShardWords  = ShimOutWords
    ShimOutShardOffset = ShimOutWords
    
    # Logic to resolve the repeat count reg overflow
    ShimBiasRepFact  = get_bias_rep(params.Tsubv)
    
    shimBiasStep0 = 1
    shimBiasWrap0 = ColPerGrp * ParamElemBytes // 4
    shimBiasStep1 = Nlrn * ParamElemBytes // 4
    shimBiasWrap1 = 2 * biasRepetition 
    shimBiasStep2 = ColPerGrp * ParamElemBytes // 4
    
    # DMA repeat counts
    len_act         = len(params.MemtileActRepeat)
    core_param_rep  = [1] + [0]*(len_act-1)
    qdq_param_rep   = [1] + [0]*(len_act-1)
    bias_rep        = [1] + [0]*(len_act-1)
    act_rep         = [2] + [0]*(len_act-1)
    out_rep         = [1] + [0]*(len_act-1)
    
    print('shim_core_param_rep:',core_param_rep) 
    print('shim_qdq_param_rep:', qdq_param_rep) 
    print('shim_bias_rep', bias_rep) 
    print('shim_act_rep', act_rep) 
    print('shim_out_rep:', out_rep) 

    def access_shim_ifm_mm2s(
            dma: AieDma,
            col
    ) -> TransferParams:
        buffer_words = ShimActShardWords
        if(col<AieCols):
            offset_bytes = col*Msubv*Nlrn*AieRows*InBytes
        else:
            offset_bytes = (col - AieCols)*Msubv*Nlrn*AieRows*InBytes + (Nlrn//AieInst) * InBytes
        
        assert offset_bytes % 4 == 0
        offset_words = offset_bytes // 4
        
        if(AieInst==1):
            tparams = TransferParams(dma, buffer_words, offset=offset_words)
        else:
            s0 = 1
            w0 = (Nlrn//AieInst) * InBytes // 4
            s1 = Nlrn * InBytes // 4
            w1 = Msubv * AieRows
            s2 = Msubv * Nlrn * AieRows * AieCols * InBytes // 4 
            tparams = TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2], wrap=[w0, w1])

        return tparams 
    
    def access_shim_ofm_s2mm(
            dma: AieDma,
            col
    ) -> TransferParams:
        buffer_words = ShimOutShardWords
        
        if(col<AieCols):
            offset_bytes = col*Msubv*Nlrn*AieRows*OutBytes
        else:
            offset_bytes = (col - AieCols)*Msubv*Nlrn*AieRows*OutBytes + (Nlrn//AieInst) * OutBytes
        
        assert offset_bytes % 4 == 0
        offset_words = offset_bytes // 4
        
        if(AieInst==1):
            tparams = TransferParams(dma, buffer_words, offset=offset_words)
        else:
            s0 = 1
            w0 = (Nlrn//AieInst) * OutBytes // 4
            s1 = Nlrn * OutBytes // 4
            w1 = Msubv * AieRows
            s2 = Msubv * Nlrn * AieRows * AieCols * OutBytes // 4 
            tparams = TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2], wrap=[w0, w1])
        
        return tparams 
    
    shim_transfers = [
        DataTransfer(
            core_param_rep,
            AieTile(TileType.Shim, col, 0), [3], CorePrmSize,
            [],
            [TransferParams(AieDma(AieTile(TileType.Shim, col, 0), DmaChannel(DmaDir.MM2S, 0)), CorePrmWords * AieRows, offset = col * (CorePrmWords * AieRows))]
        ) for col in range(AieCols*AieInst)
    ] + [ 
        DataTransfer(
            qdq_param_rep,
            AieTile(TileType.Shim, col, 0), [2], CoreQdqPrmSize,
            [],
            [TransferParams(AieDma(AieTile(TileType.Shim, col, 0), DmaChannel(DmaDir.MM2S, 0)), ShimQdqPrmWords,
                offset = ShimBiasWords*biasRepetition)]
        ) for col in range(0, AieCols*AieInst, AieInst)
    ] + [
        DataTransfer(
            bias_rep,
            AieTile(TileType.Shim, col, 0), [2], ShimBiasSize*biasRepetition,
            [],
            [TransferParams(AieDma(AieTile(TileType.Shim, col, 0), DmaChannel(DmaDir.MM2S, 0)), ShimBiasWords*biasRepetition,
                step = [shimBiasStep0,shimBiasStep1,shimBiasStep2],
                wrap = [shimBiasWrap0,shimBiasWrap1])]
        ) for col in range(0, AieCols*AieInst, AieInst)
    ] + [
    # Mean and variance ifm transfer
        DataTransfer(
            act_rep,
            AieTile(TileType.Shim, col, 0), [1], ShimActSize,
            [],
            [access_shim_ifm_mm2s(shim_dma(col, DmaDir.MM2S, 1), col)],
        ) for col in range(AieCols*AieInst)
    ] + [
        DataTransfer(
            out_rep,
            AieTile(TileType.Shim, col, 0), [0], ShimOutSize,
            [access_shim_ofm_s2mm(shim_dma(col, DmaDir.S2MM, 0), col)],
            []
        ) for col in range(AieCols*AieInst)
    ]
    return shim_transfers
    



def get_memtile_transfers(params):
     
    AieInst         = params.AieInst
    AieCols         = params.AieCols
    AieRows         = params.AieRows
    ParamElemBytes  = params.ParamElemBytes
    InBytes         = params.InBytes
    OutBytes        = params.OutBytes
    NG              = params.NG
    NGPerInst       = params.NGPerInst
    Nlrn            = params.Nlrn
    CorePrmSize     = params.CorePrmSize     
    CoreBiasSize    = params.CoreBiasSize    
    CoreActSize     = params.CoreActSize     
    CoreOutSize     = params.CoreOutSize     
    CoreQdqPrmSize  = params.CoreQdqPrmSize  
    
    MemtileActSize  = params.MemtileActSize  
    MemtileOutSize  = params.MemtileOutSize  
    
    ShimBiasSize    = params.ShimBiasSize    
    ShimBiasWords   = ShimBiasSize // 4 
    
    biasRepetition  = params.biasRepetition
    ColPerGrp       = params.ColPerGrp      
    ElmPerGrpToCore = params.ElmPerGrpToCore
    IterFactNorm    = params.IterFactNorm   
    
    CoreVecLen = params.CoreVecLen

    MemtileBiasRep  = params.Tsubv 
    
    print('MemtileBiasRep:', MemtileBiasRep)
    
    MemtilePrmPingAddr  = 0
    MemtileBiasPingAddr = MemtilePrmPingAddr + (CorePrmSize * AieRows)
    MemtileInPingAddr   = MemtileBiasPingAddr + ShimBiasSize*biasRepetition
    MemtileInPongAddr   = MemtileInPingAddr + MemtileActSize
    MemtileOutPingAddr  = MemtileInPongAddr + MemtileActSize
    MemtileOutPongAddr  = MemtileOutPingAddr + MemtileOutSize
    MemtileQdqPrmAddr   = MemtileOutPongAddr + MemtileOutSize
    
    CorePrmWords    = CorePrmSize // 4
    CoreBiasWords   = CoreBiasSize // 4
    CoreActWords    = CoreActSize // 4
    CoreOutWords    = CoreOutSize // 4
    CoreQdqPrmWords = CoreQdqPrmSize // 4
    
    MemtileActWords = MemtileActSize // 4
    MemtileOutWords = MemtileOutSize // 4
    MemtileQdqPrmWords = CoreQdqPrmSize // 4
    
    # DMA repeat counts
    len_act             = len(params.MemtileActRepeat)
    MemtileActRepeat    = params.MemtileActRepeat 
    act_rep_norm        = MemtileActRepeat[len_act//2:] 
    MemtileOutRepeat    = [0]*(len_act//2) + act_rep_norm  
    core_param_rep      = [1] + [0]*(len_act-1)
    qdq_param_rep       = [1] + [0]*(len_act-1)
    bias_rep1           = [1]+[0]*(len_act-1)
    bias_rep2           = [0]+[0]*((len_act//2)-1) + act_rep_norm[0:]
    print('core_param_rep:',core_param_rep) 
    print('qdq_param_rep:', qdq_param_rep) 
    print('bias_rep1', bias_rep1) 
    print('bias_rep2', bias_rep2) 
    print('MemtileActRepeat:', MemtileActRepeat) 
    print('MemtileOutRepeat:', MemtileOutRepeat) 
     
    def access_memtile_wgt_s2mm(
            dma: AieDma,
            col
    ) -> TransferParams:
        buffer_words = ShimBiasWords*biasRepetition
        s0 = 1
        w0 = CoreVecLen * ParamElemBytes // 4
        s1 = NGPerInst * CoreVecLen * ParamElemBytes // 4 #
        w1 = ElmPerGrpToCore * 2 // CoreVecLen 
        s2 = CoreVecLen * ParamElemBytes // 4
        w2 = NGPerInst
        s3 = w0*w1*w2
        offset_words = 0
        return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])

    def access_memtile_wgt_mm2s(
            dma: AieDma,
            col
    ) -> TransferParams:
        buffer_words = ShimBiasWords*biasRepetition
        # Linear read in case of IterFactNorm == 1
        # Should be verified for larger shapes for 4x8 overlay
        if AieInst==1 and IterFactNorm == 1 :
            return TransferParams(dma, buffer_words, offset=0)
        elif AieInst==1 and IterFactNorm > 1 :
            W0, W1 = largest_factor_pair(NG * CoreVecLen * ParamElemBytes * (ElmPerGrpToCore // CoreVecLen // IterFactNorm)  // 4)
            s0 = 1
            w0 = W0 
            s1 = W0
            w1 = W1 
            s2 = NG * CoreVecLen * ParamElemBytes * (ElmPerGrpToCore // CoreVecLen) // 4 #
            w2 = 2
            s3 = NG * CoreVecLen * ParamElemBytes * (ElmPerGrpToCore // CoreVecLen // IterFactNorm)  // 4
            offset_words = 0
            return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])
        else:
            assert(NGPerInst*ElmPerGrpToCore*ParamElemBytes % IterFactNorm == 0)
            assert(ElmPerGrpToCore//CoreVecLen % IterFactNorm == 0)
            W0, W1 = largest_factor_pair(NGPerInst*ElmPerGrpToCore*ParamElemBytes // IterFactNorm // 4)
            s0 = 1
            w0 = W0
            s1 = W0
            w1 = W1 
            s2 = NGPerInst*ElmPerGrpToCore*ParamElemBytes // 4 
            w2 = 4
            s3 = w0*w1
            w3 = IterFactNorm
            offset_words = 0
            assert(w0*w1*w2*w3 == buffer_words)
            return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])
 
    def access_memtile_ifm_s2mm(
            dma: AieDma,
            col
    ) -> TransferParams:
        buffer_words = MemtileActWords
        s0 = 1
        w0 = ColPerGrp * InBytes // 4
        s1 = ElmPerGrpToCore * InBytes // 4
        w1 = NGPerInst 
        s2 = ColPerGrp * InBytes // 4
        w2 = ElmPerGrpToCore // ColPerGrp
        s3 = NGPerInst * ElmPerGrpToCore * InBytes // 4
        offset_words = 0
        return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])
    

    def access_memtile_ifm_mm2s(
            dma: AieDma,
            col,
            row
    ) -> TransferParams:
        buffer_words =CoreActWords
        s0 = 1
        w0 = CoreVecLen * InBytes // 4
        s1 = ElmPerGrpToCore * InBytes // 4
        w1 = NGPerInst 
        s2 = CoreVecLen * InBytes // 4
        w2 = ElmPerGrpToCore // CoreVecLen
        s3 = NGPerInst * ElmPerGrpToCore * InBytes // 4
        
        offset_words = row * CoreActWords
        return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])
    
    def access_memtile_ofm_s2mm(
            dma: AieDma,
            col,
            row
    ) -> TransferParams:
        buffer_words = CoreOutWords
        s0 = 1
        w0 = CoreVecLen * OutBytes // 4
        s1 = ElmPerGrpToCore * OutBytes // 4
        w1 = NGPerInst
        s2 = CoreVecLen * OutBytes // 4
        w2 = ElmPerGrpToCore // CoreVecLen 
        s3 = NGPerInst * ElmPerGrpToCore * OutBytes // 4
        offset_words = row*CoreOutWords
        return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])
    
    def access_memtile_ofm_mm2s(
            dma: AieDma,
            col
    ) -> TransferParams:
        
        buffer_words = MemtileOutWords
        
        s0 = 1
        w0 = ColPerGrp * OutBytes // 4
        s1 = ElmPerGrpToCore * OutBytes // 4
        w1 = NGPerInst
        s2 = ColPerGrp * OutBytes // 4
        w2 = ElmPerGrpToCore // ColPerGrp
        s3 = NGPerInst * ElmPerGrpToCore * OutBytes // 4
        
        offset_words = 0
        return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])
    
    memtile_transfers = [
        DataTransfer(
            core_param_rep,
            AieTile(TileType.Memtile, col, 0), [MemtilePrmPingAddr], CorePrmSize * AieRows,
            [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 0)), CorePrmWords * AieRows)],
            [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 0)), CorePrmWords, offset = 0 * CorePrmWords),
             TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 1)), CorePrmWords, offset = 1 * CorePrmWords),
             TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 2)), CorePrmWords, offset = 2 * CorePrmWords),
             TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 3)), CorePrmWords, offset = 3 * CorePrmWords)]
        ) for col in range(AieCols*AieInst)
    ] + [
        DataTransfer(
            qdq_param_rep,
            AieTile(TileType.Memtile, col, 0), [MemtileQdqPrmAddr], CoreQdqPrmSize,
            [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 0)), MemtileQdqPrmWords)],
            [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 4)), MemtileQdqPrmWords)]
        ) for col in range(0, AieCols*AieInst, AieInst)
    ] + [
        DataTransfer(
            bias_rep1, 
            AieTile(TileType.Memtile, col, 0), [MemtileBiasPingAddr], ShimBiasSize * biasRepetition ,
            [access_memtile_wgt_s2mm(memtile_dma(col, DmaDir.S2MM, 0), col)],
            [],
            ) for col in range(0, AieCols*AieInst, AieInst)
    ] + [
        DataTransfer(
            bias_rep2, #[0, MemtileBiasRep] ,
            AieTile(TileType.Memtile, col, 0), [MemtileBiasPingAddr], ShimBiasSize * biasRepetition ,
            [],
            [access_memtile_wgt_mm2s(memtile_dma(col, DmaDir.MM2S, 4), col)],
            ) for col in range(0, AieCols*AieInst, AieInst)
    ] + [
    # memtile ifm transfers for mean and variance
        DataTransfer(
            MemtileActRepeat,#[itr, itr]
            AieTile(TileType.Memtile, col, 0), [MemtileInPingAddr, MemtileInPongAddr], MemtileActSize,
            [access_memtile_ifm_s2mm(memtile_dma(col, DmaDir.S2MM, 1), col)],
            [access_memtile_ifm_mm2s(memtile_dma(col, DmaDir.MM2S, row), col, row) for row in range(AieRows)],
            sync_strategy=SyncStrategy.Parallel_1_to_N
        ) for col in range(AieCols*AieInst)
    ] + [
    # memtile out 
        DataTransfer(
            MemtileOutRepeat,
            AieTile(TileType.Memtile, col, 0), [MemtileOutPingAddr, MemtileOutPongAddr], MemtileOutSize,
            [access_memtile_ofm_s2mm(memtile_dma(col, DmaDir.S2MM, 2+row), col, row) for row in range(AieRows)],
            [access_memtile_ofm_mm2s(memtile_dma(col, DmaDir.MM2S, 5), col)],
        ) for col in range(AieCols*AieInst)
    ]
    return memtile_transfers 
