# pylint: skip-file
class OverrideChecker(type):
    def __init__(cls, name, bases, dct):
        super().__init__(name, bases, dct)
        cls._check_overrides(bases, dct)

    def _check_overrides(cls, bases, dct):
        for base in bases:
            # Check for method overrides
            for attr, value in base.__dict__.items():
                if callable(value) and attr in dct:
                    print(f"Method '{attr}' is overridden in {cls.__name__}")

            # Check for instance variable overrides
            base_init = base.__dict__.get('__init__')
            cls_init = dct.get('__init__')
            if base_init and cls_init:
                base_vars = cls._get_instance_vars(base_init)
                cls_vars = cls._get_instance_vars(cls_init)
                for var in base_vars:
                    if var in cls_vars:
                        print(f"Instance variable '{var}' is overridden in {cls.__name__}")

    @staticmethod
    def _get_instance_vars(init_method):
        import inspect
        instance_vars = set()
        if init_method:
            source = inspect.getsource(init_method)
            for line in source.splitlines():
                line = line.strip()
                if line.startswith('self.') and '=' in line:
                    var_name = line.split('=')[0].strip().split('.')[1]
                    instance_vars.add(var_name)
        return instance_vars
class layer_norm_base(metaclass=OverrideChecker):

    def __init__(
            self,
            AieInst,
            AieCols,
            AieRows,
            split_type,
            KernelName,
            Kgran,
            Mlrn,
            Nlrn,
            TdimLayer,
            InOutBytePerElem,
            CoreTilings,
            MemTilings_ifm_s2mm,
            MemTilings_ifm_mm2s,
            ShimTilings
            ):
        self.NORM_DEBUG = 1 # Debug flag
        
        self.AieInst        = AieInst
        self.AieCols        = AieCols
        self.AieRows        = AieRows
        self.KernelName     = KernelName
        self.Mgran          = Kgran[0] 
        self.Ngran          = Kgran[1]
        self.InBytes        = InOutBytePerElem[0] #2 # 1 - int8; 2 - int16
        self.OutBytes       = InOutBytePerElem[1]
        self.ParamElemBytes = 2
        self.CorePrmSize    = 1024
        self.CoreQdqPrmSize = 64
        
        self.split_type     = split_type
        
        self.Mlrn           = Mlrn 
        self.Nlrn           = Nlrn
        self.Mlayer         = listit(TdimLayer)[0]
        self.Nlayer         = listit(TdimLayer)[1]
        self.Msubv          = self.Mgran 

        ######## Shim tilings
        self.ShimTilings    = np.array(listit(ShimTilings)) 
        
        self.ShimActSize    = np.prod(self.ShimTilings, axis=2)*self.InBytes
        self.ShimOutSize    = np.prod(self.ShimTilings, axis=2)*self.OutBytes
        ####### Core Tilings   
        self.CoreTilings    = np.array(listit(CoreTilings)) # 2x4x4x2
        self.Nsubv          = self.CoreTilings[... , 1].tolist() 
        
        self.CoreActSize    = np.prod(self.CoreTilings, axis = -1)*self.InBytes 
        self.CoreOutSize    = np.prod(self.CoreTilings, axis = -1)*self.OutBytes # To accomodate for in-place compute incase of AW8
        ###### Memtile Tilings IFM S2MM and OFM MM2S
        MemTilings_ofm_mm2s         = MemTilings_ifm_s2mm
        self.MemTilingsIfmS2MM      = np.array(listit(MemTilings_ifm_s2mm))
        self.MemTilingsOfmMM2S      = np.array(listit(MemTilings_ofm_mm2s))
        self.MemtileActSizeIfmS2MM  = np.prod(self.MemTilingsIfmS2MM, axis=2)*self.InBytes
        self.MemtileOutSizeOfmMM2S  = np.prod(self.MemTilingsOfmMM2S, axis=2)*self.OutBytes
        ##### Memtile Tilings IFM MM2S and OFM S2MM
        MemTilings_ofm_s2mm         = MemTilings_ifm_mm2s
        self.MemTilingsIfmMM2S      =  np.array(listit(MemTilings_ifm_mm2s))
        self.MemTilingsOfmS2MM      =  np.array(listit(MemTilings_ofm_s2mm))
        self.MemtileActSizeIfmMM2S  = np.prod(self.MemTilingsIfmMM2S, axis=-1)*self.InBytes
        self.MemtileOutSizeOfmS2MM  = np.prod(self.MemTilingsOfmS2MM, axis=-1)*self.OutBytes
        
        self.MemTouter          = np.ceil(np.divide(self.ShimTilings[:,:,0], self.MemTilingsIfmS2MM[:,:,0], out=np.zeros_like(self.ShimTilings[:,:,0].astype(float)), where=self.MemTilingsIfmS2MM[:,:,0]!=0)).astype(int)
        
        mem2core_mask           = np.where(self.CoreActSize>0, 1, 0) # 2x4x4
        self.MemIfmMM2Smask     = mem2core_mask.copy ()
        self.MemOfmS2MMmask     = mem2core_mask.copy ()
        # 2x4x1 #TODO: Change when the MemMsubv is defined a array
        self.CoreTouter         = self.MemIfmMM2Smask*self.MemTouter[:, :, np.newaxis] #self.MemTouter # 2x4 the value in axis 1 is for each column in an instance
        
        self.ShimBiasRep  = 8 # only valid for layer norm
        self.ShimBiasSize = (1*2 * self.Nlrn) * self.ParamElemBytes
        self.MemtileBiasSize  = self.ShimTilings[...,-1]*self.ShimBiasRep*self.ParamElemBytes*2
        if self.NORM_DEBUG:
            print('TdimPadded:', [Mlrn, Nlrn])
            print('TdimLayer:', [self.Mlayer, self.Nlayer])
            print('shim tilings: ', self.ShimTilings.shape, self.ShimTilings)
            print('core tilings: ', self.CoreTilings.shape, self.CoreTilings)
            print('core act size: ', self.CoreActSize.shape, self.CoreActSize)
            print('memtile tilings IFM S2MM: ', self.MemTilingsIfmS2MM.shape, self.MemTilingsIfmS2MM)
            print('memtile act size IfmS2MM: ', self.MemtileActSizeIfmS2MM.shape, self.MemtileActSizeIfmS2MM)
            print('memtile tilings IFM MM2S: ', self.MemTilingsIfmMM2S.shape, self.MemTilingsIfmMM2S)
            print('memtile act size IfmMM2S: ', self.MemtileActSizeIfmMM2S.shape, self.MemtileActSizeIfmMM2S)
            print('memtile bias size: ', self.MemtileBiasSize.shape, self.MemtileBiasSize)
            print('mem2core_mask: ', self.MemIfmMM2Smask.shape, self.MemIfmMM2Smask)
            print('Mem Touter:',self.MemTouter.shape, self.MemTouter)
            print('Core Touter:',self.CoreTouter.shape, self.CoreTouter)
    
    def validate_attributes(self): 
        assert(self.ShimTilings.shape       == (self.AieInst, self.AieCols, 2)) 
        assert(self.ShimActSize.shape       == (self.AieInst, self.AieCols)) 
        assert(self.ShimOutSize.shape       == (self.AieInst, self.AieCols)) 
        assert(self.CoreTilings.shape       == (self.AieInst, self.AieCols, self.AieRows, 2)) 
        assert(self.CoreActSize.shape       == (self.AieInst, self.AieCols, self.AieRows)) 
        assert(self.CoreOutSize.shape       == (self.AieInst, self.AieCols, self.AieRows)) 
        assert(self.MemTilingsIfmS2MM.shape == (self.AieInst, self.AieCols, 2)) 
        assert(self.MemTilingsOfmMM2S.shape == (self.AieInst, self.AieCols, 2)) 
        assert(self.MemTilingsIfmMM2S.shape == (self.AieInst, self.AieCols, self.AieRows, 2)) 
        assert(self.MemTilingsOfmS2MM.shape == (self.AieInst, self.AieCols, self.AieRows, 2)) 
        #assert(self.MemtileRepeatCount.shape in [(self.AieInst, self.AieCols, 1), (self.AieInst, self.AieCols, 2)]) 
        assert(self.CoreTouter.shape        == (self.AieInst, self.AieCols, self.AieRows)) 
    
    def getMemTileRepeatCount(self, max_rep=(config.MAX_REPEAT_COUNT * config.MAX_TASK_QUEUE_SIZE)):

        shim_subv    = self.ShimTilings
        mem_subv   = self.MemTilingsIfmS2MM
        is_divisible = np.mod(shim_subv[:, :, 0], mem_subv[:, :, 0], out=np.zeros_like(shim_subv[:, :, 0]), where=mem_subv[:, :, 0]>0) == 0 #, False, True)
        is_residual = ~is_divisible
        # Select output based on divisibility
        if np.all(is_divisible):
            # Case 1: shim_subv[:,:,0] is a multiple of mem_subv[:,:,0]
            rep_cnt = np.floor_divide(shim_subv[:, :, 0] , mem_subv[:, :, 0], out=np.zeros_like(shim_subv[:, :, 0]), where=mem_subv[:, :, 0]>0)
            rep_cnt_enq = re_enqueue_mem(rep_cnt, max_rep)
        else:
            # Case 2: Not a multiple
            rc_0 = np.floor_divide(shim_subv[:, :, 0], mem_subv[:, :, 0], out=np.zeros_like(shim_subv[:, :, 0]), where=mem_subv[:, :, 0]>0)
            rc_0_enq = re_enqueue_mem(rc_0, max_rep)
            rc_1 = np.where(is_divisible, 0, 1)[...,np.newaxis]
            rep_cnt_enq = np.concatenate([rc_0_enq, rc_1], axis=-1)  # shape (inst, col, 2)
        return rep_cnt_enq, is_residual
    
    def get_core_instrs():
        pass
    


