# pylint: disable=E0606
from layernorm_tiler import LayerNormTiler

class LpNormTiler(LayerNormTiler):
    def get_split_type(self, kernel):
        N = self.tensorshape['ifm'][1]
        max_n_rs   = self.max_ifm_rs // kernel.Mgran

        max_n_cs   = (self.max_ifm_cs // kernel.Mgran)*(self.overlay.rows**2)
        
        if kernel.SplitType == 'OPTIMAL' :
            if N <= max_n_rs :
                st = 'ROW_SPLIT'
            else:
                st = 'COL_SPLIT'
        else : 
            st = kernel.SplitType
        
        return st
