import yaml
import os
import numpy as np

from dataflow.dataflow_common import overlay_stack_addr
# from collections import namedtuple


class Overlay:
    
    def __init__(self, overlay, op_type, mode, op_mode='Invalid', device=None, overlay_type:str='B'):
        # for operators that overlay not defined in overlays.yml, read cols and rows from input argument
        self.overlay_name = overlay
        self.mode = mode

        self.type = overlay_type
        self.device = device

        cwd = (os.path.dirname(os.path.abspath(__file__)))
        with open(cwd+'/../../Collaterals/overlays.yaml') as f:
            all_overlays = yaml.safe_load(f)
        
        if overlay not in all_overlays :#or overlay_type not in all_overlays[overlay]:
            raise ValueError(f"Overlay '{overlay}' with type '{overlay_type}' does not exist in overlays.yaml")

        self.cols = all_overlays[overlay]['cols']
        self.rows = all_overlays[overlay]['rows']

        self.broadcast_cols = (self.cols // 2) if self.cols > 4 else self.cols

        # Get memtile_group_size from overlays.yaml if available    
        self.memtile_group_size = all_overlays[overlay].get('memtile_group_size', 1)

        self.memtile_broadcast_channels = all_overlays[overlay]['memtile_broadcast_channels']
        self.memtile_unicast_channels = all_overlays[overlay]['memtile_unicast_channels']        
        self.memtile_ofm_channels = all_overlays[overlay]['memtile_ofm_channels']


        
        # Derive num_memtile_subregions based on memtile_group_size and cols
        self.num_memtile_subregions = self.cols // self.memtile_group_size
        
        # Set memtile_capacity_bytes if device is provided
        if self.device:
            # TODO: May need better way to calculate core_data_mem_size 
            # as function of device core_mem_size and reserved space sizes.
            self.coretile_capacity_bytes = overlay_stack_addr()
            self.memtile_capacity_bytes = self.device.memtile_capacity_bytes * self.memtile_group_size
            
        if mode is None:
            self.cols, self.rows = map(int, overlay.split("x"))
            return
            
        special_handling_op = ['Add', 'Mul']
        if op_type in special_handling_op:
            assert op_mode in ['EleWise', 'BroadCast'], f"Invalid mode. It should be either Elewise or Broadcast. OP: {op_type}, Mode: {op_mode}"
            overlay_params = all_overlays[overlay][op_type][op_mode][mode]
        else:
            overlay_params = all_overlays[overlay][op_type][mode]
        
        self.cols, self.rows = all_overlays[overlay]['cols'], all_overlays[overlay]['rows']
        
        if overlay_params is not None:
            if 'core_splits' in overlay_params:
                self.core_splits = dict(
                    zip(
                        overlay_params["core_splits"].keys(),
                        map(np.array, overlay_params["core_splits"].values()),
                    )
                )
            else:
                self.core_splits = None
            
            if 'mem_splits' in overlay_params:
                self.mem_splits = dict(
                    zip(
                        overlay_params["mem_splits"].keys(),
                        map(np.array, overlay_params["mem_splits"].values()),
                    )
                )
            else:
                self.mem_splits = None

            if 'unicast' in overlay_params:
                self.unicast = overlay_params['unicast']
            if "subarray" in overlay_params:
                self.subarray = overlay_params["subarray"]


if __name__ == "__main__":
    tstov = Overlay("4x4", "MatMul", "M4K1N4")
