import numpy as np

from OGOAT.src.Tiler.utils import compute_inverted_placement, factors
from dmacompiler import set_dev_gen, DevGen, config

set_dev_gen(DevGen.Aie2p)

class LUTTiler:

    def __init__(self, layer, device, overlay, kernel):

        self.layer = layer
        self.overlay = overlay
        self.device = device
        self.kernel = kernel

        self.kernel_granularity = {
            'ifm':np.array([kernel.Ngran]),
            'ofm':np.array([kernel.Ngran])
            }


        schedule_list = [5]

        self.memtile_subvols = {k:{} for k in schedule_list}
        self.memtile_sublayers = {k:{} for k in schedule_list}
        self.memtile_iters = {k:{} for k in schedule_list}

        self.fits_in_memtile = {k:{} for k in schedule_list}
        self.valid_fits_in_memtile = {k:{} for k in schedule_list}

        self.valid_memtile_subvols = {k:{} for k in schedule_list}
        self.valid_memtile_sublayers = {k:{} for k in schedule_list}
        self.valid_memtile_iters = {k:{} for k in schedule_list}

        self.core_subvols = {k:{} for k in schedule_list}
        self.core_iters= {k:{} for k in schedule_list}
        
        self.core_validity_checks = {k:{} for k in schedule_list}
        self.valid_core_subvols = {k:{} for k in schedule_list}
        self.valid_core_iters = {k:{} for k in schedule_list}

        # Handle padding
        self.padded_shapes = {
            'ifm': np.ceil( self.layer.in_act_shape / self.kernel_granularity['ifm'] / self.overlay.core_splits['ifm'] ) * self.kernel_granularity['ifm'] * self.overlay.core_splits['ifm'],
            'ofm': np.ceil( self.layer.out_act_shape/ self.kernel_granularity['ofm'] / self.overlay.core_splits['ofm'] ) * self.kernel_granularity['ofm'] * self.overlay.core_splits['ofm']
        }

        self.padding = {
            'ifm': self.padded_shapes['ifm'] - self.layer.in_act_shape,
            'ofm': self.padded_shapes['ofm'] - self.layer.out_act_shape
        }

        self.vars_dict = {
            'ifm_bytes': layer.in_bytes,
            'ofm_bytes': layer.out_bytes,
            ## check other constraints
            'Ngran': self.kernel.Ngran
        }


        ## calculate bank-wise space formulas
        self.inverted_placement = compute_inverted_placement(self.kernel.placement_constraints)

    def calculate_memtile_tilings(self):

        schedule_tilings = {}

        memtile_max_shapes = {}
        memtile_min_shapes = {}
        split_ratios = {}
        for operand in self.overlay.mem_splits.keys():
            memtile_max_shapes[operand] = self.padded_shapes[operand]//self.overlay.mem_splits[operand]
            # captures how many core subvolumes are stored in each memtile
            split_ratio = self.overlay.core_splits[operand]/self.overlay.mem_splits[operand]  ## Can be non-integer
            split_ratios[operand] = split_ratio
            memtile_min_shapes[operand] = self.kernel_granularity[operand] * split_ratio
        # print(memtile_max_shapes)

        Nmin, Nmax = memtile_min_shapes['ifm'][0], memtile_max_shapes['ifm'][0]

        schedule_tilings[5] = np.array(factors(Nmax,Nmin))[:,np.newaxis]
        # print(schedule_tilings)

        for k, tmp in schedule_tilings.items():
            self.memtile_iters[k] = np.hstack([memtile_max_shapes['ifm']// tmp,memtile_max_shapes['ofm']// tmp]) 

            self.memtile_subvols[k] = {
                'ifm': tmp ,#* split_ratios['ifm'],
                'ofm': tmp #* split_ratios['ofm']
                }

            self.memtile_sublayers[k] = {
                'ifm': tmp * self.overlay.mem_splits['ifm'],#* split_ratios['ifm'] 
                'ofm': tmp * self.overlay.mem_splits['ofm'] #* split_ratios['ofm'] 
                }
                
        

    def check_valid_memtile_tilings(self):
        memtile_capacity = self.device.memtile_capacity * self.device.memtile_rows * 1024 # In bytes
        memtile_iters_max = config.MAX_REPEAT_COUNT * config.MAX_TASK_QUEUE_SIZE

        fits_within_memtiles = {}
        valid_subvolumes = {}
        valid_memtile_sublayers = {}
        valid_fits_within_memtiles = {}
        valid_memtile_iters = {}
        for sched, memtile_subvols in self.memtile_subvols.items():
            single_subvolume_sizes = {
                'ifm': np.prod(memtile_subvols['ifm'],1) * self.layer.in_bytes,
                'ofm': np.prod(memtile_subvols['ofm'],1) * self.layer.out_bytes
            }
            total_space_required={}
            total_iters_required={}
            if sched == 5:
                # ##############
                # Schedule 5: ifm stream, wgt stream
                # ##############
                space_required_2buff = single_subvolume_sizes['ifm']*2 + single_subvolume_sizes['ofm'] * 2
                total_space_required = space_required_2buff.reshape((len(space_required_2buff),1))

            memtile_iter_overflow_check = np.all((self.memtile_iters[sched]<memtile_iters_max), axis=1, keepdims=True)
            memtile_space_overflow_check = total_space_required<memtile_capacity
            valid_mask = memtile_iter_overflow_check & memtile_space_overflow_check
            fits_within_memtiles[sched] = valid_mask
            valid_fits_within_memtiles[sched] = fits_within_memtiles[sched][np.any(fits_within_memtiles[sched],1)]

            valid_subvolumes[sched] = {operand: memtile_subvols[operand][np.any(fits_within_memtiles[sched],1)] for operand in memtile_subvols.keys()}

            valid_memtile_sublayers[sched] = {operand:tiling*self.overlay.mem_splits[operand] for operand, tiling in valid_subvolumes[sched].items()}
            valid_memtile_iters[sched] = self.memtile_iters[sched][np.any(fits_within_memtiles[sched],1)]        

        self.fits_in_memtile = fits_within_memtiles
        self.valid_fits_in_memtile = valid_fits_within_memtiles

        self.valid_memtile_subvols = valid_subvolumes
        self.valid_memtile_sublayers = valid_memtile_sublayers
        self.valid_memtile_iters = valid_memtile_iters

    def calculate_array_tilings(self):

        for sched, valid_memtile_sublayers in self.valid_memtile_sublayers.items():

            core_max_shapes = {}
            for operand in valid_memtile_sublayers.keys():
                core_max_shapes[operand] = valid_memtile_sublayers[operand] // self.overlay.core_splits[operand]
            core_min_shapes = self.kernel_granularity
            # print(core_max_shapes)

            core_subvols = {}
            core_iters = {}
            
            for k in range(len(core_max_shapes['ofm'])):
                n_sublist={}
                if sched == 5:
                # ##############
                # Schedule 5: ifm stream, wgt stream (2 buffers of ifm, 2 for wgt, 1 for ofm)
                # ##############
                    n_sublist = [core_max_shapes['ofm'][k][0]]

                tmp = np.array(n_sublist)
                

                core_iters[k] = core_max_shapes['ifm'][k][0] // tmp
                core_subvols[k] = {
                    'ifm': tmp,
                    'ofm': tmp
                }
            
            self.core_subvols[sched] = core_subvols
            self.core_iters[sched] = core_iters


    def check_core_constraints(self):
        core_bank_capacity = 1024 * (self.device.core_data_memory // self.device.core_num_banks) #bytes

        
        for sched, core_subvol_dict in self.core_subvols.items():
            # self.core_validity_checks.setdefault(sched,{})
            for mem_sub_id, subvols in core_subvol_dict.items():

                self.vars_dict['Nsubv'] = subvols['ifm'][0]

                ## check buffer placements
                space_required = {}
                valid_placement = True
                for bank, formula in self.inverted_placement.items():
                    space_required[bank] = eval(formula, self.vars_dict)
                    
                    valid_placement = valid_placement & (space_required[bank]<=core_bank_capacity)
                
                validity_checks = {}
                validity_checks['buffer_placement'] = valid_placement
                
                # for constraint, formula in self.kernel.other_constraints.items():
                #     validity_checks[constraint] = eval(formula, self.vars_dict)
                
                # subvols that satisfy all constraints
                valid_subvols = np.all(np.vstack(list(validity_checks.values())).T,1)
                # print(validity_checks)
                
                self.core_validity_checks[sched][mem_sub_id] = validity_checks
                # self.core_validity_checks

                # self.valid_core_subvols.setdefault(sched,{})
                self.valid_core_subvols[sched][mem_sub_id] = {
                    'ifm': subvols['ifm'][valid_subvols],
                    'ofm': subvols['ofm'][valid_subvols],
                }

                self.valid_core_iters[sched][mem_sub_id] = self.core_iters[sched][mem_sub_id][valid_subvols]

if __name__=='__main__':
    from overlay import Overlay
    ov=Overlay('8x4','Silu','N32')

    c=7
    ov.cols=c
    ov.core_splits['ifm']=ov.core_splits['ifm']//8*c
    ov.core_splits['ofm']=ov.core_splits['ofm']//8*c

    ov.mem_splits['ifm']=ov.mem_splits['ifm']//8*c
    ov.mem_splits['ofm']=ov.mem_splits['ofm']//8*c


    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_silugelu.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)


    t = LUTTiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()
