import numpy as np
import warnings
import math
import re

from OGOAT.src.Scheduling_Engine.infra import custom_dict
from OGOAT.src.Scheduling_Engine.infra import const
from OGOAT.src.Scheduling_Engine.schedules import scheduler as sch
from OGOAT.src.Scheduling_Engine.code_gen.template_base_gemm import (
        params_funcs
)

from OGOAT.src.Scheduling_Engine.schedules.buffer_allocator import BufferAllocator

from OGOAT.src.Tiler.utils import process_overheads

warnings.filterwarnings('ignore')

class MatMulCostModel:

    def __init__(self, tilings):
        self.tilings = tilings
        self.kernel_cycles = {k: {} for k in self.tilings.schedule_list}
        self.array_cycles = {k: {} for k in self.tilings.schedule_list}
        self.layer_cycles = {k: {} for k in self.tilings.schedule_list}
        self.total_layer_cycles = {k: {} for k in self.tilings.schedule_list}
        self.control_overheads = {k: {} for k in self.tilings.schedule_list}

    def calculate_kernel_cycles(self):
        """Calculate kernel cycles using vectorized operations."""
        if not hasattr(self.tilings, 'num_combinations') or self.tilings.num_combinations == 0:
            return
            
        Bgran = self.tilings.kernel.Bgran
        Mgran = self.tilings.kernel.Mgran
        Kgran = self.tilings.kernel.Kgran
        Ngran = self.tilings.kernel.Ngran

        layer = self.tilings.layer
        device = self.tilings.device
        in_dtype = layer.in_datatype[1:] if layer.in_datatype[0]=='u' else layer.in_datatype
        wgt_dtype = layer.wgt_datatype[1:] if layer.wgt_datatype[0]=='u' else layer.wgt_datatype
        macs_per_cycle = device.macs_per_cycle[in_dtype+'x'+wgt_dtype]

        for sched in self.tilings.schedule_list:
            if self.tilings.valid_core_indices[sched] is None or len(self.tilings.valid_core_indices[sched]) == 0:
                continue
                
            valid_ifm = self.tilings.valid_core_subvols[sched]['ifm']
            valid_wgt = self.tilings.valid_core_subvols[sched]['wgt']
            valid_loop_iters = self.tilings.valid_loop_iters[sched]

            Bcore_subv = valid_ifm[:, 0]
            Mcore_subv = valid_ifm[:, -2]
            Kcore_subv = valid_ifm[:, -1]
            Ncore_subv = valid_wgt[:, -1]

            args = {"overheads": self.tilings.kernel.overheads, "op_type": self.tilings.layer.op_type, "macs_per_cycle": macs_per_cycle, 
                    "Mgran": Mgran, "Kgran": Kgran, "Ngran": Ngran, "M": Mcore_subv, "K": Kcore_subv, "N": Ncore_subv, 
                    "Tb": valid_loop_iters["Tb"], "Tm": valid_loop_iters["Tm"], "Tk": valid_loop_iters["Tk"], "Tn": valid_loop_iters["Tn"]}

            if 'actxact' in self.tilings.layer.op_type:
                args["transpose_A"] = self.tilings.layer.rev_permA[2] == 1 # Set when M and K dims are swapped
                args["transpose_B"] = self.tilings.layer.rev_permB[2] == 1 # Set when K and N dims are swapped

            overhead_cycles, var_dict, overhead_details = process_overheads(args)
            kernel_cycles = sum(overhead_cycles.values()) / var_dict['loops']
            self.kernel_cycles[sched] = np.ceil(Bcore_subv / Bgran) * kernel_cycles

    def calculate_array_cycles(self):
        """Calculate array cycles using vectorized operations."""
        if not hasattr(self.tilings, 'num_combinations') or self.tilings.num_combinations == 0:
            return
            
        for sched in self.tilings.schedule_list:
            if self.tilings.valid_core_indices[sched] is None or len(self.tilings.valid_core_indices[sched]) == 0:
                continue
                
            valid_ifm = self.tilings.valid_core_subvols[sched]['ifm']
            valid_wgt = self.tilings.valid_core_subvols[sched]['wgt']
            valid_ofm = self.tilings.valid_core_subvols[sched]['ofm']
            
            # Vectorized calculation
            ifm_cycles = np.prod(valid_ifm, axis=1) * self.tilings.layer.in_bytes / self.tilings.device.max_stream_bw
            wgt_cycles = np.prod(valid_wgt, axis=1) * self.tilings.layer.wgt_bytes / self.tilings.device.max_stream_bw
            ofm_cycles = np.prod(valid_ofm, axis=1) * self.tilings.layer.out_bytes / self.tilings.device.max_stream_bw
            
            self.array_cycles[sched] = np.column_stack([ifm_cycles, wgt_cycles, self.kernel_cycles[sched], ofm_cycles])

    def calculate_layer_cycles(self):
        """Calculate layer cycles using vectorized operations."""
        if not hasattr(self.tilings, 'num_combinations') or self.tilings.num_combinations == 0:
            return

        shim_bandwidth_unicast = self.tilings.overlay.cols * self.tilings.device.max_stream_bw
        shim_bandwidth_broadcast = self.tilings.overlay.rows * self.tilings.device.max_stream_bw

        # For unicast operand the overall dram bandwidth scales with number of columns/shim tiles and reduces with column disable
        # For broadcast operand the overall dram bandwidth scales with number of rows and reduces with row disable
        # TODO: Heuristic factor of 8/7 to align dram bandwidth with HW_latency. Better update constant in device.yaml
        dram_read_bandwidth_unicast = min(shim_bandwidth_unicast, self.tilings.device.dram_read_bandwidth * 8 / 7)  
        dram_read_bandwidth_broadcast = min(shim_bandwidth_broadcast, self.tilings.device.dram_read_bandwidth * 8 / 7)  
        dram_write_bandwidth = min(shim_bandwidth_unicast, self.tilings.device.dram_write_bandwidth * 8 / 7)   # OFM is always unicast

        # Derate function to account for stream steps
        derate = lambda x: 2**np.minimum(np.floor(np.log2(x/256)),0)

        for sched in self.tilings.schedule_list:
            if self.tilings.valid_core_indices[sched] is None or len(self.tilings.valid_core_indices[sched]) == 0:
                continue
                
            # Get data structures for this configuration
            array_cycles = self.array_cycles[sched]
            valid_core_subvols = self.tilings.valid_core_subvols[sched]
            valid_core_iters = self.tilings.valid_core_iters[sched]
            valid_memtile_subvols = self.tilings.valid_memtile_subvols[sched]
            valid_memtile_iters = self.tilings.valid_memtile_iters[sched]
            valid_dma_padding = self.tilings.valid_dma_padding[sched]
            valid_loop_iters = self.tilings.valid_loop_iters[sched]

            # Extract iteration values
            Tb_core_ifm, Tm_core_ifm, Tk_core_ifm = valid_core_iters['ifm'].T
            Tb_core_wgt, Tk_core_wgt, Tn_core_wgt = valid_core_iters['wgt'].T
            Tb_core_ofm, Tm_core_ofm, Tn_core_ofm = valid_core_iters['ofm'].T
            
            Tb_mem_ifm, Tm_mem_ifm, Tk_mem_ifm = valid_memtile_iters['ifm'].T
            Tb_mem_wgt, Tk_mem_wgt, Tn_mem_wgt = valid_memtile_iters['wgt'].T
            Tb_mem_ofm, Tm_mem_ofm, Tn_mem_ofm = valid_memtile_iters['ofm'].T

            Tm = valid_loop_iters['Tm']
            Tk = valid_loop_iters['Tk']
            Tn = valid_loop_iters['Tn']
            Tm_Tk_Tn = valid_loop_iters['Tm'] * valid_loop_iters['Tk'] * valid_loop_iters['Tn']
            
            # Calculate dram sizes correcting for DMA padding for single column
            ifm_sizes = np.prod(valid_memtile_subvols['ifm'] - (valid_dma_padding['ifm'] / self.tilings.overlay.mem_splits['ifm']) * 
                                (valid_memtile_iters['ifm'] == 1), axis=1) * self.tilings.layer.in_bytes
            wgt_sizes = np.prod(valid_memtile_subvols['wgt'], axis=1) * self.tilings.layer.wgt_bytes

            # Add bias bytes for weight
            if 'actxact' not in self.tilings.layer.op_type:
                num_wgt_subvols = np.prod((valid_memtile_subvols['wgt'] // valid_core_subvols['wgt']), axis=1)
                wgt_sizes += num_wgt_subvols * self.tilings.layer.bias_bytes * valid_core_subvols['wgt'][:, -1]

            ofm_sizes = np.prod(valid_memtile_subvols['ofm'] - (valid_dma_padding['ofm'] / self.tilings.overlay.mem_splits['ofm']) * 
                                (Tm_mem_ofm == 1).reshape(-1,1), axis=1) * self.tilings.layer.out_bytes

            # Calculate stream steps
            ifm_stream_steps = (valid_memtile_subvols['ifm'][:, -1] * self.tilings.layer.in_bytes * 
                                np.where(Tk_mem_ifm == 1, valid_memtile_subvols['ifm'][:, -2], 1))
            
            if 'actxact' in self.tilings.layer.op_type:
                wgt_stream_steps = (valid_memtile_subvols['wgt'][:, -1] * self.tilings.layer.wgt_bytes *
                                    np.where((Tn_mem_wgt == 1) & (self.tilings.overlay.mem_splits['wgt'][-1] == 1), 
                                             valid_memtile_subvols['wgt'][:, -2],
                                             np.where((Tn_mem_wgt == 1), self.tilings.overlay.mem_splits['wgt'][-1], 1)))
            else:
                wgt_stream_steps = np.prod(valid_memtile_subvols['wgt'], axis=1) * self.tilings.layer.wgt_bytes

            if self.tilings.overlay.unicast == 'act':
                ifm_cycles = ifm_sizes / dram_read_bandwidth_unicast / derate(ifm_stream_steps)
                wgt_cycles = wgt_sizes / dram_read_bandwidth_broadcast / derate(wgt_stream_steps)
            elif self.tilings.overlay.unicast == 'wgt':
                ifm_cycles = ifm_sizes / dram_read_bandwidth_broadcast / derate(ifm_stream_steps)
                wgt_cycles = wgt_sizes / dram_read_bandwidth_unicast / derate(wgt_stream_steps)
            else:
                raise ValueError("unicast operand not correctly defined in overlay")

            ofm_stream_steps = (valid_memtile_subvols['ofm'][:, -1] * self.tilings.layer.out_bytes *
                                np.where((Tn_mem_ofm == 1) & (self.tilings.overlay.mem_splits['ofm'][-1] == 1), 
                                         valid_memtile_subvols['ofm'][:, -2], 1))
            ofm_cycles = ofm_sizes / dram_write_bandwidth / derate(ofm_stream_steps)

            # Extract array cycles
            t_ifm_stream, t_wgt_stream, t_kernel, t_ofm_stream = array_cycles.T
            t_ifm_dram, t_wgt_dram, t_ofm_dram = ifm_cycles, wgt_cycles, ofm_cycles
            
            # Get pingpong settings
            memtile_pingpong = self.tilings.valid_memtile_pingpong[sched]
            disable_pingpong_for_mpadding = self.tilings.valid_disable_ifm_pingpong[sched]

            #TODO: Calculate reenque penalty
            re_enq_penalty            = round(float(self.tilings.device.re_enq_cost_us) * float(self.tilings.device.core_clock_freq))
            reenq_penalty_single_buff = self.control_overheads[sched][:,0] * re_enq_penalty
            reenq_penalty_double_buff = self.control_overheads[sched][:,1] * re_enq_penalty

            #TODO: Revisit this. Currently handling batch phasing in very simplified way, make it robust
            batch_phasing_single_buff = ((Tb_mem_ifm > 1) & (self.control_overheads[sched][:,0] % Tb_mem_ifm == 0) & (Tm_Tk_Tn > 1))
            batch_phasing_double_buff = ((Tb_mem_ifm > 1) & (self.control_overheads[sched][:,1] % Tb_mem_ifm == 0) & (Tm_Tk_Tn > 1))
            phasing_factor_single_buff = np.where(batch_phasing_single_buff, Tb_mem_ifm, 1)
            phasing_factor_double_buff = np.where(batch_phasing_double_buff, Tb_mem_ifm, 1)

            # Store layer cycles info
            self.layer_cycles[sched] = np.column_stack([
                t_ifm_stream, t_wgt_stream, t_kernel, t_ofm_stream,
                t_ifm_dram, t_wgt_dram, t_ofm_dram,
                reenq_penalty_single_buff, reenq_penalty_double_buff,
                Tb_core_ifm, Tm_core_ifm, Tk_core_ifm, 
                Tb_core_wgt, Tk_core_wgt, Tn_core_wgt,
                Tb_core_ofm, Tm_core_ofm, Tn_core_ofm,
                Tb_mem_ifm, Tm_mem_ifm, Tk_mem_ifm, 
                Tb_mem_wgt, Tk_mem_wgt, Tn_mem_wgt, 
                Tb_mem_ofm, Tm_mem_ofm, Tn_mem_ofm
            ])

            input_stream_cycles = np.maximum(t_ifm_stream, t_wgt_stream)
            ifm_readers = 8
            wgt_readers = 4
            ofm_writers = 8
            if self.tilings.overlay.unicast == 'wgt':
                ifm_readers = 4
                wgt_readers = 8

            eff_ifm_dram = t_ifm_dram * ifm_readers
            eff_wgt_dram = t_wgt_dram * wgt_readers
            eff_ofm_dram = t_ofm_dram * ofm_writers

            min_ifm_wgt_dram = np.minimum(t_ifm_dram, t_wgt_dram)
            eff_ifm_wgt_dram = np.maximum(min_ifm_wgt_dram * 12 + (t_ifm_dram - min_ifm_wgt_dram) * ifm_readers + t_ifm_stream, 
                                            min_ifm_wgt_dram * 12 + (t_wgt_dram - min_ifm_wgt_dram) * wgt_readers + t_wgt_stream)

            if sched == 1:
                pipeline1 = ((Tk_core_wgt - 1) * np.maximum(t_kernel, input_stream_cycles) + 
                             np.maximum(t_kernel + t_ofm_stream * (Tk_core_wgt == 1), input_stream_cycles)) * (Tn_core_wgt > 1)
                
                pipeline2 = ((Tn_core_wgt - 2) * np.maximum(pipeline1, t_ofm_stream * (Tk_core_wgt > 1) + eff_ofm_dram + t_ofm_stream * (Tk_core_wgt == 1))) * (Tn_core_wgt > 2)

                pipeline3 = np.maximum(((Tk_core_wgt - 1) * np.maximum(t_kernel, input_stream_cycles) + t_kernel), 
                                       (t_ofm_stream * (Tk_core_wgt > 1) + eff_ofm_dram) * (Tn_core_wgt > 1))
                
                pipeline4 = np.maximum(((((Tk_core_wgt - 2) * np.maximum(t_kernel, input_stream_cycles)) + 
                                         np.maximum(t_kernel * 2 + t_ofm_stream, t_ifm_stream * 2 + eff_ifm_dram, t_wgt_stream * 2)) * (Tk_core_wgt > 1) + 
                                         np.maximum(t_kernel + t_ofm_stream, np.maximum(eff_ifm_dram + t_ifm_stream, t_wgt_stream)) * (Tk_core_wgt == 1)), 
                                         t_ofm_stream * (Tk_core_wgt > 1) + eff_ofm_dram + t_ofm_stream * (Tk_core_wgt == 1))
                
                pipeline5 = (pipeline1 + pipeline2 + pipeline3)
                pipeline6 = np.maximum(pipeline3, pipeline4 * ((Tm_mem_ifm > 1) | ((Tb_mem_ifm/phasing_factor_single_buff) > 1)))
                pipeline7 = (np.maximum(pipeline1, (t_ofm_stream * (Tk_core_wgt > 1) + eff_ofm_dram + t_ofm_stream * (Tk_core_wgt == 1)) * (Tn_core_wgt > 1)) + 
                             pipeline2 + pipeline6) * (Tm_mem_ifm * (Tb_mem_ifm/phasing_factor_single_buff) - 1)
                pipeline8 = np.maximum(pipeline5, eff_ifm_dram) * (Tm_mem_ifm * (Tb_mem_ifm/phasing_factor_single_buff) - 1) 

                # for single ifm buffer the pipeline stalls after each horizontal shard is generated
                total_cycles_single_buffer = (eff_ifm_wgt_dram + pipeline7 + pipeline5 + t_ofm_stream + eff_ofm_dram) * phasing_factor_single_buff + reenq_penalty_single_buff

                # for double ifm buffer the compute time for horizontal shard overlaps with ifm read time from dram
                total_cycles_double_buffer = np.full_like(total_cycles_single_buffer, np.inf) #(layer_startup_cycles + pipeline8 + pipeline5 + t_ofm_dram) + reenq_penalty_double_buff

                # Apply pingpong settings
                cycles_with_pingpong = np.column_stack([total_cycles_single_buffer, total_cycles_double_buffer])
                cycles_with_pingpong = cycles_with_pingpong / memtile_pingpong.astype(float) / disable_pingpong_for_mpadding.astype(float)
                
                self.total_layer_cycles[sched] = cycles_with_pingpong
            
            elif sched == 2:
                pipeline1 = ((Tk_mem_wgt - 1) * np.maximum(t_kernel, np.maximum(input_stream_cycles, eff_wgt_dram)) + 
                             np.maximum(t_kernel + t_ofm_stream * (Tk_mem_wgt == 1), np.maximum(input_stream_cycles, eff_wgt_dram))) * (Tn_mem_wgt > 1)
                
                pipeline2 = ((Tn_mem_wgt - 2) * np.maximum(pipeline1, t_ofm_stream * (Tk_mem_wgt > 1) + eff_ofm_dram + t_ofm_stream * (Tk_mem_wgt == 1))) * (Tn_mem_wgt > 2) # Assuming ofm_stream <<<< ofm_dram and ofm_stream < kernel

                pipeline3 = np.maximum(((Tk_mem_wgt - 1) * np.maximum(t_kernel, np.maximum(input_stream_cycles, eff_wgt_dram)) + t_kernel), 
                                       (t_ofm_stream * (Tk_mem_wgt > 1) + eff_ofm_dram) * (Tn_mem_wgt > 1))

                pipeline4 = np.maximum(((((Tk_mem_wgt - 2) * np.maximum(t_kernel, np.maximum(input_stream_cycles, eff_wgt_dram))) + 
                                        np.maximum(t_kernel * 2 + t_ofm_stream, t_ifm_stream + eff_ifm_wgt_dram)) * (Tk_mem_wgt > 1) + 
                                        np.maximum(t_kernel + t_ofm_stream, eff_ifm_wgt_dram) * (Tk_mem_wgt == 1)), 
                                        t_ofm_stream * (Tk_mem_wgt > 1) + eff_ofm_dram + t_ofm_stream * (Tk_mem_wgt == 1))
                
                pipeline5 = (pipeline1 + pipeline2 + pipeline3)
                pipeline6 = np.maximum(pipeline3, pipeline4 * ((Tm_mem_ifm > 1) | ((Tb_mem_ifm/phasing_factor_single_buff) > 1)))
                pipeline7 = (np.maximum(pipeline1, (t_ofm_stream * (Tk_core_wgt > 1) + eff_ofm_dram + t_ofm_stream * (Tk_core_wgt == 1)) * (Tn_core_wgt > 1)) + 
                             pipeline2 + pipeline6) * (Tm_mem_ifm * (Tb_mem_ifm/phasing_factor_single_buff) - 1)
                pipeline8 = np.maximum(pipeline5, eff_ifm_wgt_dram) * (Tm_mem_ifm * (Tb_mem_ifm/phasing_factor_single_buff) - 1) 
                
                # single ifm buffer stalls after each horizontal shard is generated
                total_cycles_single_buffer = (eff_ifm_wgt_dram + pipeline7 + pipeline5 + t_ofm_stream + eff_ofm_dram) * phasing_factor_single_buff + reenq_penalty_single_buff

                total_cycles_double_buffer = np.full_like(total_cycles_single_buffer, np.inf) #(eff_ifm_wgt_dram + pipeline8 + pipeline5 + eff_ofm_dram) + reenq_penalty_double_buff

                # Apply pingpong settings
                cycles_with_pingpong = np.column_stack([total_cycles_single_buffer, total_cycles_double_buffer])
                cycles_with_pingpong = cycles_with_pingpong / memtile_pingpong.astype(float) / disable_pingpong_for_mpadding.astype(float)
                
                self.total_layer_cycles[sched] = cycles_with_pingpong
            
            elif sched == 5:
                Tb_Tm_Tn = Tn_mem_wgt * Tm_mem_ifm * (Tb_mem_ifm/phasing_factor_double_buff)
                
                pipeline1 = ((Tk_mem_wgt - 1) * np.maximum(t_kernel, eff_ifm_wgt_dram) + 
                             np.maximum(t_kernel + t_ofm_stream * (Tk_mem_wgt == 1), eff_ifm_wgt_dram)) * (Tb_Tm_Tn > 1)

                pipeline2 = (np.maximum(pipeline1, t_ofm_stream * (Tk_mem_wgt > 1) + eff_ofm_dram + t_ofm_stream * (Tk_core_wgt == 1)) * (Tb_Tm_Tn - 2)) * (Tb_Tm_Tn > 2)
                
                pipeline3 = np.maximum((Tk_mem_wgt - 1) * np.maximum(t_kernel, eff_ifm_wgt_dram) + t_kernel, 
                                       (t_ofm_stream * (Tk_mem_wgt > 1) + eff_ofm_dram) * (Tb_Tm_Tn > 1))

                total_cycles_double_buffer = (eff_ifm_wgt_dram + pipeline1 + pipeline2 + pipeline3 + t_ofm_stream + eff_ofm_dram) * phasing_factor_double_buff + reenq_penalty_double_buff

                # Apply pingpong settings
                cycles_with_pingpong = np.column_stack([np.full_like(total_cycles_double_buffer, np.inf), total_cycles_double_buffer])
                cycles_with_pingpong = cycles_with_pingpong / memtile_pingpong.astype(float) / disable_pingpong_for_mpadding.astype(float)
                
                self.total_layer_cycles[sched] = cycles_with_pingpong

            Bsplit = self.tilings.overlay.core_splits['ifm'][0]
            # TODO: Fallback to head per array since Tm*Tk*Tn > 1 need re-enqueque in both head/subarray and head/array.
            self.total_layer_cycles[sched][(Bsplit > 4) & (Tm_Tk_Tn > 1)] = [np.inf, np.inf]
            #disable Tk==1 for now, tiler choose an odd subv with Tk==1 for 577x64x577 (subv: 8x64x88)
            self.total_layer_cycles[sched][(Bsplit == 4) & ~((Tm==1) | (Tn==1))] = [np.inf, np.inf]
            #Avoid K padding refactor issue, better solution needed (Tiler choose ksubv==88 which has refactor issue in dmac for 577x577x64)
            self.total_layer_cycles[sched][(Bsplit == 4) & ((Tm_Tk_Tn > 1) & (valid_core_subvols['ifm'][:,-1] > 64))] = [np.inf, np.inf]

    def calculate_control_overheads(self, tiling_param_dict):
        def calc_total_enq (params):
            #Creating a instance of params_funcs class and calling calc_rep_params to get the corresponding enqueue list.
            param_func = params_funcs()
            param_func.calc_rep_params(params)

            if not params['MemtileActPongRepeat']:
                 params['MemtileActRepeat'] = params['MemtileActPingRepeat']
            else:
                 params['MemtileActRepeat'] = [sum(x) for x in zip(params['MemtileActPingRepeat'], params['MemtileActPongRepeat'])]
            #Calculating the Total enqueue used.
            total_enq = (
                get_valid_reenqueue_size(params['ShimParamRepeat'])      * aie_cols + 
                get_valid_reenqueue_size(params['MemtileParamRepeat'])   * (aie_cols * mem_reenq_channels_unicast) + 
                get_valid_reenqueue_size(params['ShimQdqPrmRepeat'])     * aie_rows + 
                get_valid_reenqueue_size(params['MemtileQdqPrmRepeat'])  * (aie_rows * mem_reenq_channels_broadcast) + 
                get_valid_reenqueue_size(params['ShimActRepeat'])        * (aie_cols                                if self.tilings.overlay.unicast == "act" else aie_rows) +
                get_valid_reenqueue_size(params['MemtileActRepeat'])     * ((aie_cols * mem_reenq_channels_unicast) if self.tilings.overlay.unicast == "act" else (aie_rows * mem_reenq_channels_broadcast)) +
                get_valid_reenqueue_size(params['ShimWgtRepeat'])        * (aie_cols                                if self.tilings.overlay.unicast == "wgt" else aie_rows) + 
                get_valid_reenqueue_size(params['MemtileWgtRepeat'])     * ((aie_cols * mem_reenq_channels_unicast) if self.tilings.overlay.unicast == "wgt" else (aie_rows * mem_reenq_channels_broadcast)) +
                get_valid_reenqueue_size(params['ShimOutRepeat'])        * aie_cols + 
                get_valid_reenqueue_size(params['MemtileOutRepeat'])     * (aie_cols  * mem_reenq_channels_unicast) 
                         ) 
            del param_func
            return total_enq

        #Function to calculate the valid enqueue size in the given enqueue list.
        def get_valid_reenqueue_size (enqueue_list: list):
            if enqueue_list is None:
                return 0
            if all(not isinstance(item, list) for item in enqueue_list):
                flatten_list = enqueue_list[1:]
            else:
                flatten_list = [elem for item in enqueue_list for elem in item[1:]]

            enqueue_list_size = len(flatten_list)
            zeroCnt = flatten_list.count(0)
            return enqueue_list_size - zeroCnt


        def compute_control_overhead (tiling_params, sched):
            artifacts_dict = {
            		'kernel_info_obj'      : custom_dict.ReadOnlyDict(tiling_params['kernel_info']), 
            		'overlay_info_obj'     : custom_dict.ReadOnlyDict(tiling_params['overlay_info']), 
            		'layer_info_obj'       : custom_dict.ReadOnlyDict(tiling_params['layer_info']), 
            		'scheduling_obj'       : custom_dict.ReadOnlyDict(tiling_params['scheduling']), 
            		'mem_tile_params_obj'  : custom_dict.ReadOnlyDict(tiling_params['mem_tile_params']), 
            		'core_tile_params_obj' : custom_dict.ReadOnlyDict(tiling_params['core_tile_params']), 
            		'layer_padding_obj'    : custom_dict.ReadOnlyDict(tiling_params['dma_layer_padding']), 
            		'host_padding_obj'     : custom_dict.ReadOnlyDict(tiling_params['host_layer_padding']), 
            		'shim_tile_params_obj' : custom_dict.ReadOnlyDict(tiling_params['shim_tile_params']), 
            		'dram_params_obj'      : custom_dict.ReadOnlyDict(tiling_params['dram_params']), 
            		'program_arg_obj'      : {} 
            		}

            buffer_allocator = BufferAllocator(artifacts_dict, tiler_pass = True)
            pipeline_data = sch.SharedResource()
            buffer_allocator.execute(pipeline_data)
            buff = pipeline_data.info.get('BuffAllocator')
            buff_prm = buff[const.BufAllocator_Idx.BUFF_ALLOC_PARAM_IDX.value]
            param_func = params_funcs()
            params = param_func.initialize_params(buff, buff_prm)
            if params['dims'].wgt_subv_bytes > 4096 and not params['actxact']  and mode in ["M1K1N32", "M1K1N16"]:
                params['shim_wgt_reengueue'] = True
            params['ifm_channel']   = 'unicast' if self.tilings.overlay.unicast == "act" else 'broadcast'
            params['wgt_channel']   = 'unicast' if self.tilings.overlay.unicast == "wgt" else 'broadcast'
            
            total_enq_double_buffer = calc_total_enq(params)
            params['ping_pong_enable']['ifm'] = False
            params['ActPingPong']             =  params['ping_pong_enable']['ifm'] 
            total_enq_single_buffer = calc_total_enq(params) if sched != 5 else math.inf
            
            del params
            del param_func
            return [total_enq_single_buffer, total_enq_double_buffer]
            

        #Initializing the common variables of Tiler:
        mode     = self.tilings.overlay.mode
        aie_rows = self.tilings.overlay.rows
        aie_cols = self.tilings.overlay.cols

        #Info of BD count used for unicast & broadcast
        mem_reenq_channels_unicast   = 1 + aie_rows    #Mem -> 5
        mem_reenq_channels_broadcast = 2               #Mem -> 2

        for sched, tiling_params_list in tiling_param_dict.items():
            control_overhead_list = []
            for idx in range (len(tiling_params_list)):
                control_overhead_list.append(compute_control_overhead(tiling_params_list[idx], sched))
            #Storing Enqueue counts of each sched type.
            self.control_overheads[sched] = np.array (control_overhead_list)

    def estimate_control_overheads(self, tiling_param_dict):
        def est_total_enq (tiling_param_dict, sched):
            core_ifm_itr = tiling_param_dict['core_tile_params']['iters']['ifm']
            core_wgt_itr = tiling_param_dict['core_tile_params']['iters']['wgt']
            mem_ifm_itr  = tiling_param_dict['mem_tile_params']['iters']['ifm']
            mem_wgt_itr  = tiling_param_dict['mem_tile_params']['iters']['wgt']
            dma_padding  = tiling_param_dict['dma_layer_padding']
            K_padding    = [dma_padding[0]['ifm']['dims'][-1], dma_padding[1]['wgt']['dims'][-2]]
            N_padding    = [dma_padding[1]['wgt']['dims'][-1], dma_padding[2]['ofm']['dims'][-1]]
            M_padding    = [dma_padding[0]['ifm']['dims'][-2], dma_padding[2]['ofm']['dims'][-2]]
            Batch        = tiling_param_dict['scheduling']['Tbatch']
            Tm           = core_ifm_itr[-2] * mem_ifm_itr[-2]
            Tk           = core_wgt_itr[-2] * mem_wgt_itr[-2]
            Tn           = core_wgt_itr[-1] * mem_wgt_itr[-1]
            wgtsubvbytes = list(tiling_param_dict['kernel_info']['placement_constraints']['wgt'].values())[0]
            split        = [int(x) for x in re.findall(r'\d+', tiling_param_dict['overlay_info']['mode'])]
            actxwgt      = 'actxact' not in tiling_param_dict['layer_info']['op_type']

            if K_padding == [24, 24] and N_padding == [0, 0] and M_padding == [0, 0] and \
                Tm == 9 and Tk == 15 and Tn == 1 and sched == 2:
                pass

            shim_ifm_enq = 0
            shim_wgt_enq = 0
            shim_ofm_enq = 0

            if len(split)==4 and split[0] > 1: #row wise bmm
                if Tm==Tk==Tn==1: #iter mode
                    mem_ifm_enq  = 1                               - 1
                    mem_wgt_enq  = math.ceil(Batch / split[0])     - 1
                    mem_ofm_enq  = 1                               - 1
                    shim_ifm_enq = 1                               - 1
                    shim_wgt_enq = math.ceil(Batch / split[0])     - 1
                    shim_ofm_enq = 1                               - 1
                elif Tk > 1 and sum(K_padding) > 0:
                    mem_ifm_enq  = math.ceil(Batch / split[0])     - 1
                    mem_wgt_enq  = math.ceil(2 * Batch / split[0]) - 1
                    mem_ofm_enq  = math.ceil(Batch / split[0])     - 0
                    shim_ifm_enq = math.ceil(Batch / split[0])     - 1
                    shim_wgt_enq = math.ceil(Batch / split[0])     - 1
                    shim_ofm_enq = math.ceil(Batch / split[0])     - 0
                elif Tn > 1 and sum(N_padding) > 0:
                    mem_ifm_enq  = math.ceil(Batch / split[0])     - 1
                    mem_wgt_enq  = math.ceil(2 * Batch / split[0]) - 1
                    mem_ofm_enq  = math.ceil(2 * Batch / split[0]) - 1
                    shim_ifm_enq = math.ceil(Batch / split[0])     - 1
                    shim_wgt_enq = math.ceil(2 * Batch / split[0]) - 1
                    shim_ofm_enq = math.ceil(2 * Batch / split[0]) - 1
                else:
                    mem_ifm_enq  = math.ceil(Batch / split[0])     - 1
                    mem_wgt_enq  = math.ceil(Batch / split[0])     - 1
                    mem_ofm_enq  = math.ceil(Batch / split[0])     - 1
                    shim_ifm_enq = math.ceil(Batch / split[0])     - 1
                    shim_wgt_enq = math.ceil(Batch / split[0])     - 1
                    shim_ofm_enq = math.ceil(Batch / split[0])     - 1
                shim_prm_enq = 1         - 1
                mem_prm_enq  = 1         - 1
                shim_qdq_enq = math.ceil(Batch / split[0])     - 1
                mem_qdq_enq  = math.ceil(Batch / split[0])     - 1
            else:
                if Tm==Tk==Tn==1: #iter mode
                    mem_ifm_enq  = 1         - 1
                    mem_wgt_enq  = Batch     - 1
                    mem_ofm_enq  = 1         - 1
                    shim_ifm_enq = 1         - 1
                    shim_wgt_enq = Batch     - 1
                    shim_ofm_enq = 1         - 1
                elif Tk > 1 and sum(K_padding) > 0:
                    #pin full
                    if sched == 1:
                        mem_ifm_enq  = Tm      * Batch - 1
                        mem_wgt_enq  = 1       * Batch - 1 if actxwgt else 2*Tm*Tn * Batch - 1
                        mem_ofm_enq  = Tm*Tn   * Batch - 0 #first phase is 0
                        shim_ifm_enq = Tm      * Batch - 1
                        shim_wgt_enq = 1       * Batch - 1
                        shim_ofm_enq = Tm*Tn   * Batch - 0 #first phase is 0
                    elif sched == 2:
                        mem_ifm_enq  = Tm      * Batch - 1
                        mem_wgt_enq  = 2*Tm*Tn * Batch - 1 if actxwgt else 2*Tm*Tn * Batch - 1
                        mem_ofm_enq  = Tm*Tn   * Batch - 0 #first phase is 0
                        shim_ifm_enq = Tm      * Batch - 1
                        shim_wgt_enq = Tm*Tn   * Batch - 1 if actxwgt else 2*Tm*Tn * Batch - 1
                        shim_ofm_enq = Tm*Tn   * Batch - 0 #first phase is 0
                    elif sched == 5:
                        mem_ifm_enq  = 2*Tn*Tm * Batch - 1
                        mem_wgt_enq  = 2*Tn*Tm * Batch - 1 if actxwgt else 2*Tm*Tn * Batch - 1
                        mem_ofm_enq  = Tm*Tn   * Batch - 0 #first phase is 0
                        shim_ifm_enq = 2*Tn*Tm * Batch - 1
                        shim_wgt_enq = Tn*Tm   * Batch - 1 if actxwgt else 2*Tm*Tn * Batch - 1
                        shim_ofm_enq = Tm*Tn   * Batch - 0 #first phase is 0
                elif Tn > 1 and sum(N_padding) > 0:
                    if sched == 1:
                        mem_ifm_enq  = Tm      * Batch - 1
                        mem_wgt_enq  = 1       * Batch - 1
                        mem_ofm_enq  = 2*Tm    * Batch - 1
                        shim_ifm_enq = Tm      * Batch - 1
                        shim_wgt_enq = 1       * Batch - 1
                        shim_ofm_enq = 2*Tm    * Batch - 1
                    elif sched == 2:
                        mem_ifm_enq  = Tm      * Batch - 1
                        mem_wgt_enq  = 2*Tm    * Batch - 1 
                        mem_ofm_enq  = 2*Tm    * Batch - 1
                        shim_ifm_enq = Tm      * Batch - 1
                        shim_wgt_enq = 2*Tm    * Batch - 1
                        shim_ofm_enq = 2*Tm    * Batch - 1
                    elif sched == 5:
                        mem_ifm_enq  = 2*Tm    * Batch - 1
                        mem_wgt_enq  = 2*Tm    * Batch - 1
                        mem_ofm_enq  = 2*Tm    * Batch - 1
                        shim_ifm_enq = 2*Tm    * Batch - 1
                        shim_wgt_enq = 2*Tm    * Batch - 1
                        shim_ofm_enq = 2*Tm    * Batch - 1
                elif Tm > 1 and sum(M_padding) > 0:
                    if sched == 1:
                        mem_ifm_enq  = 2       * Batch - 1
                        mem_wgt_enq  = 1       * Batch - 1
                        mem_ofm_enq  = 2       * Batch - 1
                        shim_ifm_enq = 2       * Batch - 1
                        shim_wgt_enq = 1       * Batch - 1
                        shim_ofm_enq = 2       * Batch - 1
                    elif sched == 2:
                        mem_ifm_enq  = 2       * Batch - 1
                        mem_wgt_enq  = 2       * Batch - 1
                        mem_ofm_enq  = 2       * Batch - 1
                        shim_ifm_enq = 2       * Batch - 1
                        shim_wgt_enq = 2       * Batch - 1
                        shim_ofm_enq = 2       * Batch - 1
                    elif sched == 5: #ifm reenqueue
                        mem_ifm_enq  = Tm      * Batch - 1
                        mem_wgt_enq  = Tm      * Batch - 1
                        mem_ofm_enq  = Tm      * Batch - 1
                        shim_ifm_enq = Tm      * Batch - 1
                        shim_wgt_enq = Tm      * Batch - 1
                        shim_ofm_enq = Tm      * Batch - 1
                else:
                    if sched in [1, 2]:
                        mem_ifm_enq  = 1       * Batch - 1
                        mem_wgt_enq  = 1       * Batch - 1
                        mem_ofm_enq  = 1       * Batch - 1
                        shim_ifm_enq = 1       * Batch - 1
                        shim_wgt_enq = 1       * Batch - 1
                        shim_ofm_enq = 1       * Batch - 1
                    elif sched == 5:
                        mem_ifm_enq  = Tm      * Batch - 1
                        mem_wgt_enq  = Tm      * Batch - 1
                        mem_ofm_enq  = Tm      * Batch - 1
                        shim_ifm_enq = Tm      * Batch - 1
                        shim_wgt_enq = Tm      * Batch - 1
                        shim_ofm_enq = Tm      * Batch - 1

                if split[-1] == 32 and wgtsubvbytes >= 4096 and Tn > 1 and sched in [2, 5]:
                    #shim wgt bd chaining
                    shim_wgt_enq = Tn    * Batch - 1

                shim_prm_enq = 1         - 1
                mem_prm_enq  = 1         - 1
                shim_qdq_enq = Batch     - 1
                mem_qdq_enq  = Batch     - 1
            

            
            SHIM_REPEAT_MAX  = 64
            MEM_REPEAT_MAX = 768
            if sched == 5 and Tn > SHIM_REPEAT_MAX:
                m_factor = Tn
                mem_ifm_enq  = max(m_factor * Batch - 1, mem_ifm_enq)
                mem_wgt_enq  = max(m_factor * Batch - 1, mem_ifm_enq)
                mem_ofm_enq  = max(m_factor * Batch - 1, mem_ifm_enq)
                shim_ifm_enq = max(m_factor * Batch - 1, mem_ifm_enq)
                shim_wgt_enq = 1       * Batch - 1
                shim_ofm_enq = 1       * Batch - 1
            elif sched in [2, 5] and Tm*Tk*Tn > MEM_REPEAT_MAX:
                m_factor = Tm if Tk*Tn <= MEM_REPEAT_MAX else Tm*Tn
                mem_ifm_enq  = max(m_factor - 1, mem_ifm_enq)
                mem_wgt_enq  = max(m_factor - 1, mem_wgt_enq)
                mem_ofm_enq  = max(m_factor - 1, mem_ofm_enq)
            
            total_enq = (
                (shim_prm_enq) *   aie_cols + 
                (mem_prm_enq ) *  (aie_cols * mem_reenq_channels_unicast) + 
                (shim_qdq_enq) *   aie_rows + 
                (mem_qdq_enq ) *  (aie_rows * mem_reenq_channels_broadcast) + 
                (shim_ifm_enq) *  (aie_cols                               if self.tilings.overlay.unicast == "act" else aie_rows) +
                (mem_ifm_enq ) * ((aie_cols * mem_reenq_channels_unicast) if self.tilings.overlay.unicast == "act" else (aie_rows * mem_reenq_channels_broadcast)) +
                (shim_wgt_enq) *  (aie_cols                               if self.tilings.overlay.unicast == "wgt" else aie_rows) + 
                (mem_wgt_enq ) * ((aie_cols * mem_reenq_channels_unicast) if self.tilings.overlay.unicast == "wgt" else (aie_rows * mem_reenq_channels_broadcast)) +
                (shim_ofm_enq) *   aie_cols + 
                (mem_ofm_enq ) *   (aie_cols  * mem_reenq_channels_unicast) 
                         ) 
            return total_enq

        def est_control_overhead (tiling_params, sched):
            total_enq_double_buffer = est_total_enq(tiling_params, sched)
            total_enq_single_buffer = total_enq_double_buffer if sched != 5 else math.inf
            return [total_enq_single_buffer, total_enq_double_buffer]
            
        #Initializing the common variables of Tiler:
        mode     = self.tilings.overlay.mode
        aie_rows = self.tilings.overlay.rows
        aie_cols = self.tilings.overlay.cols

        #Info of BD count used for unicast & broadcast
        mem_reenq_channels_unicast   = 1 + aie_rows    #Mem -> 5
        mem_reenq_channels_broadcast = 2               #Mem -> 2

        for sched, tiling_params_list in tiling_param_dict.items():
            control_overhead_list = []
            for idx in range (len(tiling_params_list)):
                if sched == 1 and tiling_param_dict[sched][0]['overlay_info']['mode'] == 'M8K1N4':
                    pass
                control_overhead_list.append(est_control_overhead(tiling_params_list[idx], sched))
            #Storing Enqueue counts of each sched type.
            self.control_overheads[sched] = np.array (control_overhead_list)


if __name__=='__main__':
    from overlay import Overlay
    ov=Overlay('8x4','MatMul','M8K1N4')

    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_layer.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    from matmul_tiler import MatMulTiler
    t = MatMulTiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    # t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()

    c=MatMulCostModel(t)
    c.calculate_kernel_cycles()
    c.calculate_array_cycles()
    c.calculate_layer_cycles()
