import sys
import yaml
import os

import numpy as np

from OGOAT.src.Tiler.tiler import Tiler
from conv_cost_model import ConvCostModel
from overlay import Overlay
from dataclasses import dataclass, field

parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
sys.path.append(os.path.join(parent_dir, '..', 'dataflow', 'conv'))
import conv_common


@dataclass
class KernelParams:
    def __init__(self, kernel, layer):
        self.Y_gran, self.X_gran, self.Ci_gran, self.Co_gran = kernel.subvol_constraints.values()
        self.inner_loop_min = kernel.additional_constraints.get('inner_loop_min', None)
        self.outer_loop_min = kernel.additional_constraints.get('outer_loop_min', None)
        self.mem_align = kernel.additional_constraints.get('mem_align', None)
        self.Xis_Cis_constraint = (not (kernel.additional_constraints.get('Xis_Cis_max') is not None))

        if self.inner_loop_min is None:
            raise ValueError("Minimum inner loop must be specified.")
        if self.outer_loop_min is None:
            raise ValueError("Minimum outer loop must be specified.")
        if self.mem_align is None:
            raise ValueError("Memory alignment must be specified.")

        self.ifm_bytes = layer.in_bytes
        self.wgt_bytes = layer.wgt_bytes
        self.ofm_bytes = layer.out_bytes
        self.bias_bytes = kernel.bias_bytes
        self.tdm_bytes = kernel.tdm_bytes
        self.qdq_param_size = kernel.qdq_param_size
        self.conv_kernel_param_size = kernel.conv_kernel_param_size
        self.is_xint8 = kernel.is_xint8
        self.is_a16w8 = (2, 1) == (self.ifm_bytes, self.wgt_bytes)
        self.is_a8w8 = kernel.is_a8w8

        self.enable_add = kernel.enable_add
        if hasattr(kernel, 'Cos_min'):
            self.Cos_min = kernel.Cos_min
        if hasattr(kernel, 'overheads'):
            self.overheads = kernel.overheads
        if hasattr(kernel, 'other_constraints'):
            self.other_constraints = kernel.other_constraints


@dataclass
class LayerParams:
    def __init__(self, layer):
        _, self.Yi, self.Xi, self.Ci = layer.aligned_in_act_shape
        _, self.Yo, self.Xo_orig, self.Co = layer.aligned_out_act_shape
        self.Ky, self.Kx = layer.kernel_shape
        self.Sy, self.Sx = layer.strides
        self.in_act_residency = layer.in_act_residency
        self.out_act_residency = layer.out_act_residency
        self.is_standalone_dwc = layer.is_standalone_dwc
        self.orig_op_type = "Conv"
        self.in_datatype = layer.in_datatype
        self.wgt_datatype = layer.wgt_datatype
        self.disable_scheduler_constraints = getattr(layer, 'disable_scheduler_constraints', False)


