import numpy as np
import pdb
import copy

from OGOAT.src.Tiler.utils import (
    compute_inverted_placement, 
    factors
)

from dataflow.dataflow_common import iceil

#Notes: 
# ifmA - broadcast tensor
# ifmB - activation tensor

class BroadcastTiler:

    def __init__(self, layer, device, overlay, kernel):

        self.layer = layer
        self.overlay = overlay
        self.device = device
        self.kernel = kernel

        self.ifm_mode = "N1"
        self.itersIfmA = 1
        self.itersIfmB = 1
        self.inter_dims_broadcast = False
        self.innermost_broadcast = False
        self.single_element = False
        self.bcast_vec_pad_N = 8

        self.L1_pin_limit = 2048

        self.kernel_granularity = {
            'ifmA':np.array([kernel.Ngran]),
            'ifmB':np.array([kernel.Ngran]),
            'ofm':np.array([kernel.Ngran])
            }

        #TODO: Need add the other schedule mode
        self.schedule_list = [2]   #IfmA Pin, IfmB/OFM Stream
        
        self.memtile_subvols = {k:{} for k in self.schedule_list}
        self.memtile_sublayers = {k:{} for k in self.schedule_list}
        self.memtile_iters = {k:{} for k in self.schedule_list}

        self.fits_in_memtile = {k:{} for k in self.schedule_list}
        self.valid_fits_in_memtile = {k:{} for k in self.schedule_list}

        self.valid_memtile_subvols = {k:{} for k in self.schedule_list}
        self.valid_memtile_sublayers = {k:{} for k in self.schedule_list}
        self.valid_memtile_iters = {k:{} for k in self.schedule_list}

        self.core_subvols = {k:{} for k in self.schedule_list}
        self.core_iters= {k:{} for k in self.schedule_list}
        
        self.core_validity_checks = {k:{} for k in self.schedule_list}
        self.valid_core_subvols = {k:{} for k in self.schedule_list}
        self.valid_core_iters = {k:{} for k in self.schedule_list}
        
        #For ShapeA and ShapeB, Check if both dim are same, if its not same either one of the dim has to be 1
        def can_broadcast(shapeA, shapeB):
            lenA, lenB = len(shapeA), len(shapeB)
            ndim = max(lenA, lenB)
            # Pad both shapes with 1s on the left
            paddedA = [1] * (ndim - lenA) + list(shapeA)
            paddedB = [1] * (ndim - lenB) + list(shapeB)
            for dimA, dimB in zip(paddedA, paddedB):
                if dimA != dimB and dimA != 1 and dimB != 1:
                    return False
            return True
        
        #check if the broadcasting is within the innermost and outermost dimension, [M, broadcast, N] Need to be handled separately
        def isMultiBCast(shapeA, shapeB):
            while len(shapeA) < len(shapeB): shapeA = [1] + shapeA
            while len(shapeB) < len(shapeA): shapeB = [1] + shapeB
            if shapeA == shapeB or shapeA[-1] == 1:
                return False, len(shapeA) - 1
            for i, (a, b) in enumerate(zip(shapeA, shapeB)):
                if a != b:
                    if a == 1:
                        bcastDims = i
                        break
            for i in range(bcastDims):
                if shapeA[i] > 1:
                    return True, bcastDims
            return False, bcastDims
        
        #Flatten the shape for inner most dims to be Mx1(bcast) and MxN(act)
        # def modify_shape_for_innermost_broadcast(shapeA, shapeB):
        #     shapeA = [np.prod(shapeA[:-1]), np.prod(shapeA[-1])]
        #     shapeB = [np.prod(shapeB[:-1]), np.prod(shapeB[-1])]
        #     return shapeA, shapeB
        
        def modify_shape_for_innermost_broadcast(shapeA, shapeB):
            dim = 0
            for i in range(len(shapeA)):
                if shapeA[i] != shapeB[i]:
                    dim = i
                    break
            shapeA = [np.prod(shapeA[:dim]), np.prod(shapeA[dim:])]
            shapeB = [np.prod(shapeB[:dim]), np.prod(shapeB[dim:])]
            return shapeA, shapeB

        #check if broadcastable
        assert can_broadcast(self.layer.in_wgt_shape, self.layer.in_act_shape), f"The shapes {self.layer.in_wgt_shape} and {self.layer.in_act_shape} cannot be broadcasted"
        multi_channel_or_batch_broadcast, dim = isMultiBCast(self.layer.in_wgt_shape, self.layer.in_act_shape)
        in_wgt_shape = self.layer.in_wgt_shape.copy()
        in_act_shape = self.layer.in_act_shape.copy()
        
        #Equalize the lengths
        while len(self.layer.in_wgt_shape) < len(self.layer.in_act_shape): self.layer.in_wgt_shape = [1] + self.layer.in_wgt_shape
        while len(self.layer.in_act_shape) < len(self.layer.in_wgt_shape): self.layer.in_act_shape = [1] + self.layer.in_act_shape

        #Checking wheather innermost broadcast
        if self.layer.in_wgt_shape[-1] == 1:
            self.innermost_broadcast = True
        
        if np.prod(self.layer.in_wgt_shape) == 1:
            self.single_element = True

        #For intermediate dims broadcast - Outer dimensions will be temporal / re-enqueues
        self.bcast_multi_channel = multi_channel_or_batch_broadcast
        if multi_channel_or_batch_broadcast:
            self.itersIfmA   = np.prod(self.layer.in_wgt_shape[:dim])
            self.itersIfmB   = np.prod(self.layer.in_act_shape[:dim])
        
        #Broadcast vector will be 1xN
        in_wgt_shape = np.array([1, np.prod(self.layer.in_wgt_shape[dim:])])
        #act vector will be MxN 
        in_act_shape = np.array([np.prod(self.layer.in_act_shape[dim:]) // np.prod(self.layer.in_wgt_shape[dim:]), np.prod(self.layer.in_wgt_shape[dim:])]).astype(np.int32).tolist()

        if self.innermost_broadcast:
            #Flatten the shape for inner most dims to be Mx1(bcast) and MxN(act), output shape becomes 2D
            in_act_shape, in_wgt_shape = modify_shape_for_innermost_broadcast(self.layer.in_act_shape, self.layer.in_wgt_shape)

            #Pad Mx1 to split across all cores, N = 8(host padding)
            if self.single_element:
                ifmA_padded_shapes = np.array([1, self.bcast_vec_pad_N])
            else:
                ifmA_padded_shapes = np.array([iceil(np.prod(in_wgt_shape), self.overlay.core_splits['ifmA'][0]), self.bcast_vec_pad_N]).astype(int).tolist()
            #ifmB should have same M as ifmA, N has to be padded to kernel granularity
            ifmB_padded_shapes = in_act_shape.copy()

            #TODO: change the check into check full L1 size
            if in_act_shape[-1] > self.L1_pin_limit or self.single_element:
                self.ifm_mode  = "N32"
                #pad ifmB N to split into all cores with kernel_granularity
                ifmB_padded_shapes[-1] = np.int32(iceil(in_act_shape[-1], (self.kernel_granularity['ifmA'][0] * self.overlay.core_splits['ifmA'][0])))
            else:
                self.ifm_mode  = "N1"
                #pad ifmB N to kernel granularity if streaming all subvols into cores
                ifmB_padded_shapes[-1] = np.int32(iceil(in_act_shape[-1], self.kernel_granularity['ifmA'][0]))

            #bcast(MxN) vector will be padded always with N=8
            in_wgt_shape[1] = self.bcast_vec_pad_N
			#Padding not done in layer.py, broadcast vector N = 8
            self.layer.in_wgt_shape[-1] = self.bcast_vec_pad_N

            #ifmB M has to be same as ifmA
            ifmB_padded_shapes[0] = ifmA_padded_shapes[0]
        else:
            #ifmB should have same N as ifmA, M has to be padded to split into all cores
            ifmB_padded_shapes = in_act_shape.copy()

            split_ratio        = self.overlay.core_splits['ifmA'] / self.overlay.mem_splits['ifmA']  ## Can be non-integer

            #TODO: change the check into check full L1 size
            if np.prod(in_wgt_shape) > self.L1_pin_limit:
                self.ifm_mode  = "N32"
                #pad ifmA N to split into all cores with kernel_granularity
                ifmA_padded_shapes = np.array([1, iceil(np.prod(in_wgt_shape), (self.kernel_granularity['ifmA'][0] * self.overlay.core_splits['ifmA'][0]))]).astype(np.int32).tolist()
            else:
                self.ifm_mode = "N1"
                #pad ifmA N to kernel_granularity
                ifmA_padded_shapes = np.array([1, iceil(np.prod(in_wgt_shape), (self.kernel_granularity['ifmA'][0] * 4))]).astype(np.int32).tolist()
                #pad ifmB M to split across all cores
                ifmB_padded_shapes[0] = np.int32(iceil(ifmB_padded_shapes[0], self.overlay.core_splits['ifmA'][0]))
            #ifmB N has to be same as ifmA
            ifmB_padded_shapes[-1] = ifmA_padded_shapes[-1]
        #ofm shape should be same as ifmB
        ofm_padded_shapes = ifmB_padded_shapes
        
        self.padded_shapes = {
            'ifmA': ifmA_padded_shapes,   #Note broadcast tensor is ifmA
            'ifmB': ifmB_padded_shapes,
            'ofm': ofm_padded_shapes
        }

        self.cflags = {
            'ifmA' : self.layer.in_wgt_shape,
            'ifmB' : self.layer.in_act_shape
        }

        self.host_padding = {
            'ifmA' : np.array(self.layer.in_wgt_shape).astype(np.int32).tolist(),
            'ifmB' : np.array(self.layer.in_act_shape).astype(np.int32).tolist(),
            'ofm'  : np.array(self.layer.out_act_shape).astype(np.int32).tolist()
        }

        self.padding = {
            'ifmA': np.prod(ifmA_padded_shapes) - (np.prod(self.layer.in_wgt_shape) // self.itersIfmA),
            'ifmB': np.prod(ifmB_padded_shapes) - (np.prod(self.layer.in_act_shape) // self.itersIfmB),
            'ofm': np.prod(ofm_padded_shapes)- (np.prod(self.layer.out_act_shape) // self.itersIfmB)
        }

        self.vars_dict = {
            'ifmA_bytes': layer.wgt_bytes,
            'ifmB_bytes': layer.in_bytes,
            'ofm_bytes': layer.out_bytes,
            ## check other constraints
            'Ngran': self.kernel.Ngran
        }

        ## calculate bank-wise space formulas
        self.inverted_placement = compute_inverted_placement(self.kernel.placement_constraints)

    def calculate_memtile_tilings(self):
        schedule_tilings_ifmA = {}
        schedule_tilings_ifmB = {}

        memtile_max_shapes = {}
        memtile_min_shapes = {}
        
        k = 2
        #create a max and min memtile size
        for operand in self.overlay.mem_splits.keys():
            split_ratio        = self.overlay.core_splits[operand] / self.overlay.mem_splits[operand]  ## Can be non-integer
            if self.innermost_broadcast:
                if operand == 'ifmA':
                    #M gets split spatially across all memtiles, N = 8, so taking only range of M
                    if self.single_element:
                        memtile_max_shapes[operand] = np.array([1])
                    else:
                        memtile_max_shapes[operand] = np.array([(self.padded_shapes[operand][0] // self.overlay.mem_splits[operand][0])]) 
                    memtile_min_shapes[operand] = np.array([1])
                else:
                    #for IfmB and Ofm M follows same as ifmA, ifmB and ofm N gets split spatially across all memtiles, minimum N has to follow kernel granularity
                    #since M is fixed with ifmA, taking range of N
                    memtile_max_shapes[operand] = np.array([(self.padded_shapes[operand][-1] // self.overlay.mem_splits[operand][0])])
                    memtile_min_shapes[operand] = np.array([self.kernel_granularity[operand][0] * split_ratio])
            else:
                if self.ifm_mode == 'N1':
                    memtile_max_shapes['ifmA'] = np.prod(self.padded_shapes["ifmA"])
                    memtile_min_shapes['ifmA'] = np.prod(self.padded_shapes["ifmA"])
                    #for IfmB taking ranges only on the M, N has to be same as IfmA
                    memtile_max_shapes['ifmB'] = np.prod(self.padded_shapes[operand]) // np.prod(memtile_max_shapes['ifmA']) // self.overlay.mem_splits[operand][0]
                    memtile_min_shapes['ifmB'] = 1
                else:
                    #N gets split spatially across all memtiles
                    if operand != 'ifmA':
                        memtile_max_shapes[operand] = np.prod(self.padded_shapes[operand]) // np.prod(memtile_max_shapes['ifmA']) // self.overlay.mem_splits[operand][0]
                        memtile_min_shapes[operand] = 1
                    else:
                        memtile_max_shapes[operand] = np.prod(self.padded_shapes[operand])//self.overlay.mem_splits[operand][0]
                        memtile_min_shapes[operand] = self.kernel_granularity[operand][0] * split_ratio
        
        memtile_min_shapes = {k: min(memtile_max_shapes[k], memtile_min_shapes[k]) for k in memtile_max_shapes}

        #extract range of tilings in ifmA and ifmB, ofm follows same as ifmB
        schedule_tilings_ifmA = np.array(factors(memtile_max_shapes['ifmA'], memtile_min_shapes['ifmA']))
        schedule_tilings_ifmB = np.array(factors(memtile_max_shapes['ifmB'], memtile_min_shapes['ifmB']))
        
        schedules = schedule_tilings_ifmA
        if self.ifm_mode == 'N32' or self.single_element:
            schedules = [[i,j] for i in schedule_tilings_ifmA for j in schedule_tilings_ifmB]
        #for each sch ifmA and ifmB pairs, check memsize validity
        sch_idx = 0
        sch_idx_ifmB = 0
        if self.innermost_broadcast:
            for row in range(len(schedules)):
                if self.ifm_mode == "N1" :
                    #M spatial split, N pin
                    tmp_ifmA = schedules[row] * 8
                    tmp_ifmB = schedules[row] * self.padded_shapes['ifmB'][1]
                    H_memtile = schedules[row]
                    if self.check_if_valid(tmp_ifmA, tmp_ifmB, tmp_ifmA) and H_memtile < 8 and H_memtile % 4 == 0:
                        sch_idx = row
                else:
                    #M temporal split, N spatial split
                    tmp_ifmA = schedules[row][0] * 8
                    tmp_ifmB = np.prod(schedules[row])
                    H_memtile = schedules[row][0]
                    if self.check_if_valid(tmp_ifmA, tmp_ifmB, tmp_ifmA) and H_memtile < 8:
                        sch_idx = row
            if self.ifm_mode == "N1":
                #M spatial split, N pin
                tmp_ifmA = schedules[sch_idx] * 8
                tmp_ifmB = schedules[sch_idx] * self.padded_shapes['ifmB'][1]
                curr_iter_ifmA = np.prod(self.padded_shapes['ifmA']) // self.overlay.mem_splits['ifmA'][0] // np.prod(tmp_ifmA)
                curr_iter_ifmB = np.prod(self.padded_shapes['ifmB']) // self.overlay.mem_splits['ifmB'][0] // np.prod(tmp_ifmB)
                idxA = schedule_tilings_ifmA.tolist().index(schedules[sch_idx])
                self.scheduled_tilings_ifmA = [schedule_tilings_ifmA[idxA]]
                
            else:
                #M temporal split, N spatial split
                tmp_ifmA = schedules[sch_idx][0] * 8
                tmp_ifmB = np.prod(schedules[sch_idx])
                curr_iter_ifmA = np.prod(self.padded_shapes['ifmA']) // np.prod(tmp_ifmA)
                curr_iter_ifmB = np.prod(self.padded_shapes['ifmB']) // self.overlay.mem_splits['ifmB'][0] // np.prod(tmp_ifmB)
                idxB = schedule_tilings_ifmB.tolist().index(schedules[sch_idx][1])
                idxA = schedule_tilings_ifmA.tolist().index(schedules[sch_idx][0])
                self.scheduled_tilings_ifmB = [schedule_tilings_ifmB[idxB]]
                self.scheduled_tilings_ifmA = [schedule_tilings_ifmA[idxA]]
        else:
            #Take the max of the single subvol in ifmA, to transmit A in 1 iters each enqueues, otherwise we need to do re-enqueus to synchronize data
            if self.ifm_mode == 'N1':
                tmp_ifmA       = np.array([1, max(schedule_tilings_ifmA)]).astype(int)
                curr_iter_ifmA = 1
            else:
                tmp_ifmA       = np.array([1, max(schedule_tilings_ifmA)]).astype(int)
                curr_iter_ifmA = np.prod(self.padded_shapes['ifmA']) // self.overlay.mem_splits['ifmA'][0] // np.prod(tmp_ifmA)
            tmp_ifmB = 0
            #Memtile validity check
            new_iter = np.prod(np.prod(self.padded_shapes["ifmB"]) // self.overlay.mem_splits["ifmB"][0] // (np.prod(tmp_ifmA) * schedule_tilings_ifmB[sch_idx_ifmB])) * self.itersIfmB
            for row in range(len(schedule_tilings_ifmB)):
                curr_iter = np.prod(np.prod(self.padded_shapes["ifmB"]) // self.overlay.mem_splits["ifmB"][0] // (np.prod(tmp_ifmA) * schedule_tilings_ifmB[row])) * self.itersIfmB
                ifmB_sch = np.array([schedule_tilings_ifmB[row], np.prod(tmp_ifmA)]).astype(int)
                if curr_iter <= new_iter and curr_iter < 1024 and self.check_if_valid(tmp_ifmA, ifmB_sch, ifmB_sch):
                    new_iter = curr_iter
                    sch_idx_ifmB = row
                    if not self.bcast_multi_channel or self.itersIfmA < 4:
                        break
            tmp_ifmB = np.array([schedule_tilings_ifmB[sch_idx_ifmB], np.prod(tmp_ifmA)]).astype(int)
            curr_iter_ifmB = new_iter // self.itersIfmB
            assert(self.check_if_valid(tmp_ifmA, tmp_ifmB, tmp_ifmB))
            sch_idx_ifmB = 1 if sch_idx_ifmB == 0 else sch_idx_ifmB
            self.scheduled_tilings_ifmB = schedule_tilings_ifmB[:sch_idx_ifmB]
            self.scheduled_tilings_ifmA = [max(schedule_tilings_ifmA)]

        self.memtile_iters[k] = np.hstack([self.itersIfmA * curr_iter_ifmA, self.itersIfmB * curr_iter_ifmB, self.itersIfmB * curr_iter_ifmB]).astype(int)
        self.memtile_subvols[k] = {
            'ifmA' : tmp_ifmA,
            'ifmB' : tmp_ifmB,
            'ofm'  : tmp_ifmB
        }

    def check_if_valid(self, ifmA, ifmB, ofm):
        single_subvolume_sizes = {
                'ifmA': np.prod(ifmA) * self.layer.wgt_bytes,
                'ifmB': np.prod(ifmB) * self.layer.in_bytes,
                'ofm': np.prod(ofm) * self.layer.out_bytes
            }
        if not self.innermost_broadcast:
            if self.ifm_mode == "N1" :
                if self.overlay.rows * np.prod(ifmA) > np.prod(ifmB):
                    return False
                ifmB_H = np.prod(ifmB) //  np.prod(ifmA)
                if ifmB_H >= 32 and ifmB_H % self.overlay.rows != 0:
                    return False
        memtile_capacity = self.device.memtile_capacity * self.device.memtile_rows * 1024 # In bytes
        total_space_required = []
        opt_num_buf = 0
        for num_buf in range(2, 0, -1):
            space_required_2buff = single_subvolume_sizes['ifmA']*num_buf + single_subvolume_sizes['ifmB']*num_buf + single_subvolume_sizes['ofm'] * 1
            total_space_required = space_required_2buff
            if total_space_required < memtile_capacity:
                opt_num_buf = num_buf
                break
        
        if opt_num_buf == 0:
            return False
        
        return True

    def check_valid_memtile_tilings(self):
        memtile_capacity = self.device.memtile_capacity * self.device.memtile_rows * 1024 # In bytes

        fits_within_memtiles = {}
        valid_subvolumes = {}
        valid_fits_within_memtiles = {}
        valid_memtile_iters = {}

        for sched, memtile_subvols in self.memtile_subvols.items():
            if not memtile_subvols:
                continue

            single_subvolume_sizes = {
                'ifmA': np.prod(memtile_subvols['ifmA']) * self.layer.wgt_bytes,
                'ifmB': np.prod(memtile_subvols['ifmB']) * self.layer.in_bytes,
                'ofm': np.prod(memtile_subvols['ofm']) * self.layer.out_bytes
            }

            total_space_required = []
            opt_num_buf = 0
            for num_buf in range(2, 0, -1):
                space_required_2buff = single_subvolume_sizes['ifmA']*num_buf + single_subvolume_sizes['ifmB']*num_buf + single_subvolume_sizes['ofm'] * 1
                total_space_required = space_required_2buff
                if total_space_required < memtile_capacity:
                    opt_num_buf = num_buf
                    break
            assert opt_num_buf != 0, f"Subvols not fitting in memtile {total_space_required} available space {memtile_capacity}"
            if opt_num_buf == 1:
                self.ping_pong_ifmA = False
                self.ping_pong_ifmB = False
            else: # opt_num_buf == 2:
                self.ping_pong_ifmA = True
                self.ping_pong_ifmB = True        
            fits_within_memtiles[sched] = total_space_required<memtile_capacity
            valid_fits_within_memtiles[sched] = fits_within_memtiles[sched][np.any(fits_within_memtiles[sched])]

            valid_subvolumes[sched] = {operand: memtile_subvols[operand][np.any(fits_within_memtiles[sched])] for operand in memtile_subvols.keys()}

            valid_memtile_iters[sched] = self.memtile_iters[sched][np.any(fits_within_memtiles[sched])]        

        self.fits_in_memtile = fits_within_memtiles
        self.valid_fits_in_memtile = valid_fits_within_memtiles
        self.valid_memtile_subvols = valid_subvolumes
        self.valid_memtile_iters = valid_memtile_iters


    def calculate_array_tilings(self):
        for sched, valid_memtile_subvols in self.valid_memtile_subvols.items():    
            core_max_shapes = {}
            core_subvols = {}
            core_iters = {}
            divisor = 4
            if self.innermost_broadcast:
                if self.ifm_mode == 'N1':
                    splits_H = np.hstack(self.scheduled_tilings_ifmA)
                    for operand in valid_memtile_subvols.keys():
                        if operand == 'ifmA':
                            core_max_shapes[operand] = np.array([np.prod([s_h // divisor, self.bcast_vec_pad_N]) for s_h in splits_H ]).astype(int)
                        else:
                            core_max_shapes[operand] = np.array([np.prod([s_h // divisor, self.padded_shapes['ifmB'][1]]) for s_h in splits_H]).astype(int)
                else:
                    splits_H = np.hstack(self.scheduled_tilings_ifmA)
                    for operand in valid_memtile_subvols.keys():
                        if operand == 'ifmA':
                            core_max_shapes[operand] = np.array([np.prod([s_h, self.bcast_vec_pad_N]) for s_h in splits_H ]).astype(int)
                        else:
                            core_max_shapes[operand] = np.array([np.prod([np.prod(valid_memtile_subvols[operand]) // divisor])]).astype(int)
            else:
                for operand in valid_memtile_subvols.keys():
                    splits = np.hstack(self.scheduled_tilings_ifmB) if operand != 'ifmA' else np.array([1 for _ in range(len(self.scheduled_tilings_ifmB))])
                    core_max_shapes[operand] = np.array([np.prod(valid_memtile_subvols[operand]) // (self.overlay.core_splits[operand][0] // self.overlay.mem_splits[operand][0]) // d for d in splits]).astype(int)
                    if self.ifm_mode == 'N1':
                        divisor = 4 if operand != 'ifmA' else 1
                        core_max_shapes[operand] = np.array([np.prod(valid_memtile_subvols[operand]) // divisor // d for d in splits]).astype(int)
        for k in range(len(core_max_shapes['ofm'])):
            #if sched == 5:
            # ##############
            # Schedule 5: ifm stream, wgt stream (2 buffers of ifm, 2 for wgt, 1 for ofm)
            # ##############
            n_sublist_ifmA = [core_max_shapes['ifmA'][k]]
            n_sublist_ifmB = [core_max_shapes['ifmB'][k]]
            
            split = 4
            tmp_ifmA = np.array(n_sublist_ifmA)
            tmp_ifmB = np.array(n_sublist_ifmB)
            sub_ifmB = np.array(n_sublist_ifmB)

            if self.innermost_broadcast:
                if self.single_element:
                        core_iters[k] = [[1], (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist(), (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist()]
                elif self.ifm_mode == "N1":
                    core_iters[k] = [(np.prod(valid_memtile_subvols['ifmA'][0]) // split // tmp_ifmA).astype(int).tolist(), (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist(), (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist()]
                else:
                    core_iters[k] = [(np.prod(valid_memtile_subvols['ifmA'][0]) // tmp_ifmA).astype(int).tolist(), (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist(), (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist()]
            else:
                core_iters[k] = [[1], (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist(), (np.prod(valid_memtile_subvols['ifmB'][0]) // split // tmp_ifmB).astype(int).tolist()]
            
            core_subvols[k] = {
                'ifmA': tmp_ifmA,
                'ifmB': sub_ifmB,
                'ofm': sub_ifmB
            }
        self.core_subvols[sched] = core_subvols
        self.core_iters[sched] = core_iters


    def check_core_constraints(self):
        core_bank_capacity = 1024 * (self.device.core_data_memory // self.device.core_num_banks) #bytes
        for sched, core_subvol_dict in self.core_subvols.items():
            #self.core_validity_checks.setdefault(sched,{})
            for mem_sub_id, subvols in core_subvol_dict.items():

                self.vars_dict['Nsubv'] = subvols['ifmA'][0]
                self.vars_dict['Msubv'] = subvols['ifmB'][0] // subvols['ifmA'][0]

                #setting the scale factor based on the qdq params
                self.vars_dict['ifmA_scale_factor'] = 3 if self.layer.wgt_bytes == 1 else 1
                self.vars_dict['ifmB_scale_factor'] = 3 if self.layer.in_bytes == 1 else 1
                self.vars_dict['ofm_scale_factor'] = 2 if self.layer.out_bytes == 1 else 1
                
                ## check buffer placements
                space_required = {}
                valid_placement = True
                for bank, formula in self.inverted_placement.items():
                    space_required[bank] = eval(formula, self.vars_dict)
                    
                    valid_placement = valid_placement & (space_required[bank]<=core_bank_capacity)
                
                validity_checks = {}
                validity_checks['buffer_placement'] = valid_placement
                
                # subvols that satisfy all constraints
                valid_subvols = np.all(np.vstack(list(validity_checks.values())).T,1)
                if sched in self.schedule_list:
                    self.core_validity_checks[sched][mem_sub_id] = validity_checks
                    # self.core_validity_checks

                    # self.valid_core_subvols.setdefault(sched,{})
                    self.valid_core_subvols[sched][mem_sub_id] = {
                        'ifmA': subvols['ifmA'][valid_subvols],
                        'ifmB': subvols['ifmB'][valid_subvols],
                        'ofm': subvols['ofm'][valid_subvols],
                    }

                    self.valid_core_iters[sched][mem_sub_id] = self.core_iters[sched][mem_sub_id]
        
        keys_true = [k for k, v in self.core_validity_checks[sched].items() if v.get('buffer_placement', True)]
        
        self.valid_core_iters[sched][0] = self.valid_core_iters[sched][keys_true[0]]
        self.valid_core_subvols[sched][0] = self.valid_core_subvols[sched][keys_true[0]]