import yaml
import os
import copy
import warnings
warnings.simplefilter('always')
import numpy as np

from matmul_tiler import MatMulTiler
from matmul_cost_model import MatMulCostModel
from overlay import Overlay

from enum import Enum
import pdb

parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class MatMulTilingOpt:
    def __init__(self, layer, device, overlay_name, kernel):
        with open(os.path.join(parent_dir,'Collaterals/overlays.yaml')) as f:
            all_overlays = yaml.safe_load(f)
        self.modes = ['M32K1N1'] if 'RoPE' in layer.op_type else list(all_overlays[overlay_name][layer.orig_op_type].keys())

        self.layer = layer
        if (hasattr(self.layer, 'debug_info') and
            self.layer.debug_info.get('midx') is not None):
            self.modes = [self.modes[int(self.layer.debug_info['midx'])]]

        self.overlay = overlay_name
        n_batch = layer.in_act_shape[0]

        self.tilers = []
        self.cost_models = []
        
        # Cache for frequently accessed data
        self._cached_op_type_flags = {}
        self._cached_string_operations = {}
        
        for mode in self.modes:
            overlay = Overlay(overlay_name,layer.orig_op_type,mode)
            # Skip subarray mapping for batch=1 or actxwgt case.
            if (overlay.core_splits['ifm'][0] > 1 and
                (n_batch == 1 or 'actxact' not in layer.op_type)):
                continue

            tiler = MatMulTiler(layer, device, overlay, kernel)
            self.tilers.append(tiler)
            self.cost_models.append(MatMulCostModel(tiler))

    def _convert_to_int_list(self, array):
        """Helper method to convert numpy array to int list efficiently."""
        return array.astype(np.int32).tolist()

    def _get_op_type_flags(self, op_type):
        """Cache operation type flags to avoid repeated string operations."""
        if op_type not in self._cached_op_type_flags:
            self._cached_op_type_flags[op_type] = {
                'has_rope': 'RoPE' in op_type,
                'has_add': 'Add' in op_type,
                'has_actxact': 'actxact' in op_type,
                'has_pwla': 'pwla' in op_type,
                'is_uint_wgt': 'uint' in op_type.split('_')[-1].split('x')[1] if '_' in op_type and 'x' in op_type else False
            }
        return self._cached_op_type_flags[op_type]

    def _create_efficient_index_array(self, midx, sched, num_cycles):
        """Create index arrays more efficiently."""
        midx_sched_array = np.full((num_cycles, 2), [midx, sched], dtype=np.int32)
        cycle_indices = np.arange(num_cycles, dtype=np.int32).reshape(-1, 1)
        return np.concatenate([midx_sched_array, cycle_indices], axis=1)

    def _extract_tiling_data(self, tiler, opt_sched, opt_valid_idx):
        """Extract all tiling-related data in one pass to reduce repeated array access."""
        # Cache all valid arrays for this schedule
        valid_arrays = {
            'core_subvols': tiler.valid_core_subvols[opt_sched],
            'core_iters': tiler.valid_core_iters[opt_sched],
            'memtile_subvols': tiler.valid_memtile_subvols[opt_sched],
            'memtile_iters': tiler.valid_memtile_iters[opt_sched],
            'host_shapes': tiler.valid_host_shapes[opt_sched],
            'dma_padding': tiler.valid_dma_padding[opt_sched],
            'host_padding': tiler.valid_host_padding[opt_sched]
        }

        # Extract all data at once
        extracted_data = {}
        
        # Core tiling data
        extracted_data['core_tiling'] = {
            'ifm': self._convert_to_int_list(valid_arrays['core_subvols']['ifm'][opt_valid_idx]),
            'wgt': self._convert_to_int_list(valid_arrays['core_subvols']['wgt'][opt_valid_idx]),
            'ofm': self._convert_to_int_list(valid_arrays['core_subvols']['ofm'][opt_valid_idx])
        }

        # Core iterations
        ifm_core_iters = valid_arrays['core_iters']['ifm'][opt_valid_idx]
        wgt_core_iters = valid_arrays['core_iters']['wgt'][opt_valid_idx]
        
        Tb_core = int(ifm_core_iters[0])
        Tm_core = int(ifm_core_iters[-2])
        Tk_core = int(wgt_core_iters[-2])
        Tn_core = int(wgt_core_iters[-1])

        extracted_data['core_iters'] = {
            'ifm': [Tb_core, Tm_core, Tk_core],
            'wgt': [Tb_core, Tk_core, Tn_core],
            'ofm': [Tb_core, Tm_core, Tn_core]
        }

        # Memtile data
        extracted_data['memtile_tiling'] = {
            'ifm': self._convert_to_int_list(valid_arrays['memtile_subvols']['ifm'][opt_valid_idx]),
            'wgt': self._convert_to_int_list(valid_arrays['memtile_subvols']['wgt'][opt_valid_idx]),
            'ofm': self._convert_to_int_list(valid_arrays['memtile_subvols']['ofm'][opt_valid_idx])
        }

        extracted_data['memtile_iters'] = {
            'ifm': self._convert_to_int_list(valid_arrays['memtile_iters']['ifm'][opt_valid_idx]),
            'wgt': self._convert_to_int_list(valid_arrays['memtile_iters']['wgt'][opt_valid_idx]),
            'ofm': self._convert_to_int_list(valid_arrays['memtile_iters']['ofm'][opt_valid_idx])
        }

        # Host shapes
        extracted_data['host_shapes'] = {
            'ifm': self._convert_to_int_list(valid_arrays['host_shapes']['ifm'][opt_valid_idx]),
            'wgt': self._convert_to_int_list(valid_arrays['host_shapes']['wgt'][opt_valid_idx]),
            'ofm': self._convert_to_int_list(valid_arrays['host_shapes']['ofm'][opt_valid_idx])
        }

        # DMA and host padding
        extracted_data['dma_padding'] = {
            'ifm': self._convert_to_int_list(valid_arrays['dma_padding']['ifm'][opt_valid_idx]),
            'wgt': self._convert_to_int_list(valid_arrays['dma_padding']['wgt'][opt_valid_idx]),
            'ofm': self._convert_to_int_list(valid_arrays['dma_padding']['ofm'][opt_valid_idx])
        }

        extracted_data['host_padding_values'] = {
            'ifm': int(valid_arrays['host_padding']['ifm'][opt_valid_idx][0]),
            'wgt': int(valid_arrays['host_padding']['wgt'][opt_valid_idx][0]),
            'ofm': int(valid_arrays['host_padding']['ofm'][opt_valid_idx][0])
        }

        return extracted_data

    def _create_padding_structures(self, extracted_data, tiler, op_flags):
        """Create padding structures efficiently with minimal list creation."""
        host_shapes = extracted_data['host_shapes']
        dma_padding = extracted_data['dma_padding']
        
        # Pre-calculate padding value lists
        zp_values = {
            'i0': ["zp_i0"] * len(host_shapes['ifm']),
            'i1': ["zp_i1"] * len(host_shapes['wgt']),
            'o0': ["zp_o0"] * len(host_shapes['ofm']),
            'i0_dma': ["zp_i0"] * len(dma_padding['ifm']),
            'i1_dma': ["zp_i1"] * len(dma_padding['wgt']),
            'o0_dma': ["zp_o0"] * len(dma_padding['ofm'])
        }

        # Host layer padding key
        host_layer_padding_key = "ifm" if op_flags['has_actxact'] else "wgt"

        # Create base padding structures
        host_layer_padding = [
            {"ifm": {"dims": host_shapes['ifm'], "value": zp_values['i0']}},
            {host_layer_padding_key: {"dims": host_shapes['wgt'], "value": zp_values['i1']}},
            {"ofm": {"dims": host_shapes['ofm'], "value": zp_values['o0']}}
        ]

        dma_layer_padding = [
            {"ifm": {
                "dims": dma_padding['ifm'],
                "value": zp_values['i0_dma'],
                "channels": tiler.overlay.memtile_unicast_channels if tiler.overlay.unicast == "act" else tiler.overlay.memtile_broadcast_channels
            }},
            {"wgt": {
                "dims": dma_padding['wgt'],
                "value": zp_values['i1_dma'],
                "channels": tiler.overlay.memtile_unicast_channels if tiler.overlay.unicast == "wgt" else tiler.overlay.memtile_broadcast_channels
            }},
            {"ofm": {
                "dims": dma_padding['ofm'],
                "value": zp_values['o0_dma'],
                "channels": tiler.overlay.memtile_ofm_channels
            }}
        ]

        return host_layer_padding, dma_layer_padding

    def _compute_derived_arrays(self, extracted_data, layer, tiler):
        """Compute all derived arrays (shim, dram) efficiently."""
        memtile_tiling = extracted_data['memtile_tiling']
        memtile_iters = extracted_data['memtile_iters']
        
        # Convert to numpy arrays once for all computations
        memtile_tiling_np = {k: np.array(v) for k, v in memtile_tiling.items()}
        memtile_iters_np = {k: np.array(v) for k, v in memtile_iters.items()}
        
        # Calculate memtile sizes
        core_tiling = extracted_data['core_tiling']
        memtile_sizes = {
            'ifm': int(np.prod(memtile_tiling_np['ifm']) * layer.in_bytes),
            'wgt': int(np.prod(memtile_tiling_np['wgt']) * layer.wgt_bytes +
                      core_tiling['wgt'][-1] * tiler.vars_dict['bias_bytes'] *
                      np.prod(memtile_tiling_np['wgt']/np.array(core_tiling['wgt']))),
            'ofm': int(np.prod(memtile_tiling_np['ofm']) * layer.out_bytes)
        }

        # Compute shim arrays
        shim_arrays = {}
        shim_sizes = {}
        for k in memtile_tiling.keys():
            shim_array = memtile_tiling_np[k] * memtile_iters_np[k]
            shim_arrays[k] = self._convert_to_int_list(shim_array)
            shim_sizes[k] = int(memtile_sizes[k] * np.prod(memtile_iters_np[k]))

        # Compute DRAM arrays
        overlay_mem_splits = tiler.overlay.mem_splits
        dram_shapes = {}
        dram_sizes = {}
        for k in shim_sizes:
            dram_shape_array = np.array(shim_arrays[k]) * overlay_mem_splits[k]
            dram_shapes[k] = self._convert_to_int_list(dram_shape_array)
            dram_sizes[k] = int(shim_sizes[k] * np.prod(overlay_mem_splits[k]))

        return {
            'memtile_sizes': memtile_sizes,
            'shim_tilings': shim_arrays,
            'shim_sizes': shim_sizes,
            'dram_shapes': dram_shapes,
            'dram_sizes': dram_sizes
        }

    def _update_layer_dict_efficiently(self, layerdict, extracted_data):
        """Update layer dictionary with minimal operations."""
        # Batch dictionary updates
        shape_updates = {
            'in_ifm_shape': extracted_data['host_shapes']['ifm'],
            'in_wgt_shape': extracted_data['host_shapes']['wgt'],
            'out_ofm_shape': extracted_data['host_shapes']['ofm']
        }
        
        # Remove old keys and add new ones
        old_keys_to_remove = ['in_act_shape', 'out_act_shape']
        for key in old_keys_to_remove:
            layerdict.pop(key, None)
        
        layerdict.update(shape_updates)

        # Batch datatype and bytes updates
        field_mappings = {
            'in_datatype': 'in_ifm_datatype',
            'wgt_datatype': 'in_wgt_datatype', 
            'wgt1_datatype': 'in_wgt1_datatype',
            'out_datatype': 'out_ofm_datatype',
            'in_bytes': 'in_ifm_bytes',
            'wgt_bytes': 'in_wgt_bytes',
            'wgt1_bytes': 'in_wgt1_bytes', 
            'out_bytes': 'out_ofm_bytes'
        }
        
        for old_key, new_key in field_mappings.items():
            if old_key in layerdict:
                layerdict[new_key] = layerdict.pop(old_key)

    def _handle_special_operations(self, op_flags, extracted_data, host_layer_padding, dma_layer_padding, all_attr_dicts):
        """Handle special operations (RoPE, Add) efficiently."""
        if op_flags['has_rope']:
            rope_padding = {'dims': extracted_data['host_shapes']['ofm'], 'value': [None, None, 0]}
            rope_dma_padding = {'dims': [0, 0, 0], 'value': [0, 0, None]}
            
            host_layer_padding.extend([
                {'sin': rope_padding.copy()},
                {'cos': rope_padding.copy()}
            ])
            dma_layer_padding.extend([
                {'sin': rope_dma_padding.copy()},
                {'cos': rope_dma_padding.copy()}
            ])
            
            # Use references for all attribute dictionaries
            ofm_ref = all_attr_dicts[0]['ofm']  # Reference to ofm in first dict
            for attr_dict in all_attr_dicts:
                attr_dict['sin'] = attr_dict['cos'] = ofm_ref
                
        elif op_flags['has_add']:
            host_layer_padding.append({'ifmB': {'dims': extracted_data['host_shapes']['ofm'], 'value': [None, None, 0]}})
            dma_layer_padding.append({'ifmB': {'dims': [0, 0, 0], 'value': [0, 0, None]}})
            
            # Use references for all attribute dictionaries  
            ofm_ref = all_attr_dicts[0]['ofm']  # Reference to ofm in first dict
            for attr_dict in all_attr_dicts:
                attr_dict['ifmB'] = ofm_ref

    def _create_schedule_dict(self, opt_sched, pingpong, n_batch, tiler, opt_valid_idx):
        """Create scheduling dictionary based on schedule type."""
        base_dict = {'Tbatch': n_batch}
        
        if opt_sched == 1:  # ifm pin, wgt full, ofm stream
            base_dict.update({
                'ifm': 'pin',
                'wgt': 'full',
                'ifm_ping_pong': bool(pingpong == 1),
                'wgt_ping_pong': False,
                'ofm_ping_pong': False
            })
        elif opt_sched == 2:  # ifm pin, wgt stream, ofm stream
            base_dict.update({
                'ifm': 'pin',
                'wgt': 'stream',
                'ifm_ping_pong': bool(pingpong == 1),
                'wgt_ping_pong': True,
                'ofm_ping_pong': False
            })
        elif opt_sched == 5:  # ifm stream, wgt stream, ofm stream
            Mifm_dma_padding = tiler.valid_dma_padding[opt_sched]['ifm'][opt_valid_idx][0]
            base_dict.update({
                'ifm': 'stream',
                'wgt': 'stream',
                'ifm_ping_pong': bool(pingpong == 1),
                'wgt_ping_pong': bool(Mifm_dma_padding == 0),
                'ofm_ping_pong': False
            })
        
        return base_dict

    def _create_testbench_flags(self, tiler, layerdict, core_tiling, derived_arrays, 
                              GEMM_VEC_COEFFS, n_batch, extracted_data, opt_midx, op_flags):
        """Create testbench flags efficiently."""
        testbench_args = tiler.kernel.testbench_args
        cflags = testbench_args['CFLAGS']
        
        # Base flags
        tb_cflags = {
            cflags[0]: layerdict['in_ifm_shape'][-2],  # M_GEMM
            cflags[1]: layerdict['in_ifm_shape'][-1],  # K_GEMM
            cflags[2]: layerdict['out_ofm_shape'][-1],  # N_GEMM
            cflags[3]: core_tiling['ifm'][-2],  # M_GEMM_SUBV
            cflags[4]: core_tiling['ifm'][-1],  # K_GEMM_SUBV
            cflags[5]: core_tiling['wgt'][-1],  # N_GEMM_SUBV
            cflags[6]: derived_arrays['dram_sizes']['ifm'],  # DRAM_IFM_SIZE
            cflags[7]: derived_arrays['dram_sizes']['wgt'],  # DRAM_WGT_SIZE
            cflags[8]: derived_arrays['dram_sizes']['ofm'],  # DRAM_OFM_SIZE
            cflags[9]: GEMM_VEC_COEFFS,  # GEMM_VEC_COEFFS
            cflags[10]: int(op_flags['has_pwla']),
            cflags[11]: int(not op_flags['is_uint_wgt']),
            cflags[18]: n_batch  # Set batch dimension
        }

        # Handle specific test cases
        test_cpp_name = testbench_args['HostName']
        if "gemm_int16x8_unit_test" in test_cpp_name:
            tb_cflags.update({
                cflags[19]: -1,
                cflags[20]: int(op_flags['has_rope']),
                cflags[21]: int(layerdict.get("attributes") is not None and 
                              "sin_cos_const" in layerdict["attributes"] and 
                              not layerdict["attributes"]["sin_cos_const"][0]),
                cflags[23]: int(op_flags['has_add'])
            })

        # Handle attributes
        if layerdict.get("attributes") is not None:
            attrs = layerdict["attributes"]
            tb_cflags.update({
                cflags[12]: attrs.get("disable_q", [0])[0],
                cflags[13]: int(attrs.get('InTransposeA', [0])[0] == 1),
                cflags[14]: int(attrs.get('InTransposeB', [0])[0] == 1),
                cflags[15]: ','.join(map(str, tiler.layer.permA)),
                cflags[16]: ','.join(map(str, tiler.layer.permB)),
                cflags[17]: ','.join(map(str, tiler.layer.permY))
            })
        else:
            default_perm = ','.join(map(str, [0, 1, 2]))
            tb_cflags.update({
                cflags[12]: 0, cflags[13]: 0, cflags[14]: 0,
                cflags[15]: default_perm, cflags[16]: default_perm, cflags[17]: default_perm
            })

        # Shape-related flags
        host_padding_values = extracted_data['host_padding_values']
        tb_cflags.update({
            cflags[19]: layerdict['out_ofm_shape'][-2],  # M_GEMM_OFM
            cflags[22]: ','.join(map(str, [host_padding_values['ifm'], host_padding_values['wgt'], host_padding_values['ofm']]))
        })

        if op_flags['has_actxact']:
            tb_cflags.update({
                cflags[20]: layerdict['in_wgt_shape'][-2],  # K_GEMM_WGT_A16W16 (for actxact)
                cflags[21]: layerdict['in_wgt_shape'][-1],  # N_GEMM_WGT_A16W16 (for actxact)
                cflags[23]: int(layerdict['in_ifm_bytes'] == 2 and layerdict['in_wgt_bytes'] == 1), #A16A8
                cflags[24]: int(layerdict['in_ifm_bytes'] == 1 and layerdict['in_wgt_bytes'] == 1)  #A8A8
            })
        else:
            tb_cflags.update({
                cflags[24]: layerdict['in_wgt_shape'][-2],
                cflags[25]: layerdict['in_wgt_shape'][-1]
            })

        return tb_cflags

    def _create_debug_info(self, opt_midx, opt_sched, opt_valid_idx, pingpong):
        """Create debug information structure."""
        cost_model = self.cost_models[opt_midx]
        layer_cycles_data = cost_model.layer_cycles[opt_sched][opt_valid_idx]
        total_cycles = cost_model.total_layer_cycles[opt_sched][opt_valid_idx]

        debug_info = {
            "midx": int(opt_midx),
            "sched": int(opt_sched),
            "valid_idx": int(opt_valid_idx),
            "pingpong": int(pingpong),
            "single_buffer_cycles": float(total_cycles[0]),
            "double_buffer_cycles": float(total_cycles[1])
        }

        # Add detailed cycle information
        cycle_labels = [
            "ifm_stream_cycles", "wgt_stream_cycles", "kernel_cycles", "ofm_stream_cycles",
            "ifm_dram_cycles", "wgt_dram_cycles", "ofm_dram_cycles", 
            "reenq_penalty_single_buff", "reenq_penalty_double_buff",
            "Tb_core_ifm", "Tm_core_ifm", "Tk_core_ifm",
            "Tb_core_wgt", "Tk_core_wgt", "Tn_core_wgt",
            "Tb_core_ofm", "Tm_core_ofm", "Tn_core_ofm",
            "Tb_mem_ifm", "Tm_mem_ifm", "Tk_mem_ifm",
            "Tb_mem_wgt", "Tk_mem_wgt", "Tn_mem_wgt",
            "Tb_mem_ofm", "Tm_mem_ofm", "Tn_mem_ofm"
        ]

        for i, label in enumerate(cycle_labels):
            debug_info[label] = float(layer_cycles_data[i]) if i < 9 else int(layer_cycles_data[i])

        # Override with debug_info if available
        if (hasattr(self.layer, 'debug_info') and self.layer.debug_info.get('midx') is not None):
            debug_overrides = ['midx', 'sched', 'valid_idx', 'pingpong']
            for key in debug_overrides:
                if key in self.layer.debug_info:
                    debug_info[key] = int(self.layer.debug_info[key])

        return debug_info

    def get_tiling_params (self, mode_idx, tiler):
        tiling_param_dict = {}
        for sched, core_subvols in tiler.valid_core_subvols.items():
            tiling_params_list = []
            if core_subvols != {}:
                for idx in range (len(core_subvols['ifm'])):
                    tiling_params_list.append(self.format_tiling_output(mode_idx, sched, idx, True, scheduler_pass=False))

            tiling_param_dict[sched]=tiling_params_list
        return tiling_param_dict

    def calculate_tiling_cycles(self):
        for mode_idx, (tiler, cost_model) in enumerate(zip(self.tilers, self.cost_models)):
            tiler.calculate_array_tilings()
            tiler.calculate_memtile_tilings()
            tiler.check_core_constraints()
            tiling_param_dict = self.get_tiling_params(mode_idx, tiler)

            #cost_model.calculate_control_overheads(tiling_param_dict)
            cost_model.estimate_control_overheads(tiling_param_dict)
            cost_model.calculate_kernel_cycles()
            cost_model.calculate_array_cycles()
            cost_model.calculate_layer_cycles()

        # Pre-allocate lists with estimated capacity to reduce reallocations
        all_cycles = []
        indexes = []
        iters = []
        
        for midx, cost_model in enumerate(self.cost_models):
            tiler = self.tilers[midx]

            for sched in tiler.schedule_list:
                if tiler.valid_core_indices[sched] is None or len(tiler.valid_core_indices[sched]) == 0:
                    continue

                # Get all valid cycles for this schedule
                cycles = cost_model.total_layer_cycles[sched]
                if cycles is None or len(cycles) == 0:
                    continue

                # Record cycles for both pingpong options
                all_cycles.append(cycles)

                # Use helper method for efficient index array creation
                num_cycles = len(cycles)
                indexes.append(self._create_efficient_index_array(midx, sched, num_cycles))

                # Record iteration values more efficiently
                loop_iters = tiler.valid_loop_iters[sched]
                iters.append(np.column_stack([
                    loop_iters['Tb'],
                    loop_iters['Tm'], 
                    loop_iters['Tk'],
                    loop_iters['Tn']
                ]))

        if all_cycles:
            self.all_cycles = np.vstack(all_cycles)
            self.indexes = np.vstack(indexes)
            self.iters = np.vstack(iters)
        else:
            # Initialize empty arrays with proper shape
            self.all_cycles = np.array([]).reshape(0, 2)
            self.indexes = np.array([]).reshape(0, 3)
            self.iters = np.array([]).reshape(0, 4)

    def calculate_min_cycles_and_macs(self, opt_midx):
        layer = self.tilers[opt_midx].layer
        overlay = self.tilers[opt_midx].overlay
        device = self.tilers[opt_midx].device

        in_dtype = layer.in_datatype[1:] if layer.in_datatype[0]=='u' else layer.in_datatype
        wgt_dtype = layer.wgt_datatype[1:] if layer.wgt_datatype[0]=='u' else layer.wgt_datatype

        ifm_shape = layer.in_act_shape
        wgt_shape = layer.in_wgt_shape

        macs = int(np.prod(ifm_shape)) * wgt_shape[-1]
        macs_per_cycle = device.macs_per_cycle[in_dtype+'x'+wgt_dtype]
        num_cores = overlay.rows * overlay.cols

        min_cycles = macs/macs_per_cycle/num_cores

        return min_cycles, macs

    def format_tiling_output(self, opt_midx, opt_sched, opt_valid_idx, pingpong, debug=False, scheduler_pass: bool=True):
        tiler = self.tilers[opt_midx]
        layer = tiler.layer
        
        # Extract all tiling data efficiently
        extracted_data = self._extract_tiling_data(tiler, opt_sched, opt_valid_idx)
        
        # Get operation type flags
        op_flags = self._get_op_type_flags(layer.op_type)

        # Clean up vars_dict
        if '__builtins__' in tiler.vars_dict:
            del tiler.vars_dict['__builtins__']

        # Set up placement info
        core_tiling = extracted_data['core_tiling']
        tiler.vars_dict.update({
            'Msubv': core_tiling['ifm'][-2],
            'Ksubv': core_tiling['ifm'][-1],
            'Nsubv': core_tiling['wgt'][-1],
            'single': int(tiler.valid_buffer_placement_choices[opt_sched][opt_valid_idx] == 'single_single')
        })

        # Generate placement constraints
        placement_dict = {}
        for buff, bankdict in tiler.kernel.placement_outputs.items():
            placement_dict[buff] = {k: int(eval(str(v), tiler.vars_dict)) for k, v in bankdict.items()}

        # Use unpadded M dimensions for M1N* splits
        memtile_tiling = extracted_data['memtile_tiling']
        msubv = memtile_tiling['ifm'][-2]
        if (tiler.overlay.core_splits['ifm'][0] == 1 and layer.in_act_shape[-2] <= msubv):
            memtile_tiling['ifm'][-2] = int(layer.in_act_shape[-2])
            memtile_tiling['ofm'][-2] = int(layer.out_act_shape[-2])

        # Compute derived arrays
        derived_arrays = self._compute_derived_arrays(extracted_data, layer, tiler)

        n_batch = int(tiler.original_shapes['ifm'][0])
        
        # Create scheduling dictionary
        schedule_dict = self._create_schedule_dict(opt_sched, pingpong, n_batch, tiler, opt_valid_idx)
        
        # Adjust core_iters for schedule 1
        core_iters = extracted_data['core_iters']
        if opt_sched == 1:
            core_iters['ofm'][-1] = 1

        # Create layer dictionary for the output
        layerdict = copy.deepcopy(vars(layer))
        
        # Create padding structures
        host_layer_padding, dma_layer_padding = self._create_padding_structures(extracted_data, tiler, op_flags)

        # Prepare all attribute dictionaries for special operations
        all_attr_dicts = [
            core_tiling, core_iters, memtile_tiling, extracted_data['memtile_iters'],
            derived_arrays['memtile_sizes'], derived_arrays['shim_tilings'], 
            derived_arrays['shim_sizes'], derived_arrays['dram_shapes'], derived_arrays['dram_sizes']
        ]

        # Handle special operations
        self._handle_special_operations(op_flags, extracted_data, host_layer_padding, dma_layer_padding, all_attr_dicts)

        # Update layer dictionary efficiently
        self._update_layer_dict_efficiently(layerdict, extracted_data)

        # Create final structures
        original_dimensions = [
            {"input0": layerdict['in_act_shape_orig']},
            {"input1": layerdict['in_wgt_shape_orig']},
            {"output0": layerdict['out_act_shape_orig']}
        ]

        GEMM_VEC_COEFFS = 0 if 'coeff_shape' not in layerdict else layerdict['coeff_shape'][0] // layerdict['in_wgt_shape_orig'][-1]

        # Set up test bench arguments
        test_cpp_name = tiler.kernel.testbench_args['HostName']
        tb_cflags = {}
        min_cycles = 0
        opt_cycles = 0
        macs = 0
        if scheduler_pass:
            tb_cflags = self._create_testbench_flags(tiler, layerdict, core_tiling, derived_arrays, 
                                                   GEMM_VEC_COEFFS, n_batch, extracted_data, 
                                                   opt_midx, op_flags)
            min_cycles, macs = self.calculate_min_cycles_and_macs(opt_midx)
            opt_cycles = self.cost_models[opt_midx].total_layer_cycles[opt_sched][opt_valid_idx][pingpong]

        # Create final tiling parameters structure
        tiling_params = {
            'core_tile_params': {'subvols': core_tiling, 'iters': core_iters},
            'mem_tile_params': {'subvols': memtile_tiling, 'iters': extracted_data['memtile_iters'], 'sizes': derived_arrays['memtile_sizes']},
            'shim_tile_params': {'subvols': derived_arrays['shim_tilings'], 'sizes': derived_arrays['shim_sizes']},
            'dram_params': {'shapes': derived_arrays['dram_shapes'], 'sizes': derived_arrays['dram_sizes']},
            'scheduling': schedule_dict,
            'original_dimensions': original_dimensions,
            'host_layer_padding': host_layer_padding,
            'dma_layer_padding': dma_layer_padding,
            'kernel_info': {'placement_constraints': placement_dict},
            'overlay_info': {
                'overlay': self.overlay,
                'mode': self.modes[opt_midx],
                'shape': {'row': tiler.overlay.rows, 'col': tiler.overlay.cols}
            },
            'layer_info': layerdict,
            'testbench_args': {'HOST_NAME': test_cpp_name, 'COMPILE_FLAGS': tb_cflags},
            'cycle_counts': {'macs':macs, 'min_cycles': min_cycles, 'layer_cycles': opt_cycles}
        }

        # Add debug information if requested
        if debug:
            tiling_params["debug_info"] = self._create_debug_info(opt_midx, opt_sched, opt_valid_idx, pingpong)

        return tiling_params

    def _filter_first_pass_tilings(self, idx, pp):
        """
        Filter solutions with first pass based on scheduler constraints.
        """
        scheduler_coverage_mask = np.zeros_like(idx, dtype=bool)
        selected_indices = self.indexes[idx]
        midx_sched_arr = np.unique(selected_indices[:, 0:2], axis=0)

        start_filtered_indices = 0
        for (midx, sched) in midx_sched_arr:
            tiler = self.tilers[midx]

            # Create mask for current midx and sched
            current_mask = (selected_indices[:, 0] == midx) & (selected_indices[:, 1] == sched)
            filtered_indices = selected_indices[current_mask, 2]
            end_filtered_indices = start_filtered_indices + len(filtered_indices)

            # Get valid filtered data
            valid_dma_padding = tiler.valid_dma_padding[sched]
            valid_loop_iters = tiler.valid_loop_iters[sched]

            ifm_dma_pad = valid_dma_padding['ifm'][filtered_indices]
            wgt_dma_pad = valid_dma_padding['wgt'][filtered_indices]
            ofm_dma_depad = valid_dma_padding['ofm'][filtered_indices]

            loop_iters_slice = {
                'Tb': valid_loop_iters['Tb'][filtered_indices],
                'Tm': valid_loop_iters['Tm'][filtered_indices],
                'Tk': valid_loop_iters['Tk'][filtered_indices],
                'Tn': valid_loop_iters['Tn'][filtered_indices]
            }

            scheduler_coverage_dict = {
                "sched": sched,
                "ifm_dma_pad": ifm_dma_pad,
                "wgt_dma_pad": wgt_dma_pad,
                "ofm_dma_depad": ofm_dma_depad,
                "loop_iters": loop_iters_slice,
                "perm": {
                    "A": tiler.layer.rev_permA,
                    "B": tiler.layer.rev_permB,
                    "Y": tiler.layer.permY
                },
                "ifm_ping_pong": pp[start_filtered_indices:end_filtered_indices],
                "unicast": tiler.overlay.unicast,
                "mode": tiler.overlay.mode
            }

            # Evaluate constraints more efficiently - avoid creating intermediate list
            constraint_results = []
            for constraint_expr in tiler.kernel.scheduler_coverage_constraints.values():
                constraint_results.append(eval(constraint_expr, scheduler_coverage_dict))
            
            # Use np.all with axis parameter instead of converting to list
            scheduler_coverage_mask[start_filtered_indices:end_filtered_indices] = np.all(
                np.column_stack(constraint_results), axis=1
            )

            start_filtered_indices = end_filtered_indices

        return scheduler_coverage_mask

    def _filter_by_iterations(self, idx, pp):
        """
        Filter solutions based on iteration criteria: minimize k, n, m, then total iterations.
        
        Args:
            idx: Array of indices into self.indexes
            pp: Array of pingpong options
            
        Returns:
            tuple: (best_idx, best_pp) or (None, None) if no valid solutions
        """
        if len(idx) == 0:
            return None, None
        
        # Extract iters information
        b_iters, m_iters, k_iters, n_iters = self.iters.T
        total_iters = b_iters * m_iters * k_iters * n_iters

        # Define priority filters and apply them sequentially
        priority_filters = [
            k_iters,
            n_iters,
            m_iters,
            b_iters,
            total_iters
        ]
        
        filtered_idx = idx
        filtered_pp = pp
        
        # Apply filters in priority order
        for iters_arr in priority_filters:
            if len(filtered_idx) <= 1:
                break
                
            min_val = np.min(iters_arr[filtered_idx])
            min_mask = iters_arr[filtered_idx] == min_val
            filtered_idx = filtered_idx[min_mask]
            filtered_pp = filtered_pp[min_mask]
            
            if len(filtered_idx) == 0:
                return None, None
        
        # Return the first (and possibly only) solution
        return filtered_idx[0], filtered_pp[0]

    def _find_best_tiling(self, candidate_idx, candidate_pp):
        """
        Find the best tiling from the given candidate indices and pingpong options.

        Args:
            candidate_idx: Array of indices into self.indexes
            candidate_pp: Array of pingpong options for each index

        Returns:
            tuple: (best_idx, best_pp) indices for the optimal tiling
        """
        if len(candidate_idx) == 0:
            return None, None
            
        # Find the minimum cycle count
        min_cycle = np.min(self.all_cycles[candidate_idx, candidate_pp])

        # Filter all equal cycles
        cycle_mask = self.all_cycles[candidate_idx, candidate_pp] == min_cycle
        filtered_idx = candidate_idx[cycle_mask]
        filtered_pp = candidate_pp[cycle_mask]

        # Find optimal solution from filtered solutions
        return self._filter_by_iterations(filtered_idx, filtered_pp)

    def _filter_by_mode_schedule(self, outmode, outsched):
        """
        Filter tiling configurations by mode and schedule if specified.
        
        Args:
            outmode: Output mode index to filter by (or None)
            outsched: Output schedule index to filter by (or None)
            
        Raises:
            ValueError: If no valid tiling found for given mode and schedule
        """
        if outmode is None or outsched is None:
            return
            
        mode_sched_match = (self.indexes[:, 0] == outmode) & (self.indexes[:, 1] == outsched)
        if np.any(mode_sched_match):
            self.all_cycles = self.all_cycles[mode_sched_match]
            self.indexes = self.indexes[mode_sched_match]
            self.iters = self.iters[mode_sched_match]
        else:
            raise ValueError(f"No valid tiling found for mode {outmode} and schedule {outsched}")

    def find_optimal_tiling(self, outmode=None, outsched=None, debug=False):
        """
        Find the optimal tiling configuration, prioritizing solutions without DMA padding.
        
        Args:
            outmode: Optional mode index to filter by
            outsched: Optional schedule index to filter by
            debug: Whether to include debug information in output
            
        Returns:
            dict: Optimal tiling parameters
            
        Raises:
            ValueError: If no valid tiling configurations found
        """
        self.calculate_tiling_cycles()

        if len(self.all_cycles) == 0:
            raise ValueError("No valid tiling configurations found")

        # Filter by mode and schedule if specified
        self._filter_by_mode_schedule(outmode, outsched)

        # Get all valid configuration indices (where cycles is not inf)
        valid_idx, valid_pp = np.where(~np.isinf(self.all_cycles))

        # If no valid configurations found, something might be wrong
        if len(valid_idx) == 0:
            raise ValueError("No valid tilings found (all have infinite cycles)")

        # First, try to find solutions with first pass DMA padding
        first_pass_padding_mask = self._filter_first_pass_tilings(valid_idx, valid_pp)

        # Process solutions with first pass padding if they exist
        if np.any(first_pass_padding_mask):
            first_pass_idx = valid_idx[first_pass_padding_mask]
            first_pass_pp = valid_pp[first_pass_padding_mask]
            best_idx, best_pp = self._find_best_tiling(first_pass_idx, first_pass_pp)
        else:
            # If no valid first pass padding solutions, use all valid solutions
            best_idx, best_pp = self._find_best_tiling(valid_idx, valid_pp)

        if best_idx is None:
            raise ValueError("Failed to find optimal tiling")

        # Extract optimal parameters
        opt_midx = int(self.indexes[best_idx][0])
        opt_sched = int(self.indexes[best_idx][1])
        opt_valid_idx = int(self.indexes[best_idx][2])

        # Override with debug info if available
        if (hasattr(self.layer, 'debug_info') and
            self.layer.debug_info.get('pingpong') is not None):
            best_pp = int(self.layer.debug_info.get('pingpong'))

        # Generate output for the optimal tiling
        return self.format_tiling_output(opt_midx, opt_sched, opt_valid_idx, best_pp, debug)

    def dump_all_tilings(self):
        self.calculate_tiling_cycles()
        all_tilings = {}
        for tileid, (midx, sched, core_idx) in enumerate(self.indexes):
            # tmp={}
            for pingpong in [0,1]:
                if self.all_cycles[tileid][pingpong]!=np.inf:
                    all_tilings[str(tileid)+'_'+str(pingpong)] = self.format_tiling_output(midx, sched, core_idx, pingpong)
                    # tmp[pingpong] = self.format_tiling_output(midx, sched, memtiling, coretiling, pingpong)
            # all_tilings[tileid] = tmp

        import json
        with open('all_tiling.json', "w") as fir:
            json.dump(all_tilings, fir, indent=2)
        return all_tilings