class ConvTilingOpt:
    def __init__(self, layer, device, overlay_name, kernel, overlay_type: str = "B", l2fusion_pass=False):
        with open(os.path.join(parent_dir, 'Collaterals/overlays.yaml')) as f:
            self.all_overlays = yaml.safe_load(f)

        self.layer = layer
        Layer = LayerParams(layer)
        Kernel = KernelParams(kernel, layer)

        self.modes = list(self.all_overlays[overlay_name][layer.orig_op_type].keys())
        self.overlay = overlay_name
        self.overlay_type = overlay_type
        self.l2fusion_pass = l2fusion_pass

        self.tilers = []
        self.cost_models = []
        for mode in self.modes:
            overlay = Overlay(overlay_name, layer.orig_op_type, mode, device=device, overlay_type=overlay_type)
            _, X_split, _ = overlay.core_splits['ofm']
            if (X_split == 8) and Layer.is_standalone_dwc:
                continue
            tiler = Tiler(Layer, device, overlay, Kernel)
            self.tilers.append(tiler)
            self.cost_models.append(ConvCostModel(tiler, Layer, Kernel))

    def calculate_tiling_cycles(self, top_n=1):
        """Find the best tiling configurations across all tilers and their subvolumes"""

        for i, (tiler, cost_model) in enumerate(zip(self.tilers, self.cost_models)):
            tiler.calculate_array_tilings(self.l2fusion_pass)

            cost_model.calculate_kernel_cycles()
            cost_model.calculate_array_cycles()
            cost_model.calculate_layer_cycles()

            # Find the best subvolumes for this tiler
            best_subvol_indices = self._find_best_subvols_for_tiler(tiler, cost_model)

            # Process each good subvolume
            for subvol_idx in best_subvol_indices:
                # Add each good subvolume to the list of top tilings
                self._update_top_tilings(i, subvol_idx, top_n)

            # Early exit for l2fusion_pass if we found at least one valid tiling
            if self.l2fusion_pass and self.top_tilings:
                return

    def _find_best_subvols_for_tiler(self, tiler, cost_model):
        """Find the best subvolume indices for a tiler based on cycle count and loops"""
        # Store (cycle_count, subvol_idx, loops) for ranking
        subvol_ranks = []

        for j, total_layer_cycle in cost_model.total_layer_cycles.items():
            cycles = round(total_layer_cycle['projected_cycles'])
            subvol = tiler.valid_core_subvols[j]
            loops = subvol.temporalsplits.loops

            subvol_ranks.append((cycles, loops, j))

        # Sort by cycles (primary) and loops (secondary)
        subvol_ranks.sort(key=lambda x: (x[0], x[1]))

        # Return the subvolume indices
        return [rank[2] for rank in subvol_ranks]

    def _update_top_tilings(self, tiler_idx, subvol_idx, top_n=1):
        """Update the list of top tilings with the current tiling if it's good enough"""
        current_cycles = round(self.cost_models[tiler_idx].total_layer_cycles[subvol_idx]['projected_cycles'])
        current_subvol = self.tilers[tiler_idx].valid_core_subvols[subvol_idx]
        current_tiling = {
            'cycles': current_cycles,
            'tiling': current_subvol,
            'tiler_index': tiler_idx,
            'subvol_index': subvol_idx
        }

        def tiling_sort_key(tiling):
            return (
                tiling['cycles'],
                tiling['tiling'].temporalsplits.loops,
                tiling['tiling'].temporalsplits.X_loop,
                tiling['tiling'].temporalsplits.Y_loop,
                tiling['tiling'].temporalsplits.Co_loop,
                tiling['tiling'].temporalsplits.Ci_loop,
                tiling['tiling'].spatialsplits.X_split
            )

        def is_better_tiling(current_tiling, worst_tiling):
            return tiling_sort_key(current_tiling) < tiling_sort_key(worst_tiling)

        # If we haven't found enough tilings yet, just add this one
        if len(self.top_tilings) < top_n:
            self.top_tilings.append(current_tiling)
            # Sort by cycles (primary key) and loops (secondary key)
            self.top_tilings.sort(key=tiling_sort_key)
        else:
            # Check if this is better than the worst one we have
            worst_tiling = self.top_tilings[-1]
            if is_better_tiling(current_tiling, worst_tiling):
                # Replace the worst one and resort
                self.top_tilings[-1] = current_tiling
                self.top_tilings.sort(key=tiling_sort_key)

    def calculate_min_cycles_and_macs(self):
        in_act_shape = self.layer.in_act_shape
        in_wgt_shape = self.layer.in_wgt_shape
        out_act_shape = self.layer.out_act_shape

        in_dtype = self.layer.in_datatype[1:] if self.layer.in_datatype[0] == 'u' else self.layer.in_datatype
        wgt_dtype = self.layer.wgt_datatype[1:] if self.layer.wgt_datatype[0] == 'u' else self.layer.wgt_datatype
        # Calculate MACs for each layer
        macs = int(np.prod(out_act_shape, dtype=np.int64)) * int(np.prod(in_wgt_shape[:2], dtype=np.int64)) * int(
            in_act_shape[-1])
        macs_per_cycle = self.tilers[0].device.macs_per_cycle[in_dtype + 'x' + wgt_dtype]
        num_cores = self.tilers[0].overlay.rows * self.tilers[0].overlay.cols
        min_cycles = macs / macs_per_cycle / num_cores
        return min_cycles, macs

    def find_optimal_tiling(self, top_n=1):
        if self.overlay == "4x4" and self.overlay_type == "A":
            assert False
        self.top_tilings = []

        # Initialize base tiling parameters with available info
        tiling_params = {
            'core_tile_params': {'subvols': [], 'iters': [], 'L1_sizes_addrs': [], 'ping_pong': []},
            'mem_tile_params': {'subvols': [], 'iters': [], 'sizes': [], 'configs': []},
            'shim_tile_params': {'subvols': [], 'sizes': []},
            'scheduling': {},
            'dma_layer_padding': [],
            'original_dimensions': [],
            'host_layer_padding': [],
            'layer_padding': [],
            'layer_info': vars(self.layer),
            'overlay_info': {
                'overlay': self.overlay,
                'overlay_type': self.overlay_type,
                'mode': [],
                'shape': {
                    'row': self.all_overlays[self.overlay]['rows'],
                    'col': self.all_overlays[self.overlay]['cols'],
                }
            },
            'kernel_info': {'placement_constraints': {}},
            'performance_metrics': [],
            'kernel_profiling': [],
            'additional_flags': {'dump_waves': False},
            'layer_macs': dict(zip(['min_layer_cycles', 'macs'], self.calculate_min_cycles_and_macs()))
        }
        tiling_params['original_dimensions'] = [
            {"input0": self.layer.in_act_shape},
            {"input1": self.layer.in_wgt_shape},
            {"output0": self.layer.out_act_shape}
        ]
        tiling_params['host_layer_padding'] = [
            {"input0": {"dims": self.layer.aligned_in_act_shape,
                        "values": ["zp_i0" for x in range(len(self.layer.in_act_shape))]
                        }
             },
            {"input1": {"dims": self.layer.aligned_in_wgt_shape,
                        "values": ["zp_i1" for x in range(len(self.layer.in_wgt_shape))]
                        }
             },
            {"output0": {"dims": self.layer.aligned_out_act_shape,
                         "values": ["zp_o0" for x in range(len(self.layer.out_act_shape))]
                         }
             },
        ]

        self.calculate_tiling_cycles(top_n)

        # Add detailed information if valid tilings were found
        if self.top_tilings:
            for tiling_info in self.top_tilings:
                best_tiler = self.tilers[tiling_info['tiler_index']]
                best_cost_model = self.cost_models[tiling_info['tiler_index']]
                best_tiling = tiling_info['tiling']
                best_mode = self.modes[tiling_info['tiler_index']]
                best_subvol_index = tiling_info['subvol_index']

                # Add subvolume parameters
                tiling_params['core_tile_params']['subvols'].append({
                    'Cis': int(best_tiling.convsubv.Cis),
                    'Yis': int(best_tiling.convsubv.Yis),
                    'Xis': int(best_tiling.convsubv.Xis),
                    'Cos': int(best_tiling.convsubv.Cos),
                    'Yos': int(best_tiling.convsubv.Yos),
                    'Xos': int(best_tiling.convsubv.Xos),
                    'Y_loop': int(best_tiling.temporalsplits.Y_loop),
                    'Co_loop': int(best_tiling.temporalsplits.Co_loop),
                    'Ci_loop': int(best_tiling.temporalsplits.Ci_loop),
                    'X_loop': int(best_tiling.temporalsplits.X_loop),
                    'X_split': int(best_tiling.spatialsplits.X_split),
                    'Y_split': int(best_tiling.spatialsplits.Y_split),
                    'Ci_split': int(best_tiling.spatialsplits.Ci_split),
                    'Co_split': int(best_tiling.spatialsplits.Co_split),
                    'enable_ifm_streaming': int(best_tiling.memtile_params.enable_ifm_streaming),
                    'enable_wgt_reuse': int(best_tiling.memtile_params.enable_wgt_reuse),
                    'pin_ifm_l1': int(best_tiling.memtile_params.pin_ifm_l1),
                    'pin_wgt_bias_l1': int(best_tiling.memtile_params.pin_wgt_bias_l1),
                    'Com': int(best_tiling.memtile_params.Com),
                    'Yom': int(best_tiling.memtile_params.Yom),
                    'Xom': int(best_tiling.memtile_params.Xom),
                    'Cim': int(best_tiling.memtile_params.Cim),
                    'Yim': int(best_tiling.memtile_params.Yim),
                    'Xim': int(best_tiling.memtile_params.Xim),
                    'loop_constraint': bool(best_tiling.loop_constraint),
                })
                Yis_depadded = conv_common.conv_output(int(best_tiling.convsubv.Yis), int(self.layer.in_wgt_shape[0]),
                                                       int(self.layer.strides[0]), 0, 0)
                Xis_depadded = conv_common.conv_output(int(best_tiling.convsubv.Xis), int(self.layer.in_wgt_shape[1]),
                                                       int(self.layer.strides[1]), 0, 0)
                ifm_dma_pads = [0,
                                int((Yis_depadded * int(best_tiling.temporalsplits.Y_loop) * int(
                                    best_tiling.spatialsplits.Y_split))),
                                int((Xis_depadded * int(best_tiling.temporalsplits.X_loop) * int(
                                    best_tiling.spatialsplits.X_split))),
                                int((int(best_tiling.convsubv.Cis) * int(best_tiling.temporalsplits.Ci_loop) * int(
                                    best_tiling.spatialsplits.Ci_split)))
                                ]
                ofm_dma_pads = [0,
                                int((int(best_tiling.convsubv.Yos) * int(best_tiling.temporalsplits.Y_loop) * int(
                                    best_tiling.spatialsplits.Y_split))),
                                int((int(best_tiling.convsubv.Xos) * int(best_tiling.temporalsplits.X_loop) * int(
                                    best_tiling.spatialsplits.X_split))),
                                int((int(best_tiling.convsubv.Cos) * int(best_tiling.temporalsplits.Co_loop) * int(
                                    best_tiling.spatialsplits.Co_split)))
                                ]
                wgt_dma_pads = [0, 0,
                                ifm_dma_pads[3],
                                ofm_dma_pads[3],
                                ]
                tiling_params['dma_layer_padding'].append([
                    {"input0": {"dims": ifm_dma_pads,
                                "values": ["zp_i0" for x in range(len(self.layer.in_act_shape))],
                                "channels": best_tiler.overlay.memtile_unicast_channels
                                }
                     },
                    {"input1": {"dims": wgt_dma_pads,
                                "values": ["zp_i1" for x in range(len(self.layer.in_wgt_shape))],
                                "channels": best_tiler.overlay.memtile_broadcast_channels
                                }
                     },
                    {"output0": {"dims": ofm_dma_pads,
                                 "values": ["zp_o0" for x in range(len(self.layer.out_act_shape))],
                                 "channels": best_tiler.overlay.memtile_ofm_channels
                                 }
                     }
                ])

                # Add iteration parameters
                tiling_params['core_tile_params']['iters'].append({
                    'Y_loop': int(best_tiling.temporalsplits.Y_loop),
                    'Co_loop': int(best_tiling.temporalsplits.Co_loop),
                    'Ci_loop': int(best_tiling.temporalsplits.Ci_loop),
                    'X_loop': int(best_tiling.temporalsplits.X_loop),
                })

                # Add memory L1_sizes_addrs
                safe_int = lambda x: int(x) if x is not None else x
                L1_addrs = {
                    'ifm_subv_size': safe_int(best_tiling.l1buffers.ifm_size),
                    'wgt_subv_size': safe_int(best_tiling.l1buffers.wgt_size),
                    'ofm_subv_size': safe_int(best_tiling.l1buffers.ofm_size),
                    # ------------------------------------------------
                    'ifm_ping_addr': safe_int(best_tiling.l1buffers.ifm_ping_addr),
                    'ifm_pong_addr': safe_int(best_tiling.l1buffers.ifm_pong_addr),
                    'wgt_ping_addr': safe_int(best_tiling.l1buffers.wgt_ping_addr),
                    'wgt_pong_addr': safe_int(best_tiling.l1buffers.wgt_pong_addr),
                    'ofm_ping_addr': safe_int(best_tiling.l1buffers.ofm_ping_addr),
                    'ofm_pong_addr': safe_int(best_tiling.l1buffers.ofm_pong_addr),
                    'tdm_ping_addr': safe_int(best_tiling.l1buffers.tdm_ping_addr),
                    'tdm_pong_addr': safe_int(best_tiling.l1buffers.tdm_pong_addr),
                    'ifm_sum_addr': safe_int(best_tiling.l1buffers.ifm_sum_addr),
                    'scratch_buf': safe_int(best_tiling.l1buffers.scratch_buf),
                    'tmp_buf': safe_int(best_tiling.l1buffers.tmp_buf),
                    'conv_kernelprm_addr': safe_int(best_tiling.l1buffers.conv_kernelprm_addr),
                    'add_ifm_addr': safe_int(best_tiling.l1buffers.add_ifm_addr),
                }

                tiling_params['core_tile_params']['L1_sizes_addrs'].append(L1_addrs)

                # Add ping-pong configuration
                tiling_params['core_tile_params']['ping_pong'].append({
                    'ifm': str(best_tiling.conv_pp.ifm),
                    'ofm': str(best_tiling.conv_pp.ofm),
                    'wgt': str(best_tiling.conv_pp.wgt),
                    'tdm': str(best_tiling.conv_pp.tdm)
                })

                # Add overlay info for each tiling
                tiling_params['overlay_info']['mode'].append(best_mode)
                tiling_params['overlay_info']['shape'] = {
                    'row': best_tiler.overlay.rows,
                    'col': best_tiler.overlay.cols
                }

                # Add performance metrics for each tiling
                tiling_params['performance_metrics'].append({
                    'projected_cycles': best_cost_model.total_layer_cycles[best_subvol_index]['projected_cycles'],
                    'bottleneck': best_cost_model.total_layer_cycles[best_subvol_index]['projected_cycles_bottle_neck'],
                    'projected_latency_us': best_cost_model.total_layer_cycles[best_subvol_index]['total_time_us'],
                    # Add detailed cycle metrics
                    'core_compute_cycles': best_cost_model.kernel_cycles[best_subvol_index],
                    # L1-L2 dataflow cycles
                    'l1_to_l2_ofm_dataflow_cycles': best_cost_model.array_cycles[best_subvol_index][
                        'l1_to_l2_ofm_dataflow_cycles'],
                    'l2_to_l1_ifm_dataflow_cycles': best_cost_model.array_cycles[best_subvol_index][
                        'l2_to_l1_ifm_dataflow_cycles'],
                    'l2_to_l1_wgt_dataflow_cycles': best_cost_model.array_cycles[best_subvol_index][
                        'l2_to_l1_wgt_dataflow_cycles'],
                    # DDR dataflow cycles
                    'ddr_read_dataflow_cycles': best_cost_model.layer_cycles[best_subvol_index][
                        'ddr_read_dataflow_cycles'],
                    'ddr_write_dataflow_cycles': best_cost_model.layer_cycles[best_subvol_index][
                        'ddr_write_dataflow_cycles'],
                    'ddr_ifm_dataflow_cycles': best_cost_model.layer_cycles[best_subvol_index][
                        'ddr_ifm_dataflow_cycles'],
                    'ddr_ofm_dataflow_cycles': best_cost_model.layer_cycles[best_subvol_index][
                        'ddr_ofm_dataflow_cycles'],
                    'ddr_wgt_dataflow_cycles': best_cost_model.layer_cycles[best_subvol_index][
                        'ddr_wgt_dataflow_cycles']
                })
                tiling_params['kernel_profiling'].append({
                    **best_cost_model.kernel_profiling[best_subvol_index],
                })

                tiling_params['mem_tile_params']['configs'].append({
                    'enable_ifm_streaming': int(best_tiling.memtile_params.enable_ifm_streaming),
                    'enable_wgt_reuse': int(best_tiling.memtile_params.enable_wgt_reuse),
                    'pin_ifm_l1': int(best_tiling.memtile_params.pin_ifm_l1),
                    'pin_wgt_bias_l1': int(best_tiling.memtile_params.pin_wgt_bias_l1),
                })
                tiling_params['mem_tile_params']['subvols'].append({
                    'Com': int(best_tiling.memtile_params.Com),
                    'Xom': int(best_tiling.memtile_params.Xom),
                    'Cim': int(best_tiling.memtile_params.Cim),
                })
                tiling_params['mem_tile_params']['sizes'].append({
                    'wgt_memtile_size': int(best_tiling.memtile_params.wgt_size),
                    'num_ifm_subv': int(best_tiling.memtile_params.num_ifm_subv),
                    'prm_memtile_size': int(best_tiling.memtile_params.prm_size),
                    'ifm_memtile_size': int(best_tiling.memtile_params.ifm_size),
                    'ofm_memtile_size': int(best_tiling.memtile_params.ofm_size),
                    'conv_kernel_param_size': int(best_tiling.memtile_params.conv_kernel_param_size),
                    'param_subv_size': int(best_tiling.memtile_params.param_subv_size),
                    'mt_co_pack': int(best_tiling.memtile_params.mt_co_pack),
                    'num_pack_wgt_subv': int(best_tiling.memtile_params.num_pack_wgt_subv),
                })
        else:
            print("No valid tiling found!")

        return tiling_params


if __name__ == '__main__':
    import json
    from layer import Layer

    with open('tst_conv.json') as f:
        mdict = json.load(f)
    ld = mdict['tst']
    ld['in_act_residency'] = 'L3'
    ld['out_act_residency'] = 'L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device

    d = Device('strix')
    k = Kernel(l)

    t = ConvTilingOpt(l, d, '8x4', k)
    r = t.find_optimal_tiling()