import copy
import numpy as np

from dataflow.dataflow_common import iceil
from OGOAT.src.Tiler.utils import (
    compute_inverted_placement, 
    check_reuse_chain_validity, 
    check_iteration_chain_length,
    check_special_handle_shapes
)

class MatMulTiler:

    def __init__(self, layer, device, overlay, kernel):

        self.layer = layer
        self.overlay = overlay
        self.device = device
        self.kernel = kernel

        B = self.layer.in_act_shape[0]
        M = self.layer.in_act_shape[-2]
        K = self.layer.in_act_shape[-1]
        N = self.layer.in_wgt_shape[-1]

        self.original_shapes = {
            'ifm': np.array(self.layer.in_act_shape),
            'wgt': np.array(self.layer.in_wgt_shape),
            'ofm': np.array(self.layer.out_act_shape)
        }

        # Create a dict of perm for each tensor.
        ## Input tensors (ifm/wgt) have perm list reversed since we need to decipher the input from default order.
        ## Output tensor (ofm) have perm list as is.
        self.perm_order = {
            'ifm': self.layer.rev_permA, 
            'wgt': self.layer.rev_permB, 
            'ofm': self.layer.permY
        }

        # Align the innermost dimensions as W8.
        self.rounded_shapes = copy.deepcopy(self.original_shapes)
        for tensor, perm in self.perm_order.items():
            self.rounded_shapes[tensor][perm[-1]] = iceil(self.rounded_shapes[tensor][perm[-1]], 8)

            # TODO: Need better handling
            # If B dimension for ifm/ofm is outermost, align to Bsplit to be host padded.
            if perm[0] == 0 and (tensor != 'wgt' or 'actxact' in self.layer.op_type.lower()):
                self.rounded_shapes[tensor][0] = iceil(self.rounded_shapes[tensor][0], 
                                                    self.overlay.core_splits[tensor][0])
                
        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # #
        # Special Handling
        default_perm = np.array([0, 1, 2])
        if (np.all(self.layer.rev_permA == default_perm) and 
            np.all(self.layer.rev_permB == default_perm) and 
            np.all(self.layer.permY == default_perm)):
            Mrounded_dim = self.rounded_shapes['ifm'][-2]
            Krounded_dim = self.rounded_shapes['ifm'][-1]
            Nrounded_dim = self.rounded_shapes['wgt'][-1]
            # TODO: Update scheduler to handle it using dma padding.
            Mrounded_dim, Krounded_dim, Nrounded_dim = check_special_handle_shapes((M, K, N), [Mrounded_dim, Krounded_dim, Nrounded_dim], self.kernel)

            self.rounded_shapes['ifm'][-1] = Krounded_dim
            self.rounded_shapes['wgt'][-1] = Nrounded_dim
            self.rounded_shapes['ofm'][-1] = Nrounded_dim
        # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # # 

        self.schedule_list = [1, 2, 5]
        if (hasattr(self.layer, 'debug_info') and
            self.layer.debug_info.get('sched') is not None):
            self.schedule_list = [int(self.layer.debug_info['sched'])]

        # Change data structures to use vectorized format
        self.schedule_tilings = {k: None for k in self.schedule_list}

        self.core_subvols = {'ifm': None, 'wgt': None, 'ofm': None}
        self.padded_shapes = {'ifm': None, 'wgt': None, 'ofm': None}
        self.host_shapes = {'ifm': None, 'wgt': None, 'ofm': None}
        self.host_padding = {'ifm': None, 'wgt': None, 'ofm': None}
        self.dma_padding = {'ifm': None, 'wgt': None, 'ofm': None}

        self.core_iters = {k: {} for k in self.schedule_list}
        self.memtile_subvols = {k: {} for k in self.schedule_list}
        self.memtile_iters = {k: {} for k in self.schedule_list}
        self.loop_iters = {k: {} for k in self.schedule_list}
        
        self.memtile_pingpong = {k: None for k in self.schedule_list}
        self.disable_ifm_pingpong = {k: None for k in self.schedule_list}
        
        self.buffer_placement_choices = {k: None for k in self.schedule_list}
        
        # Valid configurations
        self.valid_core_subvols = {k: {} for k in self.schedule_list}
        self.valid_core_iters = {k: {} for k in self.schedule_list}
        self.valid_core_indices = {k: None for k in self.schedule_list}
        self.valid_memtile_subvols = {k: {} for k in self.schedule_list}
        self.valid_memtile_iters = {k: {} for k in self.schedule_list}
        self.valid_loop_iters = {k: {} for k in self.schedule_list}
        self.valid_padded_shapes = {k: {} for k in self.schedule_list}
        self.valid_host_shapes = {k: {} for k in self.schedule_list}
        self.valid_host_padding = {k: {} for k in self.schedule_list}
        self.valid_dma_padding = {k: {} for k in self.schedule_list}
        self.valid_memtile_pingpong = {k: None for k in self.schedule_list}
        self.valid_disable_ifm_pingpong = {k: None for k in self.schedule_list}
        self.valid_buffer_placement_choices = {k: None for k in self.schedule_list}

        self.vars_dict = {
            'ifm_bytes': layer.in_bytes,
            'wgt_bytes': layer.wgt_bytes,
            'ofm_bytes': layer.out_bytes,
            # fix for bias and coefficient packing in wgt.bin in case of qdq 
            # TODO more general solution for different datatype etc.
            'bias_bytes': 8 * (1 if layer.coeff_shape[0]<=layer.in_wgt_shape[-1] else 2) if 'qdq' in layer.op_type else layer.wgt1_bytes,
            'tdm_bytes': kernel.tdm_bytes,
            ## check other constraints
            'Mgran': self.kernel.Mgran,
            'Kgran': self.kernel.Kgran,
            'Ngran': self.kernel.Ngran,
            'np'   : np
        }

        self.placement_options = np.array(['double_double', 'single_single'])
        core_bank_capacity = 1024 * (self.device.core_data_memory // self.device.core_num_banks) #bytes
        self.vars_dict['bank_capacity'] = core_bank_capacity

        if 'actxact' in layer.op_type.lower():
            self.vars_dict['bias_bytes'] = 0

        ## calculate bank-wise space formulas
        self.inverted_placement = compute_inverted_placement(self.kernel.placement_constraints)

    def _generate_subvol_grid(self, min_size, max_size, gran=1):
        """
        Generate a grid of possible subvolume sizes in arithmetic progression.
        
        Args:
            min_size: Minimum size (kernel minimum)
            max_size: Maximum size (dimension ceiling)
            gran: Step size (kernel granularity)
        Returns:
            Array of valid subvolume sizes
        """
        if max_size <= 0:
            return np.array([])

        # Generate arithmetic progression from min_size to max_size with step=gran
        sizes = np.arange(min_size, max_size + gran, gran)

        return sizes
    
    def _generate_core_subvols(self):
        """Generate all valid core subvolume combinations using vectorization."""
        # Get core-level constraints
        # Bsubv_max and Bsubv_min will be equal to 1 since kernel accepts only 1 batch.
        kernel = self.kernel

        Bsubv_max = Bsubv_min = kernel.Bgran
        Msubv_min = kernel.Mgran
        Ksubv_min = kernel.Kgran
        Nsubv_min = kernel.Ngran

        Msubv_max = max(self.rounded_shapes['ifm'][-2], self.rounded_shapes['ofm'][-2])
        Ksubv_max = max(self.rounded_shapes['ifm'][-1], self.rounded_shapes['wgt'][-2])
        Nsubv_max = max(self.rounded_shapes['wgt'][-1], self.rounded_shapes['ofm'][-1])

        # Ensure maximums are ceiling aligned with granularity
        # Also, adjust for overlay core splits
        Msubv_max = int(np.ceil(Msubv_max / kernel.Mgran / self.overlay.core_splits['ifm'][-2]) * kernel.Mgran)
        Ksubv_max = int(np.ceil(Ksubv_max / kernel.Kgran / self.overlay.core_splits['ifm'][-1]) * kernel.Kgran)
        Nsubv_max = int(np.ceil(Nsubv_max / kernel.Ngran / self.overlay.core_splits['wgt'][-1]) * kernel.Ngran)

        Msubv_max = max(Msubv_max, kernel.Mgran*kernel.outer_loop)
        Ksubv_max = max(Ksubv_max, kernel.Kgran*kernel.inner_loop)
        Nsubv_max = max(Nsubv_max, kernel.Ngran*kernel.outer_loop)

        # Upper bound for subvolumes based on core bank capacity
        ## Determine M & N upper bounds using following relation
        ## Msubv * Nsubv * tdm_bytes <= core_bank_capacity
        #
        ## Determine K upper bound using following relation
        ## Ksubv * Nsubv * wgt_bytes <= core_bank_capacity
        core_bank_capacity = self.vars_dict['bank_capacity']
        Msubv_max = min(Msubv_max, (core_bank_capacity / (Nsubv_min*self.vars_dict['ofm_bytes']*2)))   # 0.5*TDM buffer
        Ksubv_max = min(Ksubv_max, (core_bank_capacity / (Nsubv_min*self.vars_dict['wgt_bytes'])))     # wgt buffer
        Nsubv_max = min(Nsubv_max, (core_bank_capacity / (Msubv_min*self.vars_dict['ofm_bytes']*2)))   # 0.5*TDM buffer

        # Generate grids for each dimension
        B_grid = self._generate_subvol_grid(Bsubv_min, Bsubv_max, kernel.Bgran)
        M_grid = self._generate_subvol_grid(Msubv_min, Msubv_max, kernel.Mgran)
        K_grid = self._generate_subvol_grid(Ksubv_min, Ksubv_max, kernel.Kgran)
        N_grid = self._generate_subvol_grid(Nsubv_min, Nsubv_max, kernel.Ngran)

        # Create mesh grid for all combinations
        B_mesh, M_mesh, K_mesh, N_mesh = np.meshgrid(B_grid, M_grid, K_grid, N_grid, indexing='ij')
        
        # Fill in the arrays
        ifm_subvols = np.column_stack([B_mesh.flatten(), M_mesh.flatten(), K_mesh.flatten()])
        wgt_subvols = np.column_stack([B_mesh.flatten(), K_mesh.flatten(), N_mesh.flatten()])
        ofm_subvols = np.column_stack([B_mesh.flatten(), M_mesh.flatten(), N_mesh.flatten()])

        return {'ifm': ifm_subvols, 'wgt': wgt_subvols, 'ofm': ofm_subvols}

    def _calculate_padded_shapes_vectorized(self, core_subvols):
        """Calculate padded shapes for all subvolumes at once using vectorization."""
        # Create lambda function to calculate padding for any tensor type
        calculate_padding = lambda tensor_type: np.ceil(self.rounded_shapes[tensor_type] / 
                                                     (core_subvols[tensor_type] * self.overlay.core_splits[tensor_type])) * \
                                             core_subvols[tensor_type] * self.overlay.core_splits[tensor_type]
        
        # Apply the lambda to each tensor type
        return {
            'ifm': calculate_padding('ifm').astype(np.int32),
            'wgt': calculate_padding('wgt').astype(np.int32),
            'ofm': calculate_padding('ofm').astype(np.int32)
        }
    
    def _calculate_host_shapes_vectorized(self, padded_shapes):
        """Calculate vectorized host shapes for all subvolume combinations.
        
        For most operations, host shapes are consistent. For actxwgt operations,
        shapes vary based on derived padded shapes. This method handles both cases
        efficiently using vectorization.
        """
        calculate_host_shape = lambda tensor_type: np.tile(self.rounded_shapes[tensor_type], reps=(self.num_combinations, 1)) 
                
        return {
            'ifm': calculate_host_shape('ifm'),
            'wgt': calculate_host_shape('wgt') if 'actxact' in self.layer.op_type else padded_shapes['wgt'],
            'ofm': calculate_host_shape('ofm')
        }

    def _calculate_padding_vectorized(self, actual_shapes, padded_shapes):
        """Calculate host/DMA padding for all subvolumes at once using vectorization."""
        ifm_padding = padded_shapes['ifm'] - actual_shapes['ifm']
        wgt_padding = padded_shapes['wgt'] - actual_shapes['wgt']
        ofm_padding = padded_shapes['ofm'] - actual_shapes['ofm']

        return {'ifm': ifm_padding, 'wgt': wgt_padding, 'ofm': ofm_padding}

    def calculate_array_tilings(self):
        """Calculate core subvolumes using vectorized operations."""
        # Generate all core subvol combinations as arrays
        self.core_subvols = self._generate_core_subvols()

        # Store number of combinations for later use
        self.num_combinations = self.core_subvols['ifm'].shape[0]
        
        # Calculate padded and updated host shapes for all combinations at once
        self.padded_shapes = self._calculate_padded_shapes_vectorized(self.core_subvols)
        self.host_shapes = self._calculate_host_shapes_vectorized(self.padded_shapes)

        # Calculate host and DMA padding for all combinations.
        self.host_padding = self._calculate_padding_vectorized(self.original_shapes, self.host_shapes)
        self.dma_padding = self._calculate_padding_vectorized(self.host_shapes, self.padded_shapes)

    def calculate_memtile_tilings(self):
        """Calculate memtile tilings for each schedule using vectorization."""
        if not hasattr(self, 'core_subvols') or self.core_subvols['ifm'] is None:
            return

        # Calculate split ratios once
        split_ratio = {
            'ifm': self.overlay.core_splits['ifm'] // self.overlay.mem_splits['ifm'],
            'wgt': self.overlay.core_splits['wgt'] // self.overlay.mem_splits['wgt'],
            'ofm': self.overlay.core_splits['ofm'] // self.overlay.mem_splits['ofm']
        }
        
        # Apply each schedule to all core subvols at once
        for schedule_id in self.schedule_list:
            # Initialize memtile_subvols with all streaming.
            ifm_memtile_subvols = self.core_subvols['ifm'] * split_ratio['ifm']
            wgt_memtile_subvols = self.core_subvols['wgt'] * split_ratio['wgt']
            ofm_memtile_subvols = self.core_subvols['ofm'] * split_ratio['ofm']

            # Apply schedule-specific calculations to all cores at once
            if schedule_id == 1:    # ifm pin, wgt full, ofm stream
                ifm_memtile_subvols[:, -1] = self.padded_shapes['ifm'][:, -1]
                wgt_memtile_subvols[:, -2:] = self.padded_shapes['wgt'][:, -2:] // self.overlay.mem_splits['wgt'][-2:]
                
            elif schedule_id == 2:  # ifm pin, wgt stream, ofm stream
                ifm_memtile_subvols[:, -1] = self.padded_shapes['ifm'][:, -1]
                
            elif schedule_id == 5:  # ifm stream, wgt stream, ofm stream (initialized above)
                pass   
            

            # Store memtile subvols for this schedule
            self.memtile_subvols[schedule_id] = {
                'ifm': ifm_memtile_subvols,
                'wgt': wgt_memtile_subvols,
                'ofm': ofm_memtile_subvols
            }
            
            ifm_memtile_iters = self.padded_shapes['ifm'] // (ifm_memtile_subvols * self.overlay.mem_splits['ifm'])
            wgt_memtile_iters = self.padded_shapes['wgt'] // (wgt_memtile_subvols * self.overlay.mem_splits['wgt'])
            ofm_memtile_iters = self.padded_shapes['ofm'] // (ofm_memtile_subvols * self.overlay.mem_splits['ofm'])
            
            # Store memtile iterations
            self.memtile_iters[schedule_id] = {
                'ifm': ifm_memtile_iters,
                'wgt': wgt_memtile_iters,
                'ofm': ofm_memtile_iters
            }

            ifm_core_iters = ifm_memtile_subvols // (self.core_subvols['ifm'] * split_ratio['ifm'])
            wgt_core_iters = wgt_memtile_subvols // (self.core_subvols['wgt'] * split_ratio['wgt'])
            ofm_core_iters = ofm_memtile_subvols // (self.core_subvols['ofm'] * split_ratio['ofm'])

            # Store core iterations
            self.core_iters[schedule_id] = {
                'ifm': ifm_core_iters,
                'wgt': wgt_core_iters,
                'ofm': ofm_core_iters
            }

    def _filter_valid_data(self, data, valid_indices):
        """
        Filters tensor data dictionary by keeping only elements at specified indices.
        
        Args:
            data (dict): Dictionary with array values
            valid_indices (array-like): Indices to retain
        Returns:
            dict: Filtered dictionary with same keys
        """
        return {k: v[valid_indices] for k, v in data.items()}

    def check_core_constraints(self):
        """Check core constraints using vectorized operations."""
        if not hasattr(self, 'core_subvols') or self.core_subvols['ifm'] is None:
            return
        
        num_combinations = self.num_combinations
        memtile_capacity = self.device.memtile_capacity * self.device.memtile_rows * 1024  # In bytes
        
        for sched in self.schedule_list:
            # Skip if no memtile subvols for this schedule
            if sched not in self.memtile_subvols:
                continue
            
            # Get key arrays for this schedule
            core_subvols = self.core_subvols
            core_iters = self.core_iters[sched]
            memtile_subvols = self.memtile_subvols[sched]
            memtile_iters = self.memtile_iters[sched]
            
            # Create arrays for core and memtile iterations
            Tb_core_ifm, Tm_core_ifm, Tk_core_ifm = core_iters['ifm'].T
            Tb_core_wgt, Tk_core_wgt, Tn_core_wgt = core_iters['wgt'].T
            Tb_core_ofm, Tm_core_ofm, Tn_core_ofm = core_iters['ofm'].T
            Tb_core = Tb_core_ifm
            Tm_core = Tm_core_ifm
            Tk_core = Tk_core_wgt
            Tn_core = Tn_core_wgt

            Tb_mem_ifm, Tm_mem_ifm, Tk_mem_ifm = memtile_iters['ifm'].T
            Tb_mem_wgt, Tk_mem_wgt, Tn_mem_wgt = memtile_iters['wgt'].T
            Tb_mem_ofm, Tm_mem_ofm, Tn_mem_ofm = memtile_iters['ofm'].T


            # Check buffer placement constraints vectorized
            core_validity_checks = {}
            buffer_placements = np.empty(num_combinations, dtype='<U16')
            
            # Vectorized computation for Tb, Tm, Tk, Tn
            Tb = np.maximum.reduce([Tb_core_ifm * Tb_mem_ifm, Tb_core_wgt * Tb_mem_wgt, Tb_core_ofm * Tb_mem_ofm])
            Tm = np.maximum(Tm_core_ifm * Tm_mem_ifm, Tm_core_ofm * Tm_mem_ofm)
            Tk = np.maximum(Tk_core_ifm * Tk_mem_ifm, Tk_core_wgt * Tk_mem_wgt)
            Tn = np.maximum(Tn_core_wgt * Tn_mem_wgt, Tn_core_ofm * Tn_mem_ofm)
            loop_iters = {'Tb': Tb, 'Tm': Tm, 'Tk': Tk, 'Tn': Tn}

            # Determine buffer placement for all cores at once
            single_mask = (Tn == 1) & (Tk == 1) & (Tm == 1)
            buffer_placements[single_mask] = 'single_single'
            buffer_placements[~single_mask] = 'double_double'

            # Store buffer placement choices
            self.buffer_placement_choices[sched] = buffer_placements

            ## check buffer placements
            self.vars_dict['tk_1'] = single_mask.astype(np.int32)
            self.vars_dict['Msubv'] = core_subvols['ifm'][:, -2]
            self.vars_dict['Ksubv'] = core_subvols['ifm'][:, -1]
            self.vars_dict['Nsubv'] = core_subvols['wgt'][:, -1]
            self.vars_dict['inner_loop'] = self.kernel.inner_loop
            self.vars_dict['outer_loop'] = self.kernel.outer_loop
            bank_dict = self.kernel.placement_constraints
            single_single = np.all(list(map(lambda x: eval(x, self.vars_dict), bank_dict['single_single'].values())), axis=0)
            double_double = np.all(list(map(lambda x: eval(x, self.vars_dict), bank_dict['double_double'].values())), axis=0)
            core_validity_checks['buffer_placement'] = np.where(single_mask, single_single, double_double)

            # loop constraints eval
            for constraint, formula in self.kernel.other_constraints.items():
                core_validity_checks[constraint] = eval(formula, self.vars_dict)

            # All core constraints should be met
            fits_in_coretile = np.all(np.column_stack(list(core_validity_checks.values())), axis=1)

            # Compute memtile sizes vectorized
            ifm_memtile_sizes = np.prod(memtile_subvols['ifm'], axis=1) * self.layer.in_bytes
            wgt_memtile_sizes = np.prod(memtile_subvols['wgt'], axis=1) * self.layer.wgt_bytes
            ofm_memtile_sizes = np.prod(memtile_subvols['ofm'], axis=1) * self.layer.out_bytes
            
            # Add bias bytes to weight memtile sizes
            num_wgt_subvols = np.prod((memtile_subvols['wgt'] // core_subvols['wgt']), axis=1)
            bias_bytes_per_subvol = core_subvols['wgt'][:, -1] * self.vars_dict['bias_bytes']
            bias_bytes_in_memtile = num_wgt_subvols * bias_bytes_per_subvol
            wgt_memtile_sizes += bias_bytes_in_memtile
            
            num_ofm_buf = -1
            if 'RoPE' in self.layer.op_type:
              num_ofm_buf = 3
            elif 'Add' in self.layer.op_type:
              num_ofm_buf = 2
            else:
              num_ofm_buf = 1

            # Compute memory requirements based on schedule
            if sched == 1:
                # Schedule 1: ifm pin, wgt full, ofm stream
                space_required_1buff = ifm_memtile_sizes + wgt_memtile_sizes + ofm_memtile_sizes * num_ofm_buf
                space_required_2buff = ifm_memtile_sizes * 2 + wgt_memtile_sizes + ofm_memtile_sizes * num_ofm_buf
                
            elif sched == 2:
                # Schedule 2: ifm pin, wgt stream, ofm stream
                space_required_1buff = ifm_memtile_sizes + wgt_memtile_sizes * 2 + ofm_memtile_sizes * num_ofm_buf
                space_required_2buff = ifm_memtile_sizes * 2 + wgt_memtile_sizes * 2 + ofm_memtile_sizes * num_ofm_buf
                
            elif sched == 5:
                # Schedule 5: ifm stream, wgt stream, ofm stream
                space_required_1buff = np.full(num_combinations, np.iinfo(np.int32).max)    # Dummy fill for single buffer.
                space_required_2buff = ifm_memtile_sizes * 2 + wgt_memtile_sizes * 2 + ofm_memtile_sizes * num_ofm_buf
            else:
                assert False, f"Unexpected schedule number {sched}"
            
            # Store fits in memtile info
            memtile_pingpong = np.column_stack((
                space_required_1buff < (memtile_capacity - 16 * 1024),
                space_required_2buff < (memtile_capacity - 16 * 1024)
            ))
            fits_in_memtile = np.any(memtile_pingpong, axis=1)      # At least one memtile configuration is valid
            self.memtile_pingpong[sched] = memtile_pingpong
            
            
            ###### DRC checks/filtering for BD programming
            B_split = self.overlay.core_splits['ofm'][0]
            M_split = self.overlay.core_splits['ofm'][-2]
            N_split = self.overlay.core_splits['ofm'][-1]

            ########## HW limits
            # filter for iter wrap limit 
            max_chain_length = 4
            MAX_ITER_WRAP = 64
            MAX_IFM_BROADCAST_BD_LEN = 10
            MAX_IFM_UNICAST_BD_LEN = 5

            mpadding_ifm_check = self.dma_padding['ifm'][:, -2] > 0
            kpadding_ifm_check = self.dma_padding['ifm'][:, -1] > 0
            ifm_2dpadding_check = mpadding_ifm_check & kpadding_ifm_check

            # WGT
            wgt_subvol_sizes    = np.prod(core_subvols['wgt'], axis=1) * self.vars_dict['wgt_bytes'] + bias_bytes_per_subvol
            wgt_valid_iter_step = ((wgt_subvol_sizes * B_split * Tk_core * Tk_mem_wgt * N_split * 1 / self.device.bytes_per_word)<= 2**20)
            wgt_valid_iter_wrap = True
            if 'actxact' in self.layer.op_type and sched in [2, 5]:
                wgt_valid_iter_wrap = check_iteration_chain_length((Tn_core * Tn_mem_ofm), self.device.max_chain_length)

            # IFM
            chain_length        = np.clip(np.ceil((Tm_core * Tm_mem_ifm) / MAX_ITER_WRAP), a_min=1, a_max=max_chain_length)
            ifm_subvol_sizes    = np.prod(core_subvols['ifm'], axis=1) * self.vars_dict['ifm_bytes']
            ifm_valid_iter_step = ((ifm_subvol_sizes * B_split * M_split * Tk_core_ifm * Tk_mem_ifm * chain_length / self.device.bytes_per_word) <= 2**20) | (Tm_core*Tm_mem_ifm==1) | (Tk_core*Tk_mem_wgt==1)
            
            ifm_valid_iter_wrap_mpad  = check_iteration_chain_length((Tm_core * Tm_mem_ifm - 1), self.device.max_chain_length)
            ifm_valid_iter_wrap_nopad = check_iteration_chain_length((Tm_core * Tm_mem_ifm), self.device.max_chain_length)
            ifm_valid_iter_wrap       = np.where(mpadding_ifm_check, ifm_valid_iter_wrap_mpad, ifm_valid_iter_wrap_nopad)
            
            if self.overlay.unicast == 'wgt':
                ifm_valid_BD_len    = ~((Tm_core * Tm_mem_ifm > MAX_IFM_BROADCAST_BD_LEN) & (sched == 5))
            else:
                ifm_valid_BD_len    = ~((Tm_core * Tm_mem_ifm > MAX_IFM_UNICAST_BD_LEN) & (sched == 5))

            ifm_valid_iter_w8_wrap  = True
            if sched in [1, 2] :
                ifm_valid_iter_w8_wrap  = ~(((memtile_subvols['ifm'][:,2] // 8) > 1024) & kpadding_ifm_check)

            #OFM
            ofm_subvol_sizes     = np.prod(core_subvols['ofm'], axis=1) * self.vars_dict['ofm_bytes']
            ofm_valid_iter_step = ((ofm_subvol_sizes * B_split * M_split * N_split * Tn_core_ofm * Tn_mem_ofm * chain_length / self.device.bytes_per_word) <= 2**20) | (Tm_core*Tm_mem_ifm==1) | (Tn_core*Tn_mem_wgt==1)

            valid_iter_wrap = ofm_valid_iter_step & wgt_valid_iter_step & wgt_valid_iter_wrap & ifm_valid_iter_step & ifm_valid_iter_wrap & ifm_valid_BD_len & ifm_valid_iter_w8_wrap


            ######### Renque limitation (might go away)
            act_reuse_lim = 2**30
            wgt_reuse_lim = 2**30
            if self.overlay.unicast == 'wgt':
                if sched == 1:
                    act_reuse_lim = self.device.broadcast_reuse_limit
                    wgt_reuse_lim = self.device.unicast_reuse_limit
                if sched == 2:
                    act_reuse_lim = self.device.broadcast_reuse_limit
            elif self.overlay.unicast == 'act':
                if sched == 1:
                    act_reuse_lim = self.device.unicast_reuse_limit
                    wgt_reuse_lim = self.device.broadcast_reuse_limit
                if sched == 2:
                    act_reuse_lim = self.device.unicast_reuse_limit

            reuse_ratio_validity = (Tn_mem_ofm < act_reuse_lim) & (Tm_mem_ofm < wgt_reuse_lim)

            ### reuse chain length filter
            reuse_chain_validity = True
            num_consumers = np.prod(self.overlay.core_splits['ofm']/self.overlay.mem_splits['ofm']) 
            act_max_chain_length = self.device.max_chain_length    if self.overlay.unicast == 'wgt' else self.device.max_chain_length//2
            wgt_max_chain_length = self.device.max_chain_length//2 if self.overlay.unicast == 'wgt' else self.device.max_chain_length
            if sched==1:
                ifm_reuse_chain_validity = check_reuse_chain_validity(Tn_mem_ofm, num_consumers, act_max_chain_length, self.device.max_lock_value)
                wgt_reuse_chain_validity = check_reuse_chain_validity(Tm_mem_ofm, num_consumers, wgt_max_chain_length, self.device.max_lock_value)
                reuse_chain_validity = ifm_reuse_chain_validity & wgt_reuse_chain_validity
            if sched==2:
                reuse_chain_validity = check_reuse_chain_validity(Tn_mem_ofm, num_consumers, act_max_chain_length, self.device.max_lock_value)


            ### actxact ifm 1 restriction
            ifm1_dim_validity = Tk_core * Tk_mem_wgt > 0
            if 'actxact' in self.layer.op_type and sched==1:
                ifm1_dim_validity = ~((Tk_core*Tk_mem_wgt>1) & (Tn_core*Tn_mem_ofm>1))

            ### M1N32 wgt subv restriction
            m1_split_tn_validity = True
            if self.overlay.core_splits['ofm'][-1] == 32 and sched==2:
                m1_split_tn_validity = ~((Tn_mem_ofm * Tn_core > 8) & (wgt_subvol_sizes > 4096))
            elif self.overlay.core_splits['ofm'][-1] == 32 and sched==5:
                m1_split_tn_validity = ~(((Tn_mem_ofm * Tn_core + Tm_mem_ofm * Tm_core) > 8) & (wgt_subvol_sizes > 4096)) & check_iteration_chain_length(Tn_mem_ofm, self.device.max_chain_length)

            ###### padding related DRCs
            # kpadding_ifm1_validty = True
            # if 'actxact' in self.layer.op_type and sched in [2, 5]:
            #     kpadding_ifm1_validty = ~(((Tk_core*Tk_mem_wgt>1) & (self.dma_padding['wgt'][:,1] > 0)) | \
            #                               ((self.dma_padding['wgt'][:,1] > 0) & (self.dma_padding['wgt'][:,2] > 0)))

            # #disable ofm tn>1 and N depadding
            # npadding_validity = ~(((Tm_mem_ofm * Tm_core) > 1) & ((Tn_core * Tn_mem_ofm) > 1) & ((self.dma_padding['ofm'][:,-1]) > 0))
            # #npadding_validity = ~(((Tn_core * Tn_mem_ofm) > 1) & ((self.dma_padding['ofm'][:,-1]) > 0))

            # if mpadding is active disable pingpong 
            mpadding_ifm_pingpong = np.copy(memtile_pingpong)
            mpadding_ifm_pingpong[:, 1] = mpadding_ifm_pingpong[:, 1] & (~(mpadding_ifm_check & ((sched==1) | (sched==2))))

            # restrict wgt_subvol_size if Tm and Tn >1
            wgt_subv_validity = np.copy(valid_iter_wrap)
            if ((sched == 2) or (sched == 5)):
                wgt_subv_validity = np.where(mpadding_ifm_check, 
                                             (~((Tm_core*Tm_mem_ofm>1) & 
                                                (Tn_core*Tn_mem_ofm>1) & 
                                                (wgt_subvol_sizes > 4096))),
                                             wgt_subv_validity)                     

            n_batch = self.original_shapes['ifm'][0]
            # disable ifm pingpong for ifm pinning when batch>1
            bmm_ifm_pingpong = np.ones_like(memtile_pingpong, dtype=bool)
            if (n_batch>1) and ((sched==1) or (sched==2)):
                bmm_ifm_pingpong[(Tn_mem_ofm>1), 1] = False

            disable_ifm_pingpong = mpadding_ifm_pingpong & bmm_ifm_pingpong
            self.disable_ifm_pingpong[sched] = disable_ifm_pingpong

            # for ifm streaming in BMM we need to ensure that Tm cannot be more than 1 because streaming uses bd chaining for Tm>1 which is not possible for bmm
            if ((n_batch>1) and (sched==5)):
                bmm_stream_chaining_validity = (Tm_mem_ofm*Tm_core == 1) 
            else:
                bmm_stream_chaining_validity = (Tm_mem_ofm*Tm_core > 0)

            # for bmm the total temporal iterations cannot be more than 768
            #if n_batch>1 and sched in [2,5]:
            bmm_reenque_validity = (Tm_mem_ofm*Tm_core*Tk_mem_wgt*Tk_core*Tn_mem_ofm*Tn_core>0)
            # if n_batch>1 and Tm_mem_ofm*Tm_core*Tk_mem_wgt*Tk_core*Tn_mem_ofm*Tn_core>768:
            #     bmm_reenque_validity = False

            # Combine all validity checks
            valid_indices = np.where(fits_in_coretile & fits_in_memtile & 
                                     valid_iter_wrap & reuse_ratio_validity & 
                                     reuse_chain_validity & ifm1_dim_validity & 
                                     m1_split_tn_validity &  wgt_subv_validity & 
                                     bmm_stream_chaining_validity & bmm_reenque_validity)[0] 
                                    #  npadding_validity & kpadding_ifm1_validty)[0] 

            if (hasattr(self.layer, 'debug_info') and 
                self.layer.debug_info.get('valid_idx') is not None):
                # If debug info is provided, filter valid indices based on it
                valid_indices = [valid_indices[int(self.layer.debug_info['valid_idx'])]]

            # Store valid indices
            self.valid_core_indices[sched] = valid_indices
            
            # Store valid configurations
            if len(valid_indices) > 0:
                # Filter and store valid data
                self.valid_core_subvols[sched] = self._filter_valid_data(core_subvols, valid_indices)
                self.valid_core_iters[sched] = self._filter_valid_data(core_iters, valid_indices)
                self.valid_memtile_subvols[sched] = self._filter_valid_data(memtile_subvols, valid_indices)
                self.valid_memtile_iters[sched] = self._filter_valid_data(memtile_iters, valid_indices)
                self.valid_loop_iters[sched] = self._filter_valid_data(loop_iters, valid_indices)
                self.valid_padded_shapes[sched] = self._filter_valid_data(self.padded_shapes, valid_indices)
                self.valid_host_shapes[sched] = self._filter_valid_data(self.host_shapes, valid_indices)
                self.valid_host_padding[sched] = self._filter_valid_data(self.host_padding, valid_indices)
                self.valid_dma_padding[sched] = self._filter_valid_data(self.dma_padding, valid_indices)
                
                # Store valid pingpong info
                self.valid_memtile_pingpong[sched] = memtile_pingpong[valid_indices]

                # Store valid info for disabling pingpong for Mpadding
                self.valid_disable_ifm_pingpong[sched] = disable_ifm_pingpong[valid_indices]
                
                # Store valid buffer placement choices
                self.valid_buffer_placement_choices[sched] = buffer_placements[valid_indices]
        

if __name__=='__main__':
    from overlay import Overlay
    ov=Overlay('8x4','MatMul','M4K1N8')

    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_layer.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)


    t = MatMulTiler(l,d,ov,k)
    t.calculate_array_tilings()
    t.calculate_memtile_tilings()
    t.check_core_constraints()
