# pylint: skip-file
import numpy as np
class layer_norm_col_split(layer_norm_base):
    def __init__(
            self,
            AieInst,
            AieCols,
            AieRows,
            split_type,
            KernelName,
            Kgran,
            Mlrn,
            Nlrn,
            TdimLayer,
            InOutBytePerElem,
            CoreTilings,
            MemTilings_ifm_s2mm,
            MemTilings_ifm_mm2s,
            ShimTilings
            ):
        super().__init__(
                        AieInst,
                        AieCols,
                        AieRows,
                        split_type,
                        KernelName,
                        Kgran,
                        Mlrn,
                        Nlrn,
                        TdimLayer,
                        InOutBytePerElem,
                        CoreTilings,
                        MemTilings_ifm_s2mm,
                        MemTilings_ifm_mm2s,
                        ShimTilings
                        ) 
        print('Invoking col split schedule ........')
        
        self.BiasStreamType = StreamType.UC
        bias_size_lambda = lambda x: 8*2*x*self.ParamElemBytes
        self.CoreBiasSize = np.vectorize(bias_size_lambda)(np.array(self.Nsubv)).tolist() 
        self.BiasColStep = 1 
        self.validate_attributes()
    
    def get_core_instrs(self, col, row):
        c = col
        r = row-2
        inst_idx = col// self.AieCols
        Nsubv = self.Nsubv
        nbytes = self.ParamElemBytes
        col_idx = col % self.AieCols # wraps for every instance
        gamma_offset = 0 
        beta_offset = (Nsubv[inst_idx][col_idx][r] * nbytes * 8) + gamma_offset 
        return self.layernorm_core_instrs(inst_idx, col_idx, r, gamma_offset, beta_offset)
    
    def get_memtile_transfers(self):
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        split_type      = self.split_type
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        InBytes         = self.InBytes   
        OutBytes        = self.OutBytes 
        Msubv           = self.Msubv
        BiasColStep     = self.BiasColStep 
        
        CorePrmSize     = self.CorePrmSize
        CoreQdqPrmSize  = self.CoreQdqPrmSize
        MemtileActSizeIfmS2MM  = self.MemtileActSizeIfmS2MM.tolist()
        MemtileActSizeIfmMM2S  = self.MemtileActSizeIfmMM2S.tolist()
        MemtileOutSizeOfmS2MM  = self.MemtileOutSizeOfmS2MM.tolist()
        MemtileOutSizeOfmMM2S  = self.MemtileOutSizeOfmMM2S.tolist()
        
        if (self.split_type == SplitType.RowSplit):
            MemtileMinIfmSize  = Msubv*Nlrn*AieRows*self.InBytes
            MemtileMinOfmSize  = Msubv*Nlrn*AieRows*self.OutBytes
            self.MemtileActSizeMax = max(np.max((self.MemtileActSizeIfmS2MM, np.sum(self.MemtileActSizeIfmMM2S, -1))).item(),    MemtileMinIfmSize) 
            self.MemtileOutSizeMax = max(np.max((np.sum(self.MemtileOutSizeOfmS2MM,-1), self.MemtileOutSizeOfmMM2S)).item(),     MemtileMinOfmSize) 
        else:
            self.MemtileActSizeMax = np.max((self.MemtileActSizeIfmS2MM, np.sum(self.MemtileActSizeIfmMM2S, -1))).item()
            self.MemtileOutSizeMax = np.max((np.sum(self.MemtileOutSizeOfmS2MM,-1), self.MemtileOutSizeOfmMM2S)).item()
        
        self.MemtileBiasSize = self.ShimBiasSize * self.ShimBiasRep  
    
        self.MemtilePrmPingAddr  = 0
        self.MemtileBiasPingAddr = self.MemtilePrmPingAddr  + CorePrmSize*AieRows
        self.MemtileInPingAddr   = self.MemtileBiasPingAddr + self.MemtileBiasSize
        self.MemtileInPongAddr   = self.MemtileInPingAddr   + self.MemtileActSizeMax
        self.MemtileOutPingAddr  = self.MemtileInPongAddr   + self.MemtileActSizeMax
        self.MemtileOutPongAddr  = self.MemtileOutPingAddr  + self.MemtileOutSizeMax
        self.MemtileQdqPrmAddr   = self.MemtileOutPongAddr  + self.MemtileOutSizeMax
        
        
        memtile_transfers = self.get_memtile_param_transfers(1, self.MemtilePrmPingAddr) + \
                            self.get_memtile_qdq_transfers(1, self.MemtileQdqPrmAddr) + \
                            self.get_memtile_bias_transfers([self.MemtileBiasPingAddr]) + \
                            self.get_memtile_data_transfers()
    
        return memtile_transfers 
    def get_shim_transfers(self):
        ShimBiasSize    = self.ShimBiasSize  
        shim_transfers =    self.get_shim_param_transfers() + \
                            self.get_shim_qdq_transfers(ShimBiasSize) + \
                            self.get_shim_bias_transfers() + \
                            self.get_shim_data_transfers(NormTransferType.IfmMM2S) + \
                            self.get_shim_data_transfers(NormTransferType.OfmS2MM)
        return shim_transfers

class layer_norm(layer_norm_col_split):
    pass

