# pylint: skip-file
import numpy as np
class layer_norm_dataflow(layer_norm_base):
    def __init__(
            self,
            AieInst,
            AieCols,
            AieRows,
            split_type,
            KernelName,
            Kgran,
            Mlrn,
            Nlrn,
            TdimLayer,
            InOutBytePerElem,
            CoreTilings,
            MemTilings_ifm_s2mm,
            MemTilings_ifm_mm2s,
            ShimTilings
            ):
        super().__init__(
                        AieInst,
                        AieCols,
                        AieRows,
                        split_type,
                        KernelName,
                        Kgran,
                        Mlrn,
                        Nlrn,
                        TdimLayer,
                        InOutBytePerElem,
                        CoreTilings,
                        MemTilings_ifm_s2mm,
                        MemTilings_ifm_mm2s,
                        ShimTilings
                        )
        print('Invoking split schedule ........')

        self.MemtileRepeatCount, self.is_residual_mem = self.getMemTileRepeatCount(config.MAX_REPEAT_COUNT * config.MAX_TASK_QUEUE_SIZE//2)
        if self.NORM_DEBUG:
            print('MemtileRepeatCount', self.MemtileRepeatCount)
        bias_size_lambda = lambda x: 8*2*x*self.ParamElemBytes
        self.CoreBiasSize = np.vectorize(bias_size_lambda)(np.array(self.Nsubv)).tolist() 
        if (self.split_type == SplitType.ColSplit):
            self.BiasStreamType = StreamType.UC
            self.BiasColStep = 1
        else:
            self.BiasStreamType = StreamType.BC
            self.BiasColStep = 2 if (self.AieCols == 8) else 1 

        self.validate_attributes()
    
    def get_core_instrs(self, col, row):
        c = col
        r = row-2
        inst_idx = col// self.AieCols
        Nlrn = self.Nlrn
        Nsubv = self.Nsubv
        nbytes = self.ParamElemBytes

        col_idx = col % self.AieCols # wraps for every instance
        gamma_offset = 0 
        beta_offset = 0
        if (self.split_type == SplitType.ColSplit):
            beta_offset = (Nsubv[inst_idx][col_idx][r] * nbytes * 8) + gamma_offset 
        else:
            beta_offset = (Nlrn * nbytes* 8) + gamma_offset 
        return self.layernorm_core_instrs(inst_idx, col_idx, r, gamma_offset, beta_offset)
    
    def get_memtile_transfers(self):
        
        AieInst         = self.AieInst
        AieCols         = self.AieCols
        AieRows         = self.AieRows
        split_type      = self.split_type
        Mlrn            = self.Mlrn     
        Nlrn            = self.Nlrn     
        InBytes         = self.InBytes    
        OutBytes        = self.OutBytes 
        Nsubv           = self.Nsubv
        Msubv           = self.Msubv
        BiasColStep     = self.BiasColStep 
        
        CorePrmSize     = self.CorePrmSize
        CoreQdqPrmSize  = self.CoreQdqPrmSize
        MemtileActSizeIfmS2MM  = self.MemtileActSizeIfmS2MM.tolist()
        MemtileActSizeIfmMM2S  = self.MemtileActSizeIfmMM2S.tolist()
        MemtileOutSizeOfmS2MM  = self.MemtileOutSizeOfmS2MM.tolist()
        MemtileOutSizeOfmMM2S  = self.MemtileOutSizeOfmMM2S.tolist()
        
        if (self.split_type == SplitType.RowSplit):
            MemtileMinIfmSize  = Msubv*Nlrn*AieRows*self.InBytes
            MemtileMinOfmSize  = Msubv*Nlrn*AieRows*self.OutBytes
            self.MemtileActSizeMax = max(np.max((self.MemtileActSizeIfmS2MM, np.sum(self.MemtileActSizeIfmMM2S, -1))).item(),    MemtileMinIfmSize) 
            self.MemtileOutSizeMax = max(np.max((np.sum(self.MemtileOutSizeOfmS2MM,-1), self.MemtileOutSizeOfmMM2S)).item(),     MemtileMinOfmSize) 
        else:
            self.MemtileActSizeMax = np.max((self.MemtileActSizeIfmS2MM, np.sum(self.MemtileActSizeIfmMM2S, -1))).item()
            self.MemtileOutSizeMax = np.max((np.sum(self.MemtileOutSizeOfmS2MM,-1), self.MemtileOutSizeOfmMM2S)).item()
        
        self.MemtileBiasSizeMax = np.max(self.MemtileBiasSize).item()
    
        self.MemtilePrmPingAddr, self.MemtileInPingAddr, self.MemtileInPongAddr, self.MemtileOutPingAddr, \
        self.MemtileOutPongAddr, self.MemtileBiasPingAddr, self.MemtileQdqPrmAddr = set_mem_addr(
                self.MemtileActSizeMax,
                self.MemtileBiasSizeMax,
                CoreQdqPrmSize,
                self.MemtileOutSizeMax,
                CorePrmSize*AieRows,
                config.MAX_MEMTILE_ADDR,
                self.NORM_DEBUG
                )
        
        memtile_transfers = self.get_memtile_param_transfers(1, self.MemtilePrmPingAddr) + \
                            self.get_memtile_qdq_transfers(1, self.MemtileQdqPrmAddr) + \
                            self.get_memtile_bias_transfers([self.MemtileBiasPingAddr]) + \
                            self.get_memtile_data_transfers()
    
        return memtile_transfers 
    
    def get_shim_transfers(self):
        ShimBiasSize    = self.ShimBiasSize  
        shim_transfers =    self.get_shim_param_transfers() + \
                            self.get_shim_qdq_transfers(ShimBiasSize) + \
                            self.get_shim_bias_transfers() + \
                            self.get_shim_data_transfers(NormTransferType.IfmMM2S) + \
                            self.get_shim_data_transfers(NormTransferType.OfmS2MM)
        return shim_transfers


    def core_inst_inPlace(self, k_name, UC_chan, BC_chan, Bias_chan, Nlayer, Nsubv_core, split_type_value, 
            col, row, Touter, inst, gamma_offset, beta_offset, 
            CoreInPingAddr, CoreInPongAddr, CoreOutPingAddr, CoreOutPongAddr, CoreBiasPingAddr, CoreQdqPrmAddr,
            CoreQdqPrmSize, CoreBiasSize, CoreActSize, CoreOutSize):

        return_list = []
        max_r = config.MAX_TASK_QUEUE_SIZE * config.MAX_REPEAT_COUNT
        q = Touter[inst][col][row] // max_r
        r = Touter[inst][col][row] % max_r

        if(Touter[inst][col][row] < max_r):
            max_r, q, r = Touter[inst][col][row], 1, 0

        return_list = [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, BC_chan), CoreQdqPrmAddr, None, CoreQdqPrmSize),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, BC_chan)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, BC_chan)),
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, Bias_chan), CoreBiasPingAddr, None, CoreBiasSize[inst][col][row]),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, Bias_chan)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, Bias_chan)),
                    Loop(q, [
                        ConfigBuffer(DmaChannel(DmaDir.MM2S, UC_chan), CoreOutPingAddr, CoreOutPongAddr,  CoreOutSize[inst][col][row]),
                        Loop(max_r, [
                            AcqBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                            ConfigBuffer(DmaChannel(DmaDir.S2MM, UC_chan), CoreInPingAddr, CoreInPongAddr, CoreActSize[inst][col][row]),
                            AcqBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                            CallKernel(k_name, kernel_params=lrn_kernel_params(Nlayer, Nsubv_core, split_type_value, col, CoreBiasPingAddr, gamma_offset, beta_offset, CoreQdqPrmAddr)),
                            RelBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                            RelBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                        ])
                    ])
                 ]
        if r:
            return_list += [
                        Loop(r, [
                            ConfigBuffer(DmaChannel(DmaDir.MM2S, UC_chan), CoreOutPingAddr, CoreOutPongAddr,  CoreOutSize[inst][col][row]),
                            Loop(1, [
                                AcqBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                                ConfigBuffer(DmaChannel(DmaDir.S2MM, UC_chan), CoreInPingAddr, CoreInPongAddr, CoreActSize[inst][col][row]),
                                AcqBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                                CallKernel(k_name, kernel_params=lrn_kernel_params(Nlayer, Nsubv_core, split_type_value, col, CoreBiasPingAddr, gamma_offset, beta_offset, CoreQdqPrmAddr)),
                                RelBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                                RelBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                            ])
                        ])
                    ]

        return return_list

    def core_inst(self, k_name, UC_chan, BC_chan, Bias_chan, Nlayer, Nsubv_core, split_type_value, 
            col, row, Touter, inst, gamma_offset, beta_offset,
            CoreInPingAddr, CoreInPongAddr, CoreOutPingAddr, CoreOutPongAddr, CoreBiasPingAddr, CoreQdqPrmAddr,
            CoreQdqPrmSize, CoreBiasSize, CoreActSize, CoreOutSize):

        return_list = []
        max_r = config.MAX_TASK_QUEUE_SIZE * config.MAX_REPEAT_COUNT
        q = Touter[inst][col][row] // max_r
        r = Touter[inst][col][row] % max_r

        if(Touter[inst][col][row] < max_r):
            max_r, q, r = Touter[inst][col][row], 1, 0

        return_list = [
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, BC_chan), CoreQdqPrmAddr, None, CoreQdqPrmSize),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, BC_chan)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, BC_chan)),
                    ConfigBuffer(DmaChannel(DmaDir.S2MM, Bias_chan), CoreBiasPingAddr, None, CoreBiasSize[inst][col][row]),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, Bias_chan)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, Bias_chan)),
                    Loop(q, [
                        ConfigBuffer(DmaChannel(DmaDir.MM2S, UC_chan), CoreOutPingAddr, CoreOutPongAddr,  CoreOutSize[inst][col][row]),
                        ConfigBuffer(DmaChannel(DmaDir.S2MM, UC_chan), CoreInPingAddr, CoreInPongAddr, CoreActSize[inst][col][row]),
                        Loop(max_r, [
                            AcqBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                            AcqBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                            CallKernel(k_name, kernel_params=lrn_kernel_params(Nlayer, Nsubv_core, split_type_value, col, CoreBiasPingAddr, gamma_offset, beta_offset, CoreQdqPrmAddr)),
                            RelBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                            RelBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                        ])
                    ])
                 ]
        if r:
            return_list += [
                        Loop(r, [
                            ConfigBuffer(DmaChannel(DmaDir.MM2S, UC_chan), CoreOutPingAddr, CoreOutPongAddr,  CoreOutSize[inst][col][row]),
                            ConfigBuffer(DmaChannel(DmaDir.S2MM, UC_chan), CoreInPingAddr, CoreInPongAddr, CoreActSize[inst][col][row]),
                            Loop(1, [
                                AcqBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                                AcqBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                                CallKernel(k_name, kernel_params=lrn_kernel_params(Nlayer, Nsubv_core, split_type_value, col, CoreBiasPingAddr, gamma_offset, beta_offset, CoreQdqPrmAddr)),
                                RelBuffer(DmaChannel(DmaDir.S2MM, UC_chan)),
                                RelBuffer(DmaChannel(DmaDir.MM2S, UC_chan)),
                            ])
                        ])
                    ]

        return return_list

    def layernorm_core_instrs(
            self,
            inst: int,
            col: int,
            row: int,
            gamma_offset: int,
            beta_offset: int):
        Mlrn            = self.Mlrn
        Nlrn            = self.Nlrn
        Nlayer          = self.Nlayer
        split_type      = self.split_type
        Nsubv_core      = self.Nsubv[inst][col][row]
        AieRows         = self.AieRows
        CoreActSize     = self.CoreActSize.tolist()
        CoreOutSize     = self.CoreOutSize.tolist()
        CoreBiasSize    = self.CoreBiasSize
        CoreQdqPrmSize  = self.CoreQdqPrmSize
        
        UC_chan = CoreStreamChan.UC_channel.value
        BC_chan = CoreStreamChan.BC_channel.value
        Bias_chan = UC_chan if (self.BiasStreamType == StreamType.UC) else BC_chan
        
        k_name          = list(self.KernelName.keys())[0]
        
        size_factor = 2 if self.OutBytes == 1 else 1
        
        # Buffer allocation acounts for the max core sub-volume in case of non-uniform tiling

        CoreInPingAddr, CoreInPongAddr, CoreOutPingAddr, CoreOutPongAddr, CoreBiasPingAddr, CoreQdqPrmAddr, is_inPlace = set_core_addr(
                np.max(self.CoreActSize).item(),
                np.max(self.CoreBiasSize).item(),
                CoreQdqPrmSize,
                size_factor*np.max(self.CoreOutSize).item(),
                8192,
                config.MAX_CORE_ADDR,
                self.NORM_DEBUG
                )

        Touter = self.CoreTouter.tolist() 

        if (Touter[inst][col][row]) :
            #core_instrs for active sub array
            if is_inPlace:
                return self.core_inst_inPlace(k_name, UC_chan, BC_chan, Bias_chan, Nlayer, Nsubv_core, split_type.value,
                        col, row, Touter, inst, gamma_offset, beta_offset,
                        CoreInPingAddr, CoreInPongAddr, CoreOutPingAddr, CoreOutPongAddr, CoreBiasPingAddr, CoreQdqPrmAddr,
                        CoreQdqPrmSize, CoreBiasSize, CoreActSize, CoreOutSize)
            else:
                return self.core_inst(k_name, UC_chan, BC_chan, Bias_chan, Nlayer, Nsubv_core, split_type.value,
                        col, row, Touter, inst, gamma_offset, beta_offset,
                        CoreInPingAddr, CoreInPongAddr, CoreOutPingAddr, CoreOutPongAddr, CoreBiasPingAddr, CoreQdqPrmAddr,
                        CoreQdqPrmSize, CoreBiasSize, CoreActSize, CoreOutSize)

        else :
            #core_instrs for in-active sub array
            return [
                ConfigBuffer(DmaChannel(DmaDir.S2MM, BC_chan), CoreQdqPrmAddr, None, CoreQdqPrmSize),
                AcqBuffer(DmaChannel(DmaDir.S2MM, BC_chan)),
                RelBuffer(DmaChannel(DmaDir.S2MM, BC_chan)),
                ConfigBuffer(DmaChannel(DmaDir.S2MM, Bias_chan), CoreBiasPingAddr, None, CoreBiasSize[inst][col][row]),
                AcqBuffer(DmaChannel(DmaDir.S2MM, Bias_chan)),
                RelBuffer(DmaChannel(DmaDir.S2MM, Bias_chan)),
            ] #layernorm core_instrs

class layer_norm(layer_norm_dataflow):
    pass

