import numpy as np

from conv_params import ceildiv
from OGOAT.src.Tiler.utils import count_leading_bits as clb
from OGOAT.src.Scheduling_Engine.infra.const import MEMTILE_SIZE
from dmacompiler import compute_buffer_size , config

from dataflow.conv.conv_common import conv_input
from typing import Optional
import numpy as np
from OGOAT.src.Tiler.utils import process_overheads

class ConvCostModel:
    # Constants for setup costs
    SETUP_COST_CYCLES_PER_AIE_COL = 2000
    US_CONVERSION_FACTOR = 10 ** 6  # For converting seconds to microseconds

    def __init__(self, tilings, layer=None, kernel=None):
        self.tilings = tilings
        self.layer = layer
        self.kernel = kernel
        self.kernel_cycles = {}
        self.kernel_profiling = {}
        self.array_cycles = {}
        self.layer_cycles = {}
        self.total_layer_cycles = {}
        self.params = {}
        
        # Use max_stream_bw from device configuration (already in bytes)
        self.stream_bytes_per_cycle = self.tilings.device.max_stream_bw
        
        # Get DDR bandwidth values directly from device config (values are in bytes per cycle)
        self.ddr_read_bw_bytes_per_cycle = self.tilings.device.dram_read_bandwidth
        self.ddr_write_bw_bytes_per_cycle = self.tilings.device.dram_write_bandwidth

    def derate_ddr_bw(self, bytes: int, Cm: int, C: Optional[int] = None, Xm: Optional[int] = None, burst_len: Optional[int] = 128): 
        derate = lambda len, burst_len: 2**np.minimum(np.floor(np.log2((len * bytes)/burst_len)),0)
        if (C is not None and Xm is not None):
            return derate(Cm, burst_len) if (Cm < C) else derate(C * Xm, burst_len)
        else:
            return derate(Cm, burst_len)

    def _get_subvol_parameters(self, subvol):
        """Extract and calculate derived parameters from a subvolume"""
        # Extract basic parameters
        Cis = subvol.convsubv.Cis
        Yis = subvol.convsubv.Yis
        Xis = subvol.convsubv.Xis
        Cos = subvol.convsubv.Cos
        Yos = subvol.convsubv.Yos
        Xos = subvol.convsubv.Xos
        Ci_loop = subvol.temporalsplits.Ci_loop
        Co_loop = subvol.temporalsplits.Co_loop
        Y_loop = subvol.temporalsplits.Y_loop
        X_loop = subvol.temporalsplits.X_loop
        enable_ifm_streaming = subvol.memtile_params.enable_ifm_streaming
        enable_wgt_reuse = subvol.memtile_params.enable_wgt_reuse
        pin_ifm_l1 = subvol.memtile_params.pin_ifm_l1
        pin_wgt_bias_l1 = subvol.memtile_params.pin_wgt_bias_l1
        num_ifm_subv = subvol.memtile_params.num_ifm_subv
        l1buffers = subvol.l1buffers
        X_split = subvol.spatialsplits.X_split
        Y_split = subvol.spatialsplits.Y_split
        Co_split = subvol.spatialsplits.Co_split
        Ci_split = subvol.spatialsplits.Ci_split
        
        is_X8_split = (X_split == 8)

        # Calculate derived loop counts
        Y_Co_loop = Y_loop * Co_loop
        Y_Co_Ci_loop = Y_Co_loop * Ci_loop
        Co_Ci_loop = Co_loop * Ci_loop
        Y_Co_Cisub1_loop = Y_Co_loop * (Ci_loop - 1)
        X_Y_loop = X_loop * Y_loop
        X_Y_Co_loop = X_loop * Y_Co_loop
        X_Y_Co_Ci_loop = X_loop * Y_Co_Ci_loop
        X_Y_Co_Cisub1_loop = X_loop * Y_Co_Cisub1_loop
        Cim_loop = (Ci_loop  if not self.layer.is_standalone_dwc else Co_loop ) / num_ifm_subv
        Cim_Co_loop = Cim_loop * Co_loop

        # ------------------------------------------------------------------------------------
        l1_l2_transfer_size = {}
        l1_l2_transfer_size['ifm'] = l1buffers.ifm_size * (1 if pin_ifm_l1 else X_Y_Co_Ci_loop)
        l1_l2_transfer_size['wgt'] = l1buffers.wgt_size * (1 if pin_wgt_bias_l1 else X_Y_Co_Ci_loop)
        l1_l2_transfer_size['ofm'] = l1buffers.ofm_size * (X_Y_Co_loop if Ci_loop > 1 else 1)
        # ------------------------------------------------------------------------------------
        ddr_transfer_size = {}
        Xid = conv_input(Xos, self.layer.Kx, self.layer.Sx) if is_X8_split else self.layer.Xi # use non-padded value of Xis for is_X8_split
        ddr_transfer_size['ifm_subv'] = self.tilings.aie_cols * (subvol.memtile_params.ifm_size * Xid / subvol.memtile_params.Xim) * int(self.layer.in_act_residency in ["L3", "L2+L3"])
        ddr_transfer_size['wgt_subv'] = self.tilings.aie_rows * subvol.memtile_params.wgt_size
        ddr_transfer_size['ofm_subv'] = self.tilings.aie_cols * subvol.memtile_params.ofm_size * int(self.layer.out_act_residency in ["L3", "L2+L3"])

        ddr_transfer_size['ifm'] = ddr_transfer_size['ifm_subv'] * (X_Y_loop * (Cim_Co_loop if enable_ifm_streaming else 1))
        ddr_transfer_size['wgt'] = ddr_transfer_size['wgt_subv'] * (1 if enable_wgt_reuse else (X_Y_Co_Ci_loop / subvol.memtile_params.num_pack_wgt_subv))
        ddr_transfer_size['ofm'] = ddr_transfer_size['ofm_subv'] * (X_Y_loop if is_X8_split else Y_Co_loop)

        # ddr bw updated with derate factor based on lowest dim and second-lowest dim
        ddr_transfer_size['ifm_ddr_bw'] = self.ddr_read_bw_bytes_per_cycle * self.derate_ddr_bw(self.kernel.ifm_bytes, subvol.memtile_params.Cim, self.layer.Ci, subvol.memtile_params.Xim)
        ddr_transfer_size['wgt_ddr_bw'] = self.ddr_read_bw_bytes_per_cycle * self.derate_ddr_bw(self.kernel.wgt_bytes, l1buffers.wgt_size)
        ddr_transfer_size['ofm_ddr_bw'] = self.ddr_write_bw_bytes_per_cycle * self.derate_ddr_bw(self.kernel.ofm_bytes, subvol.memtile_params.Com * subvol.memtile_params.mt_co_pack, self.layer.Co, subvol.memtile_params.Xom)
        # ------------------------------------------------------------------------------------
        """Generic method to calculate kernel cycles using metadata-driven parameters"""

        return {
            'Cis': Cis, 'Yis': Yis, 'Xis': Xis,
            'Cos': Cos, 'Yos': Yos, 'Xos': Xos,
            'Xim': subvol.memtile_params.Xim, 'Xom': subvol.memtile_params.Xom,
            'Cim': subvol.memtile_params.Cim, 'Com': subvol.memtile_params.Com,
            'mt_co_pack': subvol.memtile_params.mt_co_pack,
            'Ci_loop': Ci_loop, 'Co_loop': Co_loop, 'Y_loop': Y_loop,'X_loop': X_loop,
            'Co_Ci_loop': Co_Ci_loop, 
            'X_Y_loop': X_Y_loop,
            'X_Y_Co_loop': X_Y_Co_loop,
            'enable_ifm_streaming': enable_ifm_streaming,
            'enable_wgt_reuse': enable_wgt_reuse,
            'pin_ifm_l1': pin_ifm_l1,
            'pin_wgt_bias_l1': pin_wgt_bias_l1,
            'l1_buffer_sizes': l1buffers,
            'l1_l2_transfer_size': l1_l2_transfer_size,
            'ddr_transfer_size': ddr_transfer_size,
            'is_X8_split': is_X8_split,
            'Co_split': Co_split,
            'num_ifm_subv': num_ifm_subv,
            'loop_constraint': subvol.loop_constraint
        }

    def _round_all_values(self, cycles_dict):
        """Round all numeric values in the dictionary"""
        return {key: round(value) for key, value in cycles_dict.items()}

    def calculate_kernel_cycles(self):
        """Calculate compute cycles for each valid core subvolume"""                  
        X_gran = self.tilings.kernel.X_gran
        Ci_gran = self.tilings.kernel.Ci_gran
        Co_gran = self.tilings.kernel.Co_gran
        Y_gran = self.tilings.kernel.Y_gran

        layer = self.tilings.layer
        device = self.tilings.device
        in_dtype = layer.in_datatype[1:] if layer.in_datatype[0]=='u' else layer.in_datatype
        wgt_dtype = layer.wgt_datatype[1:] if layer.wgt_datatype[0]=='u' else layer.wgt_datatype
        Sx = layer.Sx
        Sy = layer.Sy
        Kx_g = layer.Kx
        Ky_g = layer.Ky
        macs_per_cycle = device.macs_per_cycle[in_dtype+'x'+wgt_dtype]

        for i, subvol in enumerate(self.tilings.valid_core_subvols):
                           
            Cis = subvol.convsubv.Cis
            Xis = subvol.convsubv.Xis
            Yis = subvol.convsubv.Yis
            Cos = subvol.convsubv.Cos
            Xos = subvol.convsubv.Xos
            Yos = subvol.convsubv.Yos
            Ci_loop = subvol.temporalsplits.Ci_loop
            Co_loop = subvol.temporalsplits.Co_loop
            X_loop = subvol.temporalsplits.X_loop
            Y_loop = subvol.temporalsplits.Y_loop
 
            args = {"overheads": self.tilings.kernel.overheads, "op_type": self.tilings.layer.orig_op_type, "macs_per_cycle": macs_per_cycle, 
                    "Mgran": X_gran, "Kgran": Ci_gran, "Ngran": Co_gran, "Y_gran": Y_gran, "Sx": Sx, "Sy": Sy, "M": Xis, "Yis": Yis, "K": Cis, "Yos":Yos, "Xos":Xos, "N": Cos, 
                    "is_standalone_dwc": self.layer.is_standalone_dwc,
                    "Tb": 1 , "Tm": X_loop , "Tk": Ci_loop , "Tn": Co_loop, "Ty": Y_loop, "Kx_g": Kx_g , "Ky_g": Ky_g, "loop_constraint": self.tilings.valid_core_subvols[i].loop_constraint , "is_xint8": self.kernel.is_xint8}
            overhead_cycles, var_dict, overhead_details = process_overheads(args)
            self.kernel_cycles[i] = round(sum(overhead_cycles.values())) 
            def format_dict(inp_dict):
                for k, v in inp_dict.items():
                    inp_dict[k] = format_dict(v) if isinstance(v, dict) else int(v)
                return inp_dict
            self.kernel_profiling[i] = {'kernel_cycles': self.kernel_cycles[i], 'overhead_details': format_dict(overhead_details)}
            

    def _calc_cycles(self, oh_il_ol: tuple[dict[str, int], int, int]) -> int:
        cycles = 0
        overheads, inner_loop, outer_loop = oh_il_ol
        for overhead in overheads:
            cycles += overhead['loop_count'] * (
                (outer_loop * ((inner_loop * eval(overhead['cycles_per_inner_loop'])) + eval(overhead['outer_loop_OH']))) 
                + eval(overhead['kernel_body_OH'])
            )
        return cycles

    def _calculate_core_kernel_cycles(self, params):
        return (  self._calc_cycles(params['conv_oh_il_ol'])
                + self._calc_cycles(params['ifmsum_generic_template_oh_il_ol']) 
                + self._calc_cycles(params['ifmsum_sumv_conv_oh_il_ol'])
                + self._calc_cycles(params['qdq_oh_il_ol'])
                + self._calc_cycles(params['add_oh_il_ol'])
                + self._calc_cycles(params['superkernel_oh_il_ol'])
                + self._calc_cycles(params['wgt_unpack_oh_il_ol']) * 2   #wgt unpacking for dwc , multiply by 2 is to account for conv and ifmsum together
                )

    def calculate_array_cycles(self):
        """Calculate array-level cycles including memory dataflows between L1-L2 and L2-L3"""
        for i, subvol in enumerate(self.tilings.valid_core_subvols):
            params = self._get_subvol_parameters(subvol)
            
            # Calculate L1-L2 dataflow cycles
            l1_l2_cycles = self._calculate_l1_l2_cycles(params)
            
            # Calculate L2-L3 dataflow cycles
            l2_l3_cycles = self._calculate_l2_l3_cycles(params)
            
            # Combine all cycle counts
            self.array_cycles[i] = {**l1_l2_cycles, **l2_l3_cycles}

    def _calculate_l1_l2_cycles(self, params):
        """Calculate dataflow cycles between L1 and L2 memory levels"""
        # Calculate ofm dataflow cycles
        l1_to_l2_ofm_subv_dataflow_cycles = params['l1_buffer_sizes'].ofm_size / self.stream_bytes_per_cycle
        l1_to_l2_ofm_dataflow_cycles = params['l1_l2_transfer_size']['ofm'] / self.stream_bytes_per_cycle
        l1_to_l2_ofm_lockstalls_cycles = params['X_Y_loop'] * l1_to_l2_ofm_subv_dataflow_cycles * (1 if params['Ci_loop'] > 1 else params['mt_co_pack'])

        # Calculate ifm dataflow cycles
        l2_to_l1_ifm_subv_dataflow_cycles = params['l1_buffer_sizes'].ifm_size / self.stream_bytes_per_cycle
        l2_to_l1_ifm_dataflow_cycles = params['l1_l2_transfer_size']['ifm'] / self.stream_bytes_per_cycle
        l2_to_l1_ifm_lockstalls_cycles = l2_to_l1_ifm_subv_dataflow_cycles

        # Calculate wgt dataflow cycles
        l2_to_l1_wgt_subv_dataflow_cycles = params['l1_buffer_sizes'].wgt_size / self.stream_bytes_per_cycle
        l2_to_l1_wgt_dataflow_cycles = params['l1_l2_transfer_size']['wgt'] / self.stream_bytes_per_cycle
        l2_to_l1_wgt_lockstalls_cycles = l2_to_l1_wgt_subv_dataflow_cycles
        
        # Return rounded cycle counts
        return self._round_all_values({
            'l1_to_l2_ofm_dataflow_cycles': l1_to_l2_ofm_dataflow_cycles,
            'l2_to_l1_ifm_dataflow_cycles': l2_to_l1_ifm_dataflow_cycles,
            'l2_to_l1_wgt_dataflow_cycles': l2_to_l1_wgt_dataflow_cycles,
            'l2_to_l1_ifm_lockstalls_cycles': l2_to_l1_ifm_lockstalls_cycles,
            'l2_to_l1_wgt_lockstalls_cycles': l2_to_l1_wgt_lockstalls_cycles,
            'l1_to_l2_ofm_lockstalls_cycles': l1_to_l2_ofm_lockstalls_cycles,
        })
        
    def _calculate_l2_l3_cycles(self, params):
        """Calculate dataflow cycles between L2 and L3 memory levels"""
        return {}

    def _calculate_ddr_cycles(self, params, core_compute_cycles):
        """Calculate DDR transfer cycles for layer"""
        # Calculate DDR cycles for data transfers using appropriate bandwidth
        ddr_ifm_dataflow_cycles = params['ddr_transfer_size']['ifm'] / params['ddr_transfer_size']['ifm_ddr_bw']
        ddr_wgt_dataflow_cycles = params['ddr_transfer_size']['wgt'] / params['ddr_transfer_size']['wgt_ddr_bw']
        ddr_ofm_dataflow_cycles = params['ddr_transfer_size']['ofm'] / params['ddr_transfer_size']['ofm_ddr_bw']

        aie_cols, aie_rows = self.tilings.overlay.cols, self.tilings.overlay.rows
        ddr_wgt_lockstalls_cycles = params['ddr_transfer_size']['wgt_subv'] / params['ddr_transfer_size']['wgt_ddr_bw'] / ((aie_rows) / (aie_cols + aie_rows))
        """ For the special case of Co_Ci_loop == 1; ifm memtile double buffering enables pipelining of 
        ddr to memtile ifm transfers with core subv compute. """

        ddr_ifm_lockstalls_cycles = params['ddr_transfer_size']['ifm_subv'] / params['ddr_transfer_size']['ifm_ddr_bw'] 

        if (params['enable_ifm_streaming'] or params['Co_Ci_loop'] == 1) :
            ddr_ifm_lockstalls_cycles =  ddr_ifm_lockstalls_cycles / ((aie_cols) / (aie_cols + aie_rows))
        else:
            ddr_ifm_lockstalls_cycles = max(0, ddr_ifm_lockstalls_cycles * params['X_Y_loop'] - (core_compute_cycles * 2 / params['Co_Ci_loop']) )


        # Total cycles for read and write operations
        ddr_read_dataflow_cycles = ddr_ifm_dataflow_cycles + ddr_wgt_dataflow_cycles
        ddr_write_dataflow_cycles = ddr_ofm_dataflow_cycles
        
        # Return rounded DDR cycles
        return self._round_all_values({
            'ddr_read_dataflow_cycles': ddr_read_dataflow_cycles,
            'ddr_write_dataflow_cycles': ddr_write_dataflow_cycles,
            'ddr_ifm_lockstalls_cycles': ddr_ifm_lockstalls_cycles,
            'ddr_wgt_lockstalls_cycles': ddr_wgt_lockstalls_cycles,
            'ddr_ifm_dataflow_cycles': ddr_ifm_dataflow_cycles,
            'ddr_ofm_dataflow_cycles': ddr_ofm_dataflow_cycles,
            'ddr_wgt_dataflow_cycles': ddr_wgt_dataflow_cycles
        })
    
    def _calculate_bottleneck_cycles(self, params, core_compute_cycles, array_cycles, ddr_cycles):
        """Calculate the bottleneck and total cycles for the layer"""

        l3_to_l2_to_l1_ifm_lockstalls_cycles = ddr_cycles['ddr_ifm_lockstalls_cycles'] + array_cycles['l2_to_l1_ifm_lockstalls_cycles']
        l3_to_l2_to_l1_wgt_lockstalls_cycles = ddr_cycles['ddr_wgt_lockstalls_cycles'] + array_cycles['l2_to_l1_wgt_lockstalls_cycles']
        l3_to_l2_to_l1_lockstalls_cycles = max(l3_to_l2_to_l1_ifm_lockstalls_cycles, l3_to_l2_to_l1_wgt_lockstalls_cycles)

        setup_kernel_cycles = l3_to_l2_to_l1_lockstalls_cycles
        kernel_cycles = setup_kernel_cycles + core_compute_cycles + array_cycles['l1_to_l2_ofm_dataflow_cycles']

        core_compute_subv_cycles = core_compute_cycles / params['X_Y_Co_loop']
        setup_l2_to_l3_ofm_dataflow_cycles = l3_to_l2_to_l1_lockstalls_cycles + (core_compute_subv_cycles * params['mt_co_pack']) + array_cycles['l1_to_l2_ofm_lockstalls_cycles']

        ddr_write_dataflow_cycles = setup_l2_to_l3_ofm_dataflow_cycles + ddr_cycles['ddr_ofm_dataflow_cycles']
        
        # Identify bottleneck across all cycle types
        bottle_necks = {
            'kernel_cycles': kernel_cycles,
            'l2_to_l1_ifm_dataflow_cycles': array_cycles['l2_to_l1_ifm_dataflow_cycles'], 
            'l2_to_l1_wgt_dataflow_cycles': array_cycles['l2_to_l1_wgt_dataflow_cycles'], 
            'ddr_read_dataflow_cycles': ddr_cycles['ddr_read_dataflow_cycles'], 
            'ddr_write_dataflow_cycles': round(ddr_write_dataflow_cycles),
        }
        
        # Find bottleneck cycle count and name
        projected_cycles, projected_cycles_bottle_neck = max(zip(bottle_necks.values(), bottle_necks.keys()))
        
        # Add setup costs
        setup_cost_cycles_of_aie_cols = self.SETUP_COST_CYCLES_PER_AIE_COL * self.tilings.overlay.cols
        
        # Add all extra costs
        projected_cycles += setup_cost_cycles_of_aie_cols

        # Calculate total time in sec
        total_time = projected_cycles / self.tilings.device.core_freq_hz
        
        return {
            'projected_cycles': projected_cycles,
            'projected_cycles_bottle_neck': projected_cycles_bottle_neck,
            'total_time_us': round(total_time * self.US_CONVERSION_FACTOR, 3)
        }

    def calculate_layer_cycles(self):
        """Calculate full layer cycles including DDR operations"""
        
        for i, subvol in enumerate(self.tilings.valid_core_subvols):
            params = self._get_subvol_parameters(subvol)

            core_compute_cycles = self.kernel_cycles[i]
            # Calculate DDR cycle costs
            ddr_cycles = self._calculate_ddr_cycles(params,core_compute_cycles)
            # Extract the needed cycles from previous calculations
            array_cycles = self.array_cycles[i]
            
            # Store DDR cycles for this subvolume
            self.layer_cycles[i] = ddr_cycles
            
            # Calculate bottleneck and total cycles
            self.total_layer_cycles[i] = self._calculate_bottleneck_cycles(
                params,
                core_compute_cycles, 
                array_cycles, 
                ddr_cycles
            )
            