if __name__=='__main__':
    import json
    from layer import Layer
    from kernel import Kernel
    from device import Device
    with open('OGOAT/src/Tiler/tst_layer.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    d = Device('strix')
    ker = Kernel(l)

    t = MatMulTilingOpt(l,d,'8x4',ker)
    r=t.find_optimal_tiling()

    with open('tst_tiling.json', 'w') as f:
        json.dump(r,f,indent=2)
    # # import json
    # import argparse
    # from layer import Layer

    # dtype_bytes = {
    #     'uint16': 2,
    #     'int16': 2,
    #     'uint8': 1,
    #     'int8':1,
    #     'uint4': 0.5,
    #     'int4': 0.5
    # }

    # # parser = argparse.ArgumentParser(description="MatMul Tiling Optimization")
    # # parser.add_argument("--shape", required=True, help="Shape in format MxKxN")
    # # parser.add_argument("--datatypes", required=True, help="Datatypes in format in_dtypexwgt_dtypexout_dtype")
    # # parser.add_argument("--modifier", required=False, help="Modifier")

    # # args = parser.parse_args()

    # # shape_input = args.shape
    # # datatype_input = args.datatypes
    # # if args.modifier:
    # #     modifier_input = args.modifier
    # # else:
    # #     modifier_input = ''

    # mlist, klist, nlist = [[4096],[4096],[512]]#map(int, shape_input.split('x'))
    # in_dtype, wgt_dtype, out_dtype = ['uint16','uint8','uint16']#datatype_input.split('x')
    # modifier_input = ''





    # ldict_list = []
    # for M in mlist:
    #     for K in klist:
    #         for N in nlist:
    #             tmp = {
    #                 "op_type": "MatMul_qdq_"+modifier_input+in_dtype+'x'+wgt_dtype+'x'+out_dtype,
    #                 "in_act_shape": [M,  K],
    #                 "in_wgt_shape": [K, N],
    #                 "in_wgt1_shape": [],
    #                 "out_act_shape": [M, N],
    #                 "in_datatype": in_dtype,
    #                 "wgt_datatype": wgt_dtype,
    #                 "wgt1_datatype": "float32",
    #                 "out_datatype": out_dtype,
    #                 "in_bytes": dtype_bytes[in_dtype],
    #                 "wgt_bytes": dtype_bytes[wgt_dtype],
    #                 "wgt1_bytes": 4,
    #                 "out_bytes": dtype_bytes[out_dtype],
    #                 "attributes": "None",
    #                 "qdq_symmetry": 0,
    #                 "coeff_shape": [
    #                     2*N if dtype_bytes[wgt_dtype]<1 else N
    #                 ],
    #                 "in_act_residency": "L3",
    #                 "out_act_residency": "L3",
    #             }
    #             # # tmp = copy.deepcopy(ld)
    #             # tmp['in_act_shape'] = [m,k]
    #             # tmp['in_wgt_shape'] = [k,n]
    #             # tmp['out_act_shape'] = [m,n]
    #             # tmp['coeff_shape'] = [n]

    #             ldict_list.append(tmp)

    #             # a=t.dump_all_tilings()


    #             l = Layer(tmp)

    #             d = Device('strix')
    #             ker = Kernel(l)

    #             t = MatMulTilingOpt(l,d,'8x4',ker)
    #             r=t.find_optimal_tiling()

    #             mpa = r['dma_layer_padding'][0]['ifm']['dims'][0] + r['host_layer_padding'][0]['ifm']['dims'][0]
    #             kpa, npa = r['host_layer_padding'][1]['wgt']['dims']
    #             print(M,K,N,
    #                 mpa-M, kpa-K, npa-N,
    #                 r['overlay_info']['mode'], r['core_tile_params'])
