import numpy as np

from OGOAT.src.Tiler.utils import (
    compute_inverted_placement, 
    factors
)

class ElemWiseOpsTiler:

    def __init__(self, layer, device, overlay, kernel):

        self.layer = layer
        self.overlay = overlay
        self.device = device
        self.kernel = kernel

        self.kernel_granularity = {
            'ifmA':np.array([kernel.Ngran]),
            'ifmB':np.array([kernel.Ngran]),
            'ofm':np.array([kernel.Ngran])
            }


         #2-IfmA Pin, IfmB/OFM Stream
        #2-IfmA/IfmB/OFM Stream
        self.schedule_list = [0, 1, 2, 3, 4, 5]

        self.memtile_subvols = {k:{} for k in self.schedule_list}
        self.memtile_sublayers = {k:{} for k in self.schedule_list}
        self.memtile_iters = {k:{} for k in self.schedule_list}

        self.fits_in_memtile = {k:{} for k in self.schedule_list}
        self.valid_fits_in_memtile = {k:{} for k in self.schedule_list}

        self.valid_memtile_subvols = {k:{} for k in self.schedule_list}
        self.valid_memtile_sublayers = {k:{} for k in self.schedule_list}
        self.valid_memtile_iters = {k:{} for k in self.schedule_list}

        self.core_subvols = {k:{} for k in self.schedule_list}
        self.core_iters= {k:{} for k in self.schedule_list}
        
        self.core_validity_checks = {k:{} for k in self.schedule_list}
        self.valid_core_subvols = {k:{} for k in self.schedule_list}
        self.valid_core_iters = {k:{} for k in self.schedule_list}

        self.numBatches = None
        if layer.attributes is not None:
            self.numBatches = layer.attributes.get('num_batches', None)

        if self.numBatches == None:
            assert self.layer.orig_act_shape == self.layer.orig_wgt_shape == self.layer.orig_ofm_shape, f"Mismatch in the dim, all shapes should be exactly same, {self.layer.orig_act_shape}, {self.layer.orig_wgt_shape}, {self.layer.orig_ofm_shape}"

        in_act_shape = self.layer.in_act_shape.copy()
        in_act_shape[0] = self.layer.in_act_shape[0] // self.numBatches[0] if self.numBatches != None else self.layer.in_act_shape[0]
        #NOTE
        #ifmA -> wgt (for BcastOP the broadcast tensor is assume to be ifmA. In case of ElewOP no impact)
        #ifmB -> act (for BcastOP and ElewOP it is the other operand)
        # Handle padding
        ifmA_padded_shapes = np.ceil( np.prod(self.layer.in_wgt_shape) / self.kernel_granularity['ifmA'] / self.overlay.core_splits['ifmA'] ) * self.kernel_granularity['ifmA'] * self.overlay.core_splits['ifmA']
        ifmB_padded_shapes = np.ceil( np.prod(in_act_shape) / self.kernel_granularity['ifmB'] / self.overlay.core_splits['ifmB'] ) * self.kernel_granularity['ifmB'] * self.overlay.core_splits['ifmB']
        ofm_padded_shapes  = np.ceil( np.prod(self.layer.out_act_shape) / self.kernel_granularity['ofm'] / self.overlay.core_splits['ofm'] ) * self.kernel_granularity['ofm'] * self.overlay.core_splits['ofm']
        self.padded_shapes = {
            'ifmA': ifmA_padded_shapes,
            'ifmB': self.pad_to_make_perfect_divisor(ifmB_padded_shapes, ifmA_padded_shapes),
            'ofm': self.pad_to_make_perfect_divisor(ifmB_padded_shapes, ifmA_padded_shapes)
        }

        self.padding = {
            'ifmA': self.padded_shapes['ifmA'] - np.prod(self.layer.in_wgt_shape),
            'ifmB': self.padded_shapes['ifmB'] - np.prod(in_act_shape),
            'ofm': self.padded_shapes['ofm'] - np.prod(self.layer.out_act_shape)
        }

        self.host_padding = {
            'ifmA' : np.array(self.layer.in_wgt_shape).astype(np.int32).tolist(),
            'ifmB' : np.array(self.layer.in_act_shape).astype(np.int32).tolist(),
            'ofm'  : np.array(self.layer.out_act_shape).astype(np.int32).tolist()
        }

        self.vars_dict = {
            'ifmA_bytes': layer.wgt_bytes,
            'ifmB_bytes': layer.in_bytes,
            'ofm_bytes': layer.out_bytes,
            ## check other constraints
            'Ngran': self.kernel.Ngran
        }

        ## calculate bank-wise space formulas
        self.inverted_placement = compute_inverted_placement(self.kernel.placement_constraints)

    def pad_to_make_perfect_divisor(self, input_shape, aligned_to_shape):
        #NOTE- This is needed to make ifmB and OFM to be perfect divisor of ifmA
        unaligned = input_shape % aligned_to_shape
        if unaligned == 0:
            return input_shape
        padding_needed = aligned_to_shape - unaligned
        return (input_shape+padding_needed)

    def calculate_memtile_tilings(self):

        schedule_tilings_ifmA = {}
        schedule_tilings_ifmB = {}

        memtile_max_shapes = {}
        memtile_min_shapes = {}
        split_ratios = {}
        for operand in self.overlay.mem_splits.keys():
            memtile_max_shapes[operand] = self.padded_shapes[operand]//self.overlay.mem_splits[operand]
            # captures how many core subvolumes are stored in each memtile
            split_ratio = self.overlay.core_splits[operand]/self.overlay.mem_splits[operand]  ## Can be non-integer
            split_ratios[operand] = split_ratio
            memtile_min_shapes[operand] = self.kernel_granularity[operand] * split_ratio
        # print(memtile_max_shapes)

        Nmin_ifmA, Nmax_ifmA = memtile_min_shapes['ifmA'][0], memtile_max_shapes['ifmA'][0] #for wgt matrix
        Nmin_ifmB, Nmax_ifmB = memtile_min_shapes['ifmB'][0], memtile_max_shapes['ifmB'][0] #for act/out matrix
        
        schedule_tilings_ifmA = np.array(factors(Nmax_ifmA, Nmin_ifmA))[:,np.newaxis]
        schedule_tilings_ifmB = np.array(factors(Nmax_ifmB, Nmin_ifmB))[:,np.newaxis]

        #assert len(schedule_tilings_ifmB) >= len(schedule_tilings_ifmA), f"Act/Out matrix expected to be bigger or same size as wgt matrix"

        #schedule_tilings_ifmB = [tmp_ifmB for k, tmp_ifmB in enumerate(schedule_tilings_ifmB) if tmp_ifmB % schedule_tilings_ifmA[k % len(schedule_tilings_ifmA)] == 0 ]
        #schedule_tilings_ifmB = np.intersect1d(schedule_tilings_ifmB.flatten(), schedule_tilings_ifmA.flatten())
        #schedule_tilings_ifmA = schedule_tilings_ifmA.reshape(-1,1)
        #schedule_tilings_ifmB = schedule_tilings_ifmB.reshape(-1,1)
            

        for k, tmp_ifmB in enumerate(schedule_tilings_ifmB): 
            if k not in self.schedule_list:
                continue
            #Ensure ifmB tiling matches the iteration by doing wrap around
            #tmp_ifmA = np.array([1]) if self.overlay.mode == 'N32_pin' else schedule_tilings_ifmA[k % len(schedule_tilings_ifmA)]
            tmp_ifmA = schedule_tilings_ifmA[k % len(schedule_tilings_ifmA)]
            ifmA_iter = memtile_max_shapes['ifmA'] // tmp_ifmA if memtile_max_shapes['ifmA'] // tmp_ifmA > 1 else 1
            self.memtile_iters[k] = np.hstack([ifmA_iter, memtile_max_shapes['ifmB'] // max(tmp_ifmB, tmp_ifmA), memtile_max_shapes['ofm'] // max(tmp_ifmB, tmp_ifmA)])
            self.memtile_subvols[k] = {
                    'ifmA': tmp_ifmA,
                    'ifmB': max(tmp_ifmB, tmp_ifmA),
                    'ofm': max(tmp_ifmB, tmp_ifmA)
                    }
            self.memtile_sublayers[k] = {
                    'ifmA': tmp_ifmA * self.overlay.mem_splits['ifmA'],
                    'ifmB': max(tmp_ifmB, tmp_ifmA) * self.overlay.mem_splits['ifmB'],
                    'ofm': max(tmp_ifmB, tmp_ifmA) * self.overlay.mem_splits['ofm']
                    }

    def check_valid_memtile_tilings(self):
        memtile_capacity = self.device.memtile_capacity * self.device.memtile_rows * 1024 # In bytes

        fits_within_memtiles = {}
        valid_subvolumes = {}
        valid_memtile_sublayers = {}
        valid_fits_within_memtiles = {}
        valid_memtile_iters = {}

        for sched, memtile_subvols in self.memtile_subvols.items():
            if not memtile_subvols:
                continue

            single_subvolume_sizes = {
                'ifmA': np.prod(memtile_subvols['ifmA']) * self.layer.wgt_bytes,
                'ifmB': np.prod(memtile_subvols['ifmB']) * self.layer.in_bytes,
                'ofm': np.prod(memtile_subvols['ofm']) * self.layer.out_bytes
            }
            total_space_required = 0
            ifmA_num_buf = 1 if "BroadCast" in self.layer.op_type else 2
            if sched == 5:
                # ##############
                # Schedule 5: ifmA/B stream
                # # ##############
                space_required_2buff = single_subvolume_sizes['ifmA']*ifmA_num_buf + single_subvolume_sizes['ifmB']*2 + single_subvolume_sizes['ofm'] * 1
                total_space_required = space_required_2buff #space_required_2buff.reshape((len(space_required_2buff),1))
            elif sched == 2:
                # ##############
                # Schedule 2: ifmA pin, ifmB/ofm stream
                # ##############
                space_required_2buff = single_subvolume_sizes['ifmA']*ifmA_num_buf + single_subvolume_sizes['ifmB']*2 + single_subvolume_sizes['ofm'] * 1
                total_space_required = space_required_2buff #space_required_2buff.reshape((len(space_required_2buff),1))
            else:
                # ##############
                # Default - no ping pong
                # ##############
                space_required_2buff = single_subvolume_sizes['ifmA']*ifmA_num_buf + single_subvolume_sizes['ifmB']*1 + single_subvolume_sizes['ofm'] * 1
                total_space_required = space_required_2buff #space_required_2buff.reshape((len(space_required_2buff),1))


            fits_within_memtiles[sched] = total_space_required<memtile_capacity
            valid_fits_within_memtiles[sched] = fits_within_memtiles[sched][np.any(fits_within_memtiles[sched])]

            valid_subvolumes[sched] = {operand: memtile_subvols[operand][np.any(fits_within_memtiles[sched])] for operand in memtile_subvols.keys()}

            valid_memtile_sublayers[sched] = {operand:tiling*self.overlay.mem_splits[operand] for operand, tiling in valid_subvolumes[sched].items()}
            valid_memtile_iters[sched] = self.memtile_iters[sched][np.any(fits_within_memtiles[sched])]        

        self.fits_in_memtile = fits_within_memtiles
        self.valid_fits_in_memtile = valid_fits_within_memtiles

        self.valid_memtile_subvols = valid_subvolumes
        self.valid_memtile_sublayers = valid_memtile_sublayers
        self.valid_memtile_iters = valid_memtile_iters

    def calculate_array_tilings(self):

        for sched, valid_memtile_sublayers in self.valid_memtile_sublayers.items():

            core_max_shapes = {}
            for operand in valid_memtile_sublayers.keys():
                core_max_shapes[operand] = valid_memtile_sublayers[operand] // self.overlay.core_splits[operand]
            core_min_shapes = self.kernel_granularity
            # print(core_max_shapes)

            core_subvols = {}
            core_iters = {}
            
            for k in range(len(core_max_shapes['ofm'])):

                #if sched == 5:
                # ##############
                # Schedule 5: ifm stream, wgt stream (2 buffers of ifm, 2 for wgt, 1 for ofm)
                # ##############
                n_sublist_ifmA = [core_max_shapes['ifmA'][k][0]]
                n_sublist_ifmB = [core_max_shapes['ifmB'][k][0]]

                tmp_ifmA = np.array(n_sublist_ifmA)
                tmp_ifmB = np.array(n_sublist_ifmB)
                

                core_iters[k] = [core_max_shapes['ifmA'][k][0] // tmp_ifmA, core_max_shapes['ifmB'][k][0] // tmp_ifmB, core_max_shapes['ofm'][k][0] // tmp_ifmB]
                core_subvols[k] = {
                    'ifmA': tmp_ifmA,
                    'ifmB': tmp_ifmB,
                    'ofm': tmp_ifmB
                }
            self.core_subvols[sched] = core_subvols
            self.core_iters[sched] = core_iters


    def check_core_constraints(self):
        core_bank_capacity = 1024 * (self.device.core_data_memory // self.device.core_num_banks) #bytes

        
        for sched, core_subvol_dict in self.core_subvols.items():
            #self.core_validity_checks.setdefault(sched,{})
            for mem_sub_id, subvols in core_subvol_dict.items():

                self.vars_dict['Nsubv'] = subvols['ifmB'][0]
                self.vars_dict['ifmA_scale_factor'] = 3 if self.layer.wgt_bytes == 1 else 1
                self.vars_dict['ifmB_scale_factor'] = 3 if self.layer.in_bytes == 1 else 1
                self.vars_dict['ofm_scale_factor'] = 2 if self.layer.out_bytes == 1 else 1

                ## check buffer placements
                space_required = {}
                valid_placement = True
                for bank, formula in self.inverted_placement.items():
                    space_required[bank] = eval(formula, self.vars_dict)
                    
                    valid_placement = valid_placement & (space_required[bank]<=core_bank_capacity)
                
                validity_checks = {}
                validity_checks['buffer_placement'] = valid_placement
                
                # for constraint, formula in self.kernel.other_constraints.items():
                #     validity_checks[constraint] = eval(formula, self.vars_dict)
                
                # subvols that satisfy all constraints
                valid_subvols = np.all(np.vstack(list(validity_checks.values())).T,1)
                # print(validity_checks)
                if sched in self.schedule_list:
                    self.core_validity_checks[sched][mem_sub_id] = validity_checks
                    # self.core_validity_checks

                    # self.valid_core_subvols.setdefault(sched,{})
                    self.valid_core_subvols[sched][mem_sub_id] = {
                        'ifmA': subvols['ifmA'][valid_subvols],
                        'ifmB': subvols['ifmB'][valid_subvols],
                        'ofm': subvols['ofm'][valid_subvols],
                    }

                    self.valid_core_iters[sched][mem_sub_id] = self.core_iters[sched][mem_sub_id]

if __name__=='__main__':
    from overlay import Overlay
    ov=Overlay('8x4','Silu','N32')

    import json
    from layer import Layer
    with open('src/Tiler/tst_elemwise.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)


    t = ElemWiseOpsTiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()
