import numpy as np
from math import floor, ceil
from pprint import pprint
import warnings
from OGOAT.src.Tiler.overlay import Overlay

class LayerNormTiler:

    def __init__(self, layer, device, overlay, kernel):

        self.layer = layer
        self.device = device
        self.kernel = kernel
        self.kernel_list_path = self.kernel.kernel_list_path
        self.overlay = overlay
        
        schedule_list = [1]

        self.memtile_subvols = {k:{} for k in schedule_list}
        self.memtile_sublayers = {k:{} for k in schedule_list}
        self.memtile_iters = {k:{} for k in schedule_list}

        self.fits_in_memtile = {k:{} for k in schedule_list}
        self.valid_fits_in_memtile = {k:{} for k in schedule_list}

        self.valid_memtile_subvols = {k:{} for k in schedule_list}
        self.valid_memtile_sublayers = {k:{} for k in schedule_list}
        self.valid_memtile_iters = {k:{} for k in schedule_list}

        self.core_subvols = {k:{} for k in schedule_list}
        self.core_iters= {k:{} for k in schedule_list}
        
        self.core_validity_checks = {k:{} for k in schedule_list}
        self.valid_core_subvols = {k:{} for k in schedule_list}
        self.valid_core_iters = {k:{} for k in schedule_list}
        in_act_shape = self.layer.in_act_shape
        
        t_dim  = layer.reshape_shape(in_act_shape)

        #t_dim = [32, 1024] 

        self.tensorshape = {
            'ifm' : np.array(t_dim),
            'ofm' : np.array(t_dim)
        }
        
        self.max_ifm_rs_kgpn = 2560 #elements
        self.max_ifm_rs       = 4096 
        self.max_ifm_cs       = 4096 
        self.split_type = self.get_split_type(kernel)
        self.enable_mpad      = 0 # 1 -> Enable Mpadding in the host
        #pprint(vars(kernel))
        
        if(self.split_type == 'ROW_SPLIT_K_GPN'):
            self.overlay.aieinst = 1 # # 4x4->1x4x4  # 4x8->1x4x8
            Mgran = kernel.Mgran_KGPN
            Ngran = kernel.Ngran_KGPN
        elif(self.split_type == 'ROW_SPLIT'):
            self.overlay.aieinst = 1 # # 4x4->1x4x4  # 4x8->1x4x8
            Mgran = kernel.Mgran
            Ngran = kernel.Ngran
        elif(self.split_type == 'COL_SPLIT'):
            self.overlay.aieinst = overlay.cols//overlay.rows # 4x4->1x4x4  # 4x8->2x4x4
            Mgran = kernel.Mgran
            Ngran = kernel.Ngran
        else:
            assert False, "Invalid Kernel Type. Should be either 'RowSplit' or 'ColSplit'"

        self.kernel_granularity = {
            'ifm':np.array([Mgran, Ngran]),
            'ofm':np.array([Mgran, Ngran])
            }
        #print('Granularity:',self.kernel_granularity['ifm'])

        self.vars_dict = {
            'ifm_bytes': layer.in_bytes,
            'ofm_bytes': layer.out_bytes,
            ## check other constraints
            'Mgran': self.kernel_granularity['ifm'][0].item(),
            'Ngran': self.kernel_granularity['ifm'][1].item()
        }
        
    def get_split_type(self, kernel):
        N = self.tensorshape['ifm'][1]
        max_n_kgpn = self.max_ifm_rs_kgpn // kernel.Mgran_KGPN
        max_n_rs   = self.max_ifm_rs // kernel.Mgran

        max_n_cs   = (self.max_ifm_cs // kernel.Mgran)*(self.overlay.rows**2)
        
        st = ''
        if kernel.SplitType == 'OPTIMAL' :
            if N <= max_n_rs :
                st = 'ROW_SPLIT'
            else:
                st = 'COL_SPLIT'
        else : 
            st = kernel.SplitType
        
        return st
    def get_nlrn(self, is_Nlayer_word_bound):
        
        Nlrn_layer = self.tensorshape['ifm'][1].item()
        Ngran = self.kernel_granularity['ifm'][1].item()
        en_prev_op_consis = self.kernel.other_constraints['en_prev_op_consistancy']
        if en_prev_op_consis:
            ngran_prev = self.kernel.other_constraints['prev_op']['Ngran']
            nmin_prev  = self.kernel.other_constraints['prev_op']['Nmin']
        else:
            ngran_prev = Ngran
            nmin_prev  = Ngran
        
        ngran = max(ngran_prev, Ngran)
        nmin  = max(nmin_prev,  Ngran)

        if((en_prev_op_consis == 0) and is_Nlayer_word_bound)  : # pads N to the greater multiple of 8 if it is not word bound
            Nlrn = nmin if Nlrn_layer<nmin else Nlrn_layer
        else:
            if Nlrn_layer < nmin:
                Nlrn = nmin
            else:
                Nlrn = Nlrn_layer + ngran - (Nlrn_layer%ngran) if (Nlrn_layer%ngran) else Nlrn_layer  
        return Nlrn
    
    def calculate_memtile_tilings(self):
        pass
                
    def check_valid_memtile_tilings(self):
        self.valid_memtile_subvols = self.memtile_subvols
        self.valid_memtile_sublayers = self.memtile_sublayers
        #self.valid_const_padding = self.const_padding
        self.valid_memtile_iters = self.memtile_iters
    def check_valid_shimtile_tilings(self):
        self.valid_shimtile_subvols = self.shimtile_subvols
    
    def max_factor(self, num, gran, maxlim):
        num = int(num)
        k = [i for i in range(gran,num+1,gran) if (num%i==0) & (i<=maxlim)]
        if not k: 
            warnings.warn("Dimension is too small")
        else:
            return k[-1]
    def row_split_array_tilings(self, max_ifm_samp:int):
        core_subvols = {}
        core_iters = {}
        #max_ifm_samp = 2560
        Mgran = self.kernel_granularity['ifm'][0].item()
        Ngran = self.kernel_granularity['ifm'][1].item()
        Mlrn_layer = self.tensorshape['ifm'][0].item()
        Nlrn_layer = self.tensorshape['ifm'][1].item()
        ifm_bytes = self.vars_dict['ifm_bytes']
        aie_inst = self.overlay.aieinst
        aie_cols = self.overlay.cols // aie_inst
        aie_rows = self.overlay.rows
        
        assert(Nlrn_layer <= max_ifm_samp)
        is_Nlayer_word_bound = 0 if ((Nlrn_layer*ifm_bytes)%4) else 1
        
        Nlrn = self.get_nlrn(is_Nlayer_word_bound) 
        
        # logic make the Mlrn greater multiple of Ncores if not an exact multiple
        Mlrn_min = Mgran*aie_cols*aie_rows
        if(Mlrn_layer%Mlrn_min and Mlrn_layer > Mlrn_min and self.enable_mpad): 
            Mlrn = Mlrn_layer + (Mgran*aie_rows) - (Mlrn_layer%(Mgran*aie_rows))
        else:
            Mlrn = Mlrn_layer
        #Mlrn = Mlrn_layer

        ######## Shim tilings
        #mshim_per_col = np.array([Mlrn//aie_cols], ndmin=2)
        mshim_per_col = np.array(self.split_number_with_gran(Mlrn, aie_cols, granularity=Mgran*aie_rows, min_value = Mgran*aie_rows), ndmin=2)

        nshim_per_col = np.array([Nlrn]*aie_cols, ndmin=2)
        Nlrn_shim_tmp = nshim_per_col
        shim_sv_pcol = np.concatenate((mshim_per_col, Nlrn_shim_tmp), axis=0).T
        shim_sv_pinst_pcol = shim_sv_pcol.reshape((aie_inst,aie_cols, 2)) 
        self.ShimTilings  = shim_sv_pinst_pcol # 2x4x2 matrix for overlay: 2x4x4
        #print(f'Shim Tilings: {self.ShimTilings}')
        ####### Core Tilings   
        if(Nlrn%Ngran): 
            Nlrn_core = Nlrn + Ngran - (Nlrn%Ngran)
        else:
            Nlrn_core = Nlrn
        assert(Nlrn_core%Ngran == 0) 
        
        Mlrn_core   = Mgran 
        
        core_tiling_Msubv   = np.zeros((1, aie_cols, aie_rows))
        core_tiling_Msubv_tmp = self.split_number_with_gran(np.squeeze(self.ShimTilings[:,:,0]), aie_rows, Mgran, Mgran, force_granularity=True)  
        core_tiling_Msubv[0,:,:] = np.where(core_tiling_Msubv_tmp, Mgran,core_tiling_Msubv_tmp)  
        core_tiling_Msubv   = core_tiling_Msubv.reshape((aie_inst, aie_cols, aie_rows, -1), order='F').astype(int) # 2x4x4x1
        #print("core_tiling_Msubv:", core_tiling_Msubv.shape)
        core_tiling_Nsubv   = np.array([Nlrn_core]) # 4x1 array
        core_tiling_Nsubv   = core_tiling_Nsubv.repeat(aie_inst*aie_cols*aie_rows, axis=-1)
        core_tiling_Nsubv   = core_tiling_Nsubv.reshape((aie_inst, aie_cols, aie_rows, -1), order='F') # 2x4x4
        #print("core_tiling_Nsubv:", core_tiling_Nsubv.shape)
        self.CoreTilings    = np.concatenate((core_tiling_Msubv, core_tiling_Nsubv), axis=-1) # 2x4x4x2
        #print("core_tilings:", self.CoreTilings)
        ###### Memtile Tilings IFM S2MM and OFM MM2S
        mem_tilings_tmp         = self.ShimTilings.copy()
        mem_tilings_tmp[:,:,0]  = np.where(mem_tilings_tmp[:,:,0] < Mgran*aie_rows, mem_tilings_tmp[:,:,0], Mgran*aie_rows)
        mem_tilings_tmp[:,:,1]  = mem_tilings_tmp[:,:,1] if(is_Nlayer_word_bound) else Nlrn_core
        self.MemTilingsIfmS2MM  = mem_tilings_tmp.copy()
        self.MemTilingsOfmMM2S  = mem_tilings_tmp.copy()
        
        ##### Memtile Tilings IFM MM2S and OFM S2MM
        mem_tilings_tmp         = self.CoreTilings.copy()
        self.MemTilingsIfmMM2S  = mem_tilings_tmp.copy()
        self.MemTilingsOfmS2MM  = mem_tilings_tmp.copy()
        
        self.MemTouter =np.ceil(np.divide(self.ShimTilings[:,:,0], self.MemTilingsIfmS2MM[:,:,0], where=self.ShimTilings[:,:,0]!=0)).astype(int)
        self.CoreTouter = self.MemTouter # 2x4 the value in axis 1 is for each column in an instance
        #print("mem_tilings IfmS2MM:", self.MemTilingsIfmS2MM)
        #print("mem_tilings IfmMM2S:", self.MemTilingsIfmMM2S)
        #print("Touter:", self.MemTouter)
        
        tiling = {
            'core_sv'       : self.CoreTilings,
            'core_itr'      : self.CoreTouter,
            'mem_sv_ifm'    : {'s2mm' : self.MemTilingsIfmS2MM , 'mm2s' : self.MemTilingsIfmMM2S},
            'mem_sv_ofm'    : {'s2mm' : self.MemTilingsOfmS2MM , 'mm2s' : self.MemTilingsOfmMM2S},
            'mem_itr'       : self.MemTouter,
            'shim_sv'       : shim_sv_pinst_pcol
            }
        return tiling 
    def col_split_array_tilings(self):
        core_subvols = {}
        core_iters = {}
        
        Mgran = self.kernel_granularity['ifm'][0].item()
        Ngran = self.kernel_granularity['ifm'][1].item()
        Mlrn_layer = self.tensorshape['ifm'][0].item()
        Nlrn_layer = self.tensorshape['ifm'][1].item()
        ifm_bytes = self.vars_dict['ifm_bytes']
        aie_inst = self.overlay.aieinst
        aie_cols = self.overlay.cols//aie_inst
        aie_rows = self.overlay.rows
        if(Mlrn_layer%Mgran and Mlrn_layer>Mgran and self.enable_mpad): 
            Mlrn = Mlrn_layer + Mgran - (Mlrn_layer%(Mgran))
        else:
            Mlrn = Mlrn_layer
        #Mlrn = Mlrn_layer
        is_Nlayer_word_bound = 0 if ((Nlrn_layer*ifm_bytes)%4) else 1
        Nlrn = self.get_nlrn(is_Nlayer_word_bound) 
        
        Mlrn_core = Mgran 
        
        Mlrn_mem = (Mlrn_core)
        
        ######## Shim tilings
        if(Mlrn_layer < Mgran):
            mshim_per_inst = np.array([Mlrn_layer] + [0]*(aie_inst-1), ndmin=2) # TODO: Replace with the best distribution algorithm
        else:
            mshim_per_inst = np.array(self.split_number_with_gran(Mlrn, aie_inst, granularity=Mgran, min_value = Mgran), ndmin=2) 
        
        Mlrn_shim_tmp = mshim_per_inst.repeat(aie_cols, 1) 
        # min possible granularity to make sure the distribution is as uniform as possible
        nshim_per_col = self.split_number_with_gran(Nlrn, aie_cols, granularity=Ngran, min_value = 2) 
        # repeat the list the aie_inst times
        Nlrn_shim_tmp = np.array(nshim_per_col*aie_inst, ndmin=2) # repeat the list the aie_inst times

        shim_sv_pcol = np.concatenate((Mlrn_shim_tmp, Nlrn_shim_tmp), axis=0).T
        
        shim_sv_pinst_pcol = shim_sv_pcol.reshape((aie_inst,aie_cols, 2)) 
        self.ShimTilings  = shim_sv_pinst_pcol # 2x4x2 matrix for overlay: 2x4x4
        
        #print(f'Shim Tilings: {self.ShimTilings}')
        Mlrn_shim = np.sum(mshim_per_inst)
        
        ####### Core Tilings   
        Nsubv = self.split_number_with_gran(self.ShimTilings[0,:,1], aie_rows, Ngran, force_granularity=True) 
        core_tiling_Msubv   = np.where(self.ShimTilings[:,:,[0]]>0, Mgran, 0)
        core_tiling_Msubv   = core_tiling_Msubv.repeat(aie_rows, axis=-1)
        core_tiling_Msubv   = core_tiling_Msubv.reshape((aie_inst, aie_cols, aie_rows, -1), order='F') # 2x4x4
        core_tiling_Nsubv   = np.array([Nsubv]) # 4x4 array
        core_tiling_Nsubv   = core_tiling_Nsubv.repeat(aie_inst, axis=0)
        core_tiling_Nsubv   = core_tiling_Nsubv.reshape((aie_inst, aie_cols, aie_rows, -1), order='F') # 2x4x4
        self.CoreTilings    = np.concatenate((core_tiling_Msubv, core_tiling_Nsubv), axis=-1) # 2x4x4x2
        
        ###### Memtile Tilings IFM S2MM and OFM MM2S
        mem_tilings_tmp         = self.ShimTilings.copy()
        mem_tilings_tmp[:,:,0]  = np.where(mem_tilings_tmp[:,:,0] < Mgran, mem_tilings_tmp[:,:,0], Mgran) 
        self.MemTilingsIfmS2MM  = mem_tilings_tmp.copy()
        self.MemTilingsOfmMM2S  = mem_tilings_tmp.copy()
        
        ##### Memtile Tilings IFM MM2S and OFM S2MM
        mem_tilings_tmp         = self.CoreTilings.copy()
        self.MemTilingsIfmMM2S  = mem_tilings_tmp.copy()
        self.MemTilingsOfmS2MM  = mem_tilings_tmp.copy()
        
        self.MemTouter =np.ceil(np.divide(self.ShimTilings[:,:,0], self.MemTilingsIfmS2MM[:,:,0], where=self.ShimTilings[:,:,0]!=0)).astype(int)
        self.CoreTouter = self.MemTouter # 2x4 the value in axis 1 is for each column in an instance
        
        tiling = {
            'core_sv'       : self.CoreTilings,
            'core_itr'      : self.CoreTouter,
            'mem_sv_ifm'    : {'s2mm' : self.MemTilingsIfmS2MM , 'mm2s' : self.MemTilingsIfmMM2S},
            'mem_sv_ofm'    : {'s2mm' : self.MemTilingsOfmS2MM , 'mm2s' : self.MemTilingsOfmMM2S},
            'mem_itr'       : self.MemTouter,
            'shim_sv'       : shim_sv_pinst_pcol
            }
        return tiling 
    def split_number_with_gran(self, N, M, granularity=8, min_value = 8, force_granularity=False):
        if isinstance(N, int):
            N_list = [N]
            input_type = 'int'
        elif isinstance(N, list):
            N_list = N
            input_type = 'list'
        elif isinstance(N, np.ndarray):
            N_list = N.tolist()
            input_type = 'numpy'
        else:
            raise ValueError("N must be either an integer, a list of integers, or a numpy array of integers")
        
        if M <= 0:
            raise ValueError("Number of elements (M) must be greater than 0")
        if granularity <= 0:
            raise ValueError("Granularity must be greater than 0")
        if min_value < 0:
            raise ValueError("Minimum value must be non-negative")
        
        results = []
        
        for N in N_list:
            # Calculate the maximum number of elements that can meet the min_value criteria
            max_elements = ceil(N / min_value)
            if max_elements > M:
                max_elements = M
            
            # If no elements can meet the min_value criteria, fill with zeros
            if max_elements == 0:
                results.append([0] * M)
                continue
            
            # Calculate the base value for each element
            base_value = max((N // max_elements // granularity) * granularity, min_value)
            remainder = N - base_value * max_elements
            
            # Create the result array with the base value for the valid elements
            result = [base_value] * max_elements + [0] * (M - max_elements)
            
            # Distribute the remainder to the valid elements
            for i in range(max_elements):
                if remainder >= granularity:
                    result[i] += granularity
                    remainder -= granularity
                else:
                    result[i] += remainder
                    remainder = 0
            
            if force_granularity:
                for i in range(max_elements):
                    if result[i] < min_value:
                        result[i] = min_value
                    if result[i] % granularity != 0:
                        result[i] = ((result[i] // granularity) + 1) * granularity
            # sort the array in descending order
            results.append(sorted(result, reverse=True))
        
        if input_type == 'int':
            return results[0]
        elif input_type == 'list':
            return results
        elif input_type == 'numpy':
            return np.array(results)
    
    def calculate_array_tilings(self):
        
        if(self.split_type == 'ROW_SPLIT_K_GPN'):
            tiling = self.row_split_kgpn_array_tilings(self.max_ifm_rs_kgpn)
            self.ktype = 'K_GPN'
        elif(self.split_type == 'ROW_SPLIT'):
            tiling = self.row_split_array_tilings(self.max_ifm_rs)
            self.ktype = 'K_LRN'
        else:
            tiling = self.col_split_array_tilings() 
            self.ktype = 'K_LRN'
        
        core_subvols = {
            'ifm': tiling['core_sv'],
            'ofm': tiling['core_sv']
        }
        self.core_subvols = core_subvols
        core_iters = {
            'ifm': tiling['core_itr'],
            'ofm': tiling['core_itr']
        }
        self.core_iters = core_iters

        self.memtile_subvols = {
            'ifm': tiling['mem_sv_ifm'],
            'ofm': tiling['mem_sv_ofm'] 
            }

        #self.const_padding = {
        #    'ifm': tiling['const_pad'], 
        #    }

        self.memtile_iters = {
            'ifm': tiling['mem_itr'],
            'ofm': tiling['mem_itr']
            }
        self.memtile_sublayers = {
            'ifm': self.memtile_subvols['ifm'] ,
            'ofm': self.memtile_subvols['ofm'] 
            }
        self.shimtile_subvols = {
            'ifm': tiling['shim_sv'],
            'ofm': tiling['shim_sv'] 
            }

    def check_core_constraints(self):
        diff_sv = self.core_subvols['ifm'] - self.kernel_granularity['ifm']
        if np.any(diff_sv<0):
            warnings.warn("Core Sub volume is voilating the minimum granularity.")
        self.valid_core_subvols = self.core_subvols 
        self.valid_core_iters = self.core_iters

if __name__=='__main__':
    
    ov=Overlay('4x4','LayerNormalization','N16')

    import json
    from layer import Layer
    with open('src/Tiler/tst_layernorm.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    t = LayerNormTiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()