############------- Memtile Transfers ----------##################

    def get_memtile_param_transfers(self, param_repeat=1, MemtilePrmPingAddr=0):
        AieCols     = self.AieCols
        AieRows     = self.AieRows
        AieInst     = self.AieInst
        CorePrmSize = self.CorePrmSize
        CorePrmWords    = CorePrmSize//4
        rep_cnt = [param_repeat] + [0]*(self.MemtileRepeatCount.shape[-1] - 1)
        # Unicast param transfer
        DT = [
            DataTransfer(
                rep_cnt,
                AieTile(TileType.Memtile, col, 0), [MemtilePrmPingAddr], CorePrmSize*AieRows,
                [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 0)), CorePrmWords*AieRows)],
                [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 0)), CorePrmWords, offset = 0 * CorePrmWords),
                 TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 1)), CorePrmWords, offset = 1 * CorePrmWords),
                 TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 2)), CorePrmWords, offset = 2 * CorePrmWords),
                 TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 3)), CorePrmWords, offset = 3 * CorePrmWords)]
            ) for col in range(AieCols*AieInst)
        ]
        return DT 
    
    def get_memtile_qdq_transfers(self, qdq_repeat, MemtileQdqPrmAddr):
        AieCols     = self.AieCols
        AieRows     = self.AieRows
        AieInst     = self.AieInst
        QDQ_col_step = 2 if AieInst*AieCols==8 else 1
        CoreQdqPrmSize  = self.CoreQdqPrmSize
        MemtileQdqPrmWords  = CoreQdqPrmSize // 4
        rep_cnt = [qdq_repeat] + [0]*(self.MemtileRepeatCount.shape[-1] - 1) 
        DT = [
            DataTransfer(
                rep_cnt,
                AieTile(TileType.Memtile, col, 0), [MemtileQdqPrmAddr], CoreQdqPrmSize,
                [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 0)), MemtileQdqPrmWords)],
                [TransferParams(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 4)), MemtileQdqPrmWords)]
            ) for col in range(0, AieCols*AieInst, QDQ_col_step)
        ]
        return DT 
    
    def get_memtile_bias_transfers(self,  MemtileBiasAddr: List[int], bias_repeat = 1) -> List[DataTransfer] :
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Nlrn            = self.Nlrn
        ParamElemBytes       = self.ParamElemBytes    
        BiasColStep     = self.BiasColStep 
        Nsubv           = self.Nsubv 
        CoreBiasSize    = self.CoreBiasSize   
        ShimBiasRep     = self.ShimBiasRep
        ShimTilings = self.ShimTilings.tolist()
        MemtileBiasSize = np.max(self.MemtileBiasSize).item()
    
        bias_words_lambda = lambda x: x//4 
        CoreBiasWords = np.vectorize(bias_words_lambda)(np.array(CoreBiasSize)).tolist() 
        rep_cnt = [bias_repeat] + [0]*(self.MemtileRepeatCount.shape[-1] - 1) 
        
        def access_memtile_bias_s2mm(
                dma: AieDma,
                col: int
        ) -> TransferParams:
            inst_idx = col // AieCols
            col_idx  = col % AieCols
            buffer_bytes = ShimTilings[inst_idx][col_idx][1]*2*ShimBiasRep*ParamElemBytes
            buffer_words = buffer_bytes//4 
            s0 = 1
            w0 = 8*ParamElemBytes // 4
            s1 = (ShimBiasRep* 2 * 8 * ParamElemBytes) // 4
            w1 = (ShimTilings[inst_idx][col_idx][1]) // 8
            s2 = (ShimBiasRep* 8 * ParamElemBytes) // 4
            w2 = 2
            s3 = 8*ParamElemBytes // 4
            offset_words = 0
            return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2, s3], wrap=[w0, w1, w2])
    
        def access_memtile_bias_mm2s(
                dma: AieDma,
                stream: StreamType, 
                col: int,
                row: int
        ) -> TransferParams:
            
            inst_idx = col// self.AieCols
            col_idx = col % self.AieCols # wraps for every instance
            if stream == StreamType.BC : 
                buffer_words = CoreBiasWords[inst_idx][col_idx][0]
                s0 = 1
                w0 = 8*8*ParamElemBytes // 4
                s1 = 8*8*2*ParamElemBytes // 4
                w1 = Nlrn // 8
                s2 = 8*8*ParamElemBytes // 4
                offset_words = 0
                TP = TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2], wrap=[w0, w1])
            else:
                buffer_words = CoreBiasWords[inst_idx][col_idx][row]
                s0 = 1
                w0 = 8*8*ParamElemBytes // 4
                s1 = 8*8*2*ParamElemBytes // 4
                w1 = Nsubv[inst_idx][col_idx][row] // 8
                s2 = 8*8*ParamElemBytes // 4
                offset_words = sum(CoreBiasWords[inst_idx][col_idx][:row])
                TP = TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1, s2], wrap=[w0, w1])
            return TP 
        
        
        access_s2mm = [access_memtile_bias_s2mm(memtile_dma(col, DmaDir.S2MM, 0), col) for col in range(0, AieCols*AieInst, BiasColStep)]
        if (self.BiasStreamType == StreamType.BC): 
            access_mm2s = [[access_memtile_bias_mm2s(memtile_dma(col, DmaDir.MM2S, 4), self.BiasStreamType, col, None)] for col in range(0, AieCols*AieInst, BiasColStep)]
        else:
            access_mm2s = [[access_memtile_bias_mm2s(memtile_dma(col, DmaDir.MM2S, row), self.BiasStreamType, col, row) for row in range(AieRows)] for col in range(0, AieCols*AieInst)]
        DT = [
               DataTransfer(
                   rep_cnt,
                   AieTile(TileType.Memtile, col, 0), MemtileBiasAddr, MemtileBiasSize,
                   [access_s2mm[col//BiasColStep]],
                   [*access_mm2s[col//BiasColStep]], # unpack list of access patterns
               ) for col in range(0, AieCols*AieInst, BiasColStep)
             ]
        
        return DT
    
    def get_memtile_data_transfers(self):
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        InBytes       = self.InBytes    
        OutBytes        = self.OutBytes 
        Msubv           = self.Msubv

        CorePrmSize     = self.CorePrmSize  
        CoreQdqPrmSize  = self.CoreQdqPrmSize
        CoreActSize     = self.CoreActSize.tolist()  
        CoreOutSize     = self.CoreOutSize.tolist()
        MemtileActSizeIfmS2MM  = self.MemtileActSizeIfmS2MM.tolist()
        MemtileActSizeIfmMM2S  = self.MemtileActSizeIfmMM2S.tolist()
        MemtileOutSizeOfmS2MM  = self.MemtileOutSizeOfmS2MM.tolist()
        MemtileOutSizeOfmMM2S  = self.MemtileOutSizeOfmMM2S.tolist()
        
        MemtileMinIfmSize  = Msubv*Nlrn*AieRows*self.InBytes
        MemtileMinOfmSize  = Msubv*Nlrn*AieRows*self.OutBytes

        MemIfmMM2Smask = self.MemIfmMM2Smask.tolist()
        MemOfmS2MMmask = self.MemOfmS2MMmask.tolist()
        
        MemtileActSizeMax   = self.MemtileActSizeMax   
        MemtileOutSizeMax   = self.MemtileOutSizeMax   
        MemtilePrmPingAddr  = self.MemtilePrmPingAddr  
        MemtileInPingAddr   = self.MemtileInPingAddr   
        MemtileInPongAddr   = self.MemtileInPongAddr   
        MemtileOutPingAddr  = self.MemtileOutPingAddr  
        MemtileOutPongAddr  = self.MemtileOutPongAddr  
        MemtileQdqPrmAddr   = self.MemtileQdqPrmAddr   
        
        CorePrmWords        = CorePrmSize // 4
        CoreActWords        = (self.CoreActSize // 4).tolist()
        CoreOutWords        = (self.CoreOutSize // 4).tolist()
        MemtileQdqPrmWords  = CoreQdqPrmSize // 4
        
        MemtileActRepeat = self.MemTouter.tolist()
        MemtileRepeatCount = self.MemtileRepeatCount.tolist()
        MemTilingsIfmS2MM = self.MemTilingsIfmS2MM.tolist()
        MemTilingsIfmMM2S = self.MemTilingsIfmMM2S.tolist()
        MemTilingsOfmS2MM = self.MemTilingsOfmS2MM.tolist()
        MemTilingsOfmMM2S = self.MemTilingsOfmMM2S.tolist()
        
        def get_memtile_ifm_transfers() -> List[DataTransfer] :
            # NOTE: The memtile activation access pattern re-arranges a
            #       row major shard to w8 shard on the input side,
            #       and re-arranges a w8 shard to a row-major shard
            #       on the output side.

            ifm_list = [MemtileInPingAddr]
            if MemtileInPongAddr is not None:
                ifm_list.append(MemtileInPongAddr)
    
            DT = [
                    DataTransfer(
                        MemtileRepeatCount[col // AieCols][col % AieCols],
                        AieTile(TileType.Memtile, col, 0), ifm_list, MemtileActSizeMax,
                        [self.access_memtile_ifm_s2mm(memtile_dma(col, DmaDir.S2MM, 1), col, AccessFormat.Linear)],
                        [self.access_memtile_ifm_mm2s(memtile_dma(col, DmaDir.MM2S, row), col, row, AccessFormat.W8) for row in range(AieRows) if MemIfmMM2Smask[col // AieCols][col % AieCols][row] ],
                        sync_strategy=SyncStrategy.Parallel_1_to_N
                    ) for col in range(AieCols*AieInst) if MemtileActRepeat[col // AieCols][col % AieCols]
                ] 
            return DT
        
        def get_memtile_ofm_transfers() -> List[DataTransfer] :

            ofm_list = [MemtileOutPingAddr]
            if MemtileOutPongAddr is not None:
                ofm_list.append(MemtileOutPongAddr)
    
            DT = [ 
                    DataTransfer(
                    MemtileRepeatCount[col // AieCols][col % AieCols],
                    AieTile(TileType.Memtile, col, 0), ofm_list, MemtileActSizeMax,
                    [self.access_memtile_ofm_s2mm(memtile_dma(col, DmaDir.S2MM, 2+row), col, row, AccessFormat.W8) for row in range(AieRows) if MemOfmS2MMmask[col // AieCols][col % AieCols][row] ],
                    [self.access_memtile_ofm_mm2s(memtile_dma(col, DmaDir.MM2S, 5), col, AccessFormat.Linear)],
                    ) for col in range(AieCols*AieInst) if MemtileActRepeat[col // AieCols][col % AieCols]
                ]
            return DT
        
        memtile_transfers = get_memtile_ifm_transfers() + \
                            get_memtile_ofm_transfers()
    
        return memtile_transfers 

    def get_memtile_buffer_words(self, rc, is_residual, memsv_words, shimsv_words):
        buffer_words      = [memsv_words*(x!=0) for x in rc[:-1]]
        if(is_residual):
            residual_words    = shimsv_words - sum([x*y for x,y in zip(buffer_words, rc[:-1])]) 
            buffer_words.append(residual_words)
        else:
            buffer_words.append(memsv_words)
        return buffer_words

    def access_memtile_ifm_s2mm(
            self,
            dma: AieDma,
            col: int,
            fmt: AccessFormat = AccessFormat.W8,
    ) -> TransferParams:
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        InBytes       = self.InBytes    
        OutBytes        = self.OutBytes 
        Msubv           = self.Msubv
        inst            = col // AieCols
        col_idx         = col % AieCols
        
        CoreActSize     = self.CoreActSize.tolist()  
        MemtileActSizeIfmS2MM  = self.MemtileActSizeIfmS2MM.tolist()
        ShimActSize     =   self.ShimActSize
        
        CoreActWords        = (self.CoreActSize // 4).tolist()
        
        MemTilingsIfmS2MM = self.MemTilingsIfmS2MM.tolist()
        MemTilingsIfmMM2S = self.MemTilingsIfmMM2S.tolist()
        MemtileRepeatCount = self.MemtileRepeatCount.tolist()
        
        is_residual       = self.is_residual_mem[inst][col_idx]
        rc                = MemtileRepeatCount[inst][col_idx][:]
        memsv_words  = MemtileActSizeIfmS2MM[inst][col_idx] // 4
        shimsv_words = ShimActSize[inst][col_idx] // 4
        buffer_words = self.get_memtile_buffer_words(rc, is_residual, memsv_words, shimsv_words)
        
        offset_words = 0
        step_list, wrap_list = [], []
        
        if self.split_type == SplitType.RowSplit:
            if(fmt == AccessFormat.W8):
                (s0, s1, s2, s3, w0, w1, w2) = access_w8_subvolume(Msubv*AieRows, Nlrn, InBytes)
                if(s3 is not None and w2 is not None):
                    step_list += [s0, s1, s2, s3]
                    wrap_list += [w0, w1, w2]
                else:
                    step_list += [s0, s1, s2]
                    wrap_list += [w0, w1]
                return TransferParams(dma, buffer_words, offset=offset_words, step=step_list, wrap=wrap_list)
            else:
                s0 = 1
                w0 = MemTilingsIfmS2MM[inst][col_idx][1] *InBytes // 4
                s1 = MemTilingsIfmMM2S[inst][col_idx][0][1] *InBytes // 4
                return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1], wrap=[w0])
        else: 
                nsubv_col     =  MemTilingsIfmS2MM[inst][col_idx][1]
                (s0, s1, s2, s3, w0, w1, w2) = access_w8_subvolume(Msubv, nsubv_col, InBytes)
                if(s3 is not None and w2 is not None):
                    step_list += [s0, s1, s2, s3]
                    wrap_list += [w0, w1, w2]
                else:
                    step_list += [s0, s1, s2]
                    wrap_list += [w0, w1]
                return TransferParams(dma, buffer_words, offset=offset_words, step=step_list, wrap=wrap_list)
    
    def access_memtile_ifm_mm2s(
            self,
            dma: AieDma,
            col: int,
            row: int,
            fmt: AccessFormat = AccessFormat.W8,
    ) -> TransferParams:
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        InBytes       = self.InBytes    
        OutBytes        = self.OutBytes 
        Msubv           = self.Msubv
        
        CoreActSize     = self.CoreActSize.tolist()  
        MemtileActSizeIfmS2MM  = self.MemtileActSizeIfmS2MM.tolist()
        
        CoreActWords        = (self.CoreActSize // 4).tolist()
        
        MemTilingsIfmS2MM = self.MemTilingsIfmS2MM.tolist()
        MemTilingsIfmMM2S = self.MemTilingsIfmMM2S.tolist()
        
        buffer_words = CoreActWords[col // AieCols][col % AieCols][row] 
        subv_rows = MemTilingsIfmMM2S[col // AieCols][col % AieCols][row][0]
        subv_cols = MemTilingsIfmMM2S[col // AieCols][col % AieCols][row][1]
        
        if self.split_type == SplitType.RowSplit:
            if(fmt == AccessFormat.W8): #read in w8 format
                (s0, s1, s2,
                 w0, w1,
                 ) = access_w8_rd(subv_rows, subv_cols, InBytes)
                offset_words = row*subv_rows*subv_cols*InBytes // 4 
                return TransferParams(dma, buffer_words, offset=offset_words, step = [s0, s1, s2], wrap = [w0, w1])
            else: #read the data that is already in w8 format
                s0 = 1
                w0 = Msubv*8*InBytes // 4
                s1 = Msubv*AieRows*8*InBytes // 4
                offset_words = row*Msubv*8*InBytes // 4 
                return TransferParams(dma, buffer_words, offset=offset_words, step = [s0, s1], wrap = [w0])
        else:
            buffer_words = CoreActWords[col // AieCols][col % AieCols][row] 
    
            offset_words = sum(CoreActWords[col // AieCols][col % AieCols][:row])
            return TransferParams(dma, buffer_words, offset=offset_words)

    def access_memtile_ofm_s2mm(
            self,
            dma: AieDma,
            col: int,
            row: int,
            fmt: AccessFormat = AccessFormat.W8,
    ) -> TransferParams:
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        OutBytes        = self.OutBytes 
        Msubv           = self.Msubv
        
        MemTilingsOfmS2MM = self.MemTilingsOfmS2MM.tolist()
        MemTilingsOfmMM2S = self.MemTilingsOfmMM2S.tolist()
        
        CoreOutWords        = (self.CoreOutSize // 4).tolist()
        buffer_words = CoreOutWords[col // AieCols][col % AieCols][row]
        
        if self.split_type == SplitType.RowSplit:
            subv_rows = MemTilingsOfmS2MM[col // AieCols][col % AieCols][row][0]
            subv_cols = MemTilingsOfmS2MM[col // AieCols][col % AieCols][row][1]
            if(fmt == AccessFormat.W8): # undo the w8 format
                (s0, s1, s2,
                 w0, w1,
                 ) = access_w8_rd(subv_rows, subv_cols, OutBytes)
                offset_words = row*subv_rows*subv_cols*OutBytes // 4
                return TransferParams(dma, buffer_words, offset=offset_words, step = [s0, s1, s2], wrap = [w0, w1])
            else: # write linearly
                s0 = 1
                w0 = Msubv*8*OutBytes // 4
                s1 = Msubv*AieRows*8*OutBytes // 4
                offset_words = row*Msubv*8*OutBytes // 4
                return TransferParams(dma, buffer_words, offset=offset_words, step = [s0, s1], wrap = [w0])
        else:
            offset_words = sum(CoreOutWords[col // AieCols][col % AieCols][:row])
            return TransferParams(dma, buffer_words, offset=offset_words)
    
    def access_memtile_ofm_mm2s(
            self,
            dma: AieDma,
            col: int,
            fmt: AccessFormat = AccessFormat.W8,
    ) -> TransferParams:
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        OutBytes        = self.OutBytes 
        Msubv           = self.Msubv
        inst            = col // AieCols
        col_idx         = col % AieCols
        
        MemtileOutSizeOfmS2MM   = self.MemtileOutSizeOfmS2MM.tolist()
        MemtileOutSizeOfmMM2S   = self.MemtileOutSizeOfmMM2S.tolist()
        ShimOutSize             =   self.ShimOutSize
        
        MemTilingsOfmS2MM = self.MemTilingsOfmS2MM.tolist()
        MemTilingsOfmMM2S = self.MemTilingsOfmMM2S.tolist()
        MemtileRepeatCount = self.MemtileRepeatCount.tolist()
        
        CoreOutWords        = (self.CoreOutSize // 4).tolist()
        out_words           = MemtileOutSizeOfmMM2S[inst][col_idx] // 4
        out_words_residual  = (ShimOutSize[inst][col_idx] - (MemtileOutSizeOfmMM2S[inst][col_idx])*MemtileRepeatCount[inst][col_idx][0]) // 4
        if (self.MemtileRepeatCount.shape[-1] - 1 ) :
            # Re-config the BD to tranfer the residual data
            buffer_words  = [out_words, out_words_residual]
        else:
            buffer_words  = out_words

        is_residual       = self.is_residual_mem[inst][col_idx]
        rc            = MemtileRepeatCount[inst][col_idx][:]
        memsv_words   = MemtileOutSizeOfmMM2S[inst][col_idx] // 4
        shimsv_words  = ShimOutSize[inst][col_idx] // 4
        buffer_words  = self.get_memtile_buffer_words(rc, is_residual, memsv_words, shimsv_words)
        step_list, wrap_list = [], []

        if self.split_type == SplitType.RowSplit:
            offset_words = 0
            if(fmt == AccessFormat.W8): # undo the w8 format
                (s0, s1, s2, s3, w0, w1, w2) = access_w8_subvolume(Msubv*AieRows, Nlrn, OutBytes)
                if(s3 is not None and w2 is not None):
                    step_list += [s0, s1, s2, s3]
                    wrap_list += [w0, w1, w2]
                else:
                    step_list += [s0, s1, s2]
                    wrap_list += [w0, w1]
                return TransferParams(dma, buffer_words, offset=offset_words, step=step_list, wrap=wrap_list)
            else:
                s0 = 1
                w0 = MemTilingsOfmMM2S[inst][col_idx][1] *OutBytes // 4
                s1 = MemTilingsOfmS2MM[inst][col_idx][0][1] *OutBytes // 4
                return TransferParams(dma, buffer_words, offset=offset_words, step=[s0, s1], wrap=[w0])
        else:
            nsubv_col    =  MemTilingsOfmMM2S[inst][col_idx][1] 
            offset_words = 0
            (s0, s1, s2, s3, w0, w1, w2) = access_w8_subvolume(Msubv, nsubv_col, OutBytes)
            if(s3 is not None and w2 is not None):
                step_list += [s0, s1, s2, s3]
                wrap_list += [w0, w1, w2]
            else:
                step_list += [s0, s1, s2]
                wrap_list += [w0, w1]
            return TransferParams(dma, buffer_words, offset=offset_words, step=step_list, wrap=wrap_list)


############------- Shim Transfers ----------##################
    def get_shim_param_transfers(self, param_repeat=1, param_addr = 3):
        AieCols     = self.AieCols
        AieRows     = self.AieRows
        AieInst     = self.AieInst
        CorePrmSize = self.CorePrmSize
        CorePrmWords    = CorePrmSize//4

        rep_cnt = [param_repeat] + [0]*(self.MemtileRepeatCount.shape[-1] - 1) 
        
        DT = [
            DataTransfer(
                rep_cnt,
                AieTile(TileType.Shim, col, 0), [param_addr], CorePrmSize * AieRows,
                [],
                [TransferParams(AieDma(AieTile(TileType.Shim, col, 0), DmaChannel(DmaDir.MM2S, 0)), CorePrmWords * AieRows, offset = col * (CorePrmWords * AieRows))]
            ) for col in range(AieCols*AieInst)
        ]
        return DT
    def get_shim_qdq_transfers(self, offset_size, qdq_repeat=1, qdq_addr = 2):
        AieCols     = self.AieCols
        AieRows     = self.AieRows
        AieInst     = self.AieInst
        QDQ_col_step = 2 if AieInst*AieCols==8 else 1
        CoreQdqPrmSize  = self.CoreQdqPrmSize
        ShimQdqPrmWords  = CoreQdqPrmSize // 4
        offset_words    = offset_size // 4

        rep_cnt = [qdq_repeat] + [0]*(self.MemtileRepeatCount.shape[-1] - 1) 

        DT = [
            DataTransfer(
                rep_cnt,
                AieTile(TileType.Shim, col, 0), [qdq_addr], CoreQdqPrmSize,
                [],
                [TransferParams(AieDma(AieTile(TileType.Shim, col, 0), DmaChannel(DmaDir.MM2S, 0)), ShimQdqPrmWords,
                offset = offset_words)]
            ) for col in range(0, AieCols*AieInst, QDQ_col_step)
        ]
        return DT
    def get_shim_bias_transfers(self, bias_repeat=8, bias_addr = 2):
        AieCols     = self.AieCols
        AieRows     = self.AieRows
        AieInst     = self.AieInst
        BiasColStep     = self.BiasColStep
        ShimBiasSize = self.ShimBiasSize
        ShimBiasWords = self.ShimBiasSize // 4
        rep_cnt = [bias_repeat] + [0]*(self.MemtileRepeatCount.shape[-1] - 1) 
        split_type      = self.split_type
        Nlrn        = self.Nlrn
        ParamElemBytes  = self.ParamElemBytes    
        ShimTilings     = self.ShimTilings.tolist() 

        def access_shim_bias_mm2s(
                dma: AieDma,
                col
        ) -> TransferParams:

            inst_idx = col // AieCols
            col_idx  = col % AieCols

            buffer_bytes = ShimTilings[inst_idx][col_idx][1]*2*ParamElemBytes
            step_list, wrap_list = shim_step_wrap(ShimTilings[inst_idx][col_idx][1], ParamElemBytes, Nlrn)

            if split_type == SplitType.ColSplit:
                offset_bytes = np.sum(self.ShimTilings[inst_idx][:col_idx, 1]).item() * ParamElemBytes
            else: 
                offset_bytes = 0 
            offset_words = bytes_to_words(offset_bytes)
            buffer_words = bytes_to_words(buffer_bytes)
            return TransferParams(dma, buffer_words, offset=offset_words, step=step_list, wrap=wrap_list)

        DT = [
            DataTransfer(
                rep_cnt,
                AieTile(TileType.Shim, col, 0), [bias_addr], ShimBiasSize,
                [],
                [access_shim_bias_mm2s(shim_dma(col, DmaDir.MM2S, 0), col)]
            ) for col in range(0, AieCols*AieInst, BiasColStep)
        ]
        return DT
    
    def get_shim_data_transfers(self, TT: NormTransferType):
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        InBytes       = self.InBytes    
        OutBytes        = self.OutBytes 
        ShimTilings     = self.ShimTilings.tolist() 
        ShimActSize     = self.ShimActSize.tolist() 
        ShimOutSize     = self.ShimOutSize.tolist()
        split_type      = self.split_type

        ShimActRepeat  = np.where(self.ShimActSize>0,1,0).tolist()
        ShimOutRepeat  = np.where(self.ShimOutSize>0,1,0).tolist()
        
     
        def access_shim_ifm_mm2s(
                dma: AieDma,
                col
        ) -> TransferParams:
            
            inst_idx = col // AieCols
            col_idx  = col % AieCols

            buffer_bytes = ShimActSize[inst_idx][col_idx]

            if split_type == SplitType.ColSplit:
                step_list, wrap_list = shim_step_wrap(ShimTilings[inst_idx][col_idx][1], InBytes, Nlrn)
                
                if(col<AieCols):
                    offset_bytes = np.sum(self.ShimTilings[inst_idx][:col_idx, 1]).item() * InBytes
                else:
                    offset_4x4 = np.sum(self.ShimActSize[0,:]).item()
                    offset_bytes = np.sum(self.ShimTilings[inst_idx,0:col_idx, 1]).item() * InBytes + offset_4x4 
                offset_words = bytes_to_words(offset_bytes)
                buffer_words = bytes_to_words(buffer_bytes)
                return TransferParams(dma, buffer_words, offset=offset_words, step=step_list, wrap=wrap_list)
            else:
                offset_bytes = sum(ShimActSize[inst_idx][:col_idx]) 
                offset_words = bytes_to_words(offset_bytes)
                buffer_words = bytes_to_words(buffer_bytes)
                return TransferParams(dma, buffer_words, offset=offset_words)
        
        def access_shim_ofm_s2mm(
                dma: AieDma,
                col
        ) -> TransferParams:
            inst_idx = col // AieCols
            col_idx  = col % AieCols
            buffer_bytes = ShimOutSize[inst_idx][col_idx]  
            
            if split_type == SplitType.ColSplit:
                step_list, wrap_list = shim_step_wrap(ShimTilings[inst_idx][col_idx][1], OutBytes, Nlrn)
                
                if(col<AieCols):
                    offset_bytes = np.sum(self.ShimTilings[inst_idx][:col_idx, 1]).item() * OutBytes
                else:
                    offset_4x4 = np.sum(self.ShimOutSize[0,:]).item()
                    offset_bytes = np.sum(self.ShimTilings[inst_idx,0:col_idx, 1]).item() * OutBytes + offset_4x4 
                
                offset_words = bytes_to_words(offset_bytes)
                buffer_words = bytes_to_words(buffer_bytes)
                return TransferParams(dma, buffer_words, offset=offset_words, step=step_list, wrap=wrap_list)
            else:
                offset_bytes =  sum(ShimOutSize[inst_idx][:col_idx]) 
                offset_words = bytes_to_words(offset_bytes)
                buffer_words = bytes_to_words(buffer_bytes)
                return TransferParams(dma, buffer_words, offset=offset_words)
        
        if TT == NormTransferType.IfmMM2S: 
            shim_transfers =    [ 
                                DataTransfer(
                                    [ShimActRepeat[col // AieCols][col % AieCols]] + [0]*(self.MemtileRepeatCount.shape[-1] - 1),
                                    AieTile(TileType.Shim, col, 0), [1], ShimActSize[col // AieCols][col % AieCols],
                                    [],
                                    [access_shim_ifm_mm2s(shim_dma(col, DmaDir.MM2S, 1), col)],
                                    ) for col in range(AieCols*AieInst)
                                ]
        else:
            shim_transfers =    [ 
                                DataTransfer(
                                    [ShimOutRepeat[col // AieCols][col % AieCols]] + [0]*(self.MemtileRepeatCount.shape[-1] - 1),
                                    AieTile(TileType.Shim, col, 0), [0], ShimOutSize[col // AieCols][col % AieCols],
                                    [access_shim_ofm_s2mm(shim_dma(col, DmaDir.S2MM, 0), col)],
                                    []
                                    ) for col in range(AieCols*AieInst)
                                ]
        return shim_transfers



############################################################

