# pylint: skip-file
import numpy as np
class lp_norm_row_split(layer_norm_base):
    def __init__(
            self,
            AieInst,
            AieCols,
            AieRows,
            split_type,
            KernelName,
            Kgran,
            Mlrn,
            Nlrn,
            TdimLayer,
            InOutBytePerElem,
            CoreTilings,
            MemTilings_ifm_s2mm,
            MemTilings_ifm_mm2s,
            ShimTilings
            ):
        super().__init__(
                        AieInst,
                        AieCols,
                        AieRows,
                        split_type,
                        KernelName,
                        Kgran,
                        Mlrn,
                        Nlrn,
                        TdimLayer,
                        InOutBytePerElem,
                        CoreTilings,
                        MemTilings_ifm_s2mm,
                        MemTilings_ifm_mm2s,
                        ShimTilings
                        )
        print('Invoking row split schedule ........')
        self.validate_attributes()
    
    def get_core_instrs(self, col, row):
        c = col
        r = row-2
        inst_idx = col// self.AieCols
        col_idx = col % self.AieCols # wraps for every instance
        return self.lpnorm_core_instrs(inst_idx, col_idx, r)
    
    def get_memtile_transfers(self):
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        InBytes         = self.InBytes    
        OutBytes        = self.OutBytes 
        Msubv           = self.Msubv

        
        CorePrmSize     = self.CorePrmSize  
        CoreQdqPrmSize  = self.CoreQdqPrmSize
        MemtileActSizeIfmS2MM  = self.MemtileActSizeIfmS2MM.tolist()
        MemtileActSizeIfmMM2S  = self.MemtileActSizeIfmMM2S.tolist()
        MemtileOutSizeOfmS2MM  = self.MemtileOutSizeOfmS2MM.tolist()
        MemtileOutSizeOfmMM2S  = self.MemtileOutSizeOfmMM2S.tolist()
       
        if (self.split_type == SplitType.RowSplit):
            MemtileMinIfmSize  = Msubv*Nlrn*AieRows*self.InBytes
            MemtileMinOfmSize  = Msubv*Nlrn*AieRows*self.OutBytes
            self.MemtileActSizeMax = max(np.max((self.MemtileActSizeIfmS2MM, np.sum(self.MemtileActSizeIfmMM2S, -1))).item(),    MemtileMinIfmSize) 
            self.MemtileOutSizeMax = max(np.max((np.sum(self.MemtileOutSizeOfmS2MM,-1), self.MemtileOutSizeOfmMM2S)).item(),     MemtileMinOfmSize) 
        else:
            self.MemtileActSizeMax = np.max((self.MemtileActSizeIfmS2MM, np.sum(self.MemtileActSizeIfmMM2S, -1))).item()
            self.MemtileOutSizeMax = np.max((np.sum(self.MemtileOutSizeOfmS2MM,-1), self.MemtileOutSizeOfmMM2S)).item()

        self.MemtilePrmPingAddr  = 0
        self.MemtileInPingAddr   = self.MemtilePrmPingAddr + CorePrmSize*AieRows
        self.MemtileInPongAddr   = self.MemtileInPingAddr  + self.MemtileActSizeMax 
        self.MemtileOutPingAddr  = self.MemtileInPongAddr  + self.MemtileActSizeMax
        self.MemtileOutPongAddr  = self.MemtileOutPingAddr + self.MemtileOutSizeMax
        self.MemtileQdqPrmAddr   = self.MemtileOutPongAddr + self.MemtileOutSizeMax
        
        memtile_transfers = self.get_memtile_param_transfers(1, self.MemtilePrmPingAddr) + \
                            self.get_memtile_qdq_transfers(1, self.MemtileQdqPrmAddr) + \
                            self.get_memtile_data_transfers()
    
        return memtile_transfers 

    def get_shim_transfers(self):
        shim_transfers =    self.get_shim_param_transfers() + \
                            self.get_shim_qdq_transfers(0) + \
                            self.get_shim_data_transfers(NormTransferType.IfmMM2S) + \
                            self.get_shim_data_transfers(NormTransferType.OfmS2MM)
        return shim_transfers


class layer_norm(lp_norm_row_split):
    pass
