import yaml
import os
import ast

import numpy as np

from broadcast_tiler import BroadcastTiler
from broadcast_cost_model import BroadcastCostModel
from overlay import Overlay

import pdb
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class BroadcastTilingOpt:
    def __init__(self, layer, device, overlay_name, kernel):
        with open(os.path.join(parent_dir,'Collaterals/overlays.yaml')) as f:
            all_overlays = yaml.safe_load(f)
        self.modes = list(all_overlays[overlay_name][layer.orig_op_type][layer.op_type.split('_')[2]].keys())
        self.overlay = overlay_name
        self.layer = layer
        self.tilers = []
        self.cost_models = []
        for mode in self.modes:
            overlay = Overlay(overlay_name, layer.orig_op_type, mode, layer.op_type.split('_')[2])
            tiler = BroadcastTiler(layer, device, overlay, kernel)
            self.tilers.append( tiler )
            self.cost_models.append( BroadcastCostModel(tiler) )

        self.indexes = []

    def calculate_tiling_cycles(self):
        for tiler, cost_model in zip(self.tilers, self.cost_models):
            tiler.calculate_memtile_tilings()
            tiler.check_valid_memtile_tilings()
            tiler.calculate_array_tilings()
            tiler.check_core_constraints()

            cost_model.calculate_kernel_cycles()
            cost_model.calculate_array_cycles()
            cost_model.calculate_layer_cycles()

    def find_optimal_tiling(self):
        self.calculate_tiling_cycles()

        opt_midx = 0
        opt_sched = 2
        core_split = 4
        opt_memtiling = 0

        core_tiling = {
            'ifmA':(self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ifmA']).astype(int).tolist(),
            'ifmB':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ifmB'].astype(int).tolist(),
            'ofm':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ofm'].astype(int).tolist()
        }
        tn = [array for array in self.tilers[opt_midx].valid_core_iters[opt_sched][opt_memtiling]]
        core_iters = {
            'ifmA': tn[0],
            'ifmB': tn[1],
            'ofm': tn[2]
        }

        memtile_ifmA = self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ifmA'][opt_memtiling].astype(int).tolist()
        memtile_ifmB = self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ifmB'][opt_memtiling].astype(int).tolist()
        memtile_ofm = self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ofm'][opt_memtiling].astype(int).tolist()

        memtile_tiling = {
            'ifmA':memtile_ifmA if type(memtile_ifmA) == list else [memtile_ifmA],
            'ifmB':memtile_ifmB if type(memtile_ifmB) == list else [memtile_ifmB],
            'ofm':memtile_ofm if type(memtile_ofm) == list else [memtile_ofm],
        }

        tmp = self.tilers[opt_midx].valid_memtile_iters[opt_sched][opt_memtiling]
        
        memtile_iters = {
            'ifmA': [tmp[0].astype(int).tolist()],
            'ifmB': [tmp[1].astype(int).tolist()],
            'ofm': [tmp[2].astype(int).tolist()]
        }

        ## delete the builins from the vars_dict
        del self.tilers[0].vars_dict['__builtins__']
        placement_dict = {}
        self.tilers[0].vars_dict['Nsubv'] = core_tiling['ifmA'][0]
        self.tilers[0].vars_dict['Msubv'] = core_tiling['ifmB'][0] // core_tiling['ifmA'][0]

        for buff, bankdict in self.tilers[0].kernel.placement_constraints.items():
            placement_dict[buff] = {k:int(eval(str(v),self.tilers[0].vars_dict)) for k,v in bankdict.items()}

        memtile_sizes = {
            'ifmA': int(np.prod(memtile_tiling['ifmA']) * self.tilers[opt_midx].layer.wgt_bytes),
            'ifmB': int(np.prod(memtile_tiling['ifmB']) * self.tilers[opt_midx].layer.in_bytes),
            'ofm': int(np.prod(memtile_tiling['ofm']) * self.tilers[opt_midx].layer.out_bytes),
        }
        schedule_dict =  {}

        # Check if any tensor type is const for entry where shape is a non-empty list
        const_param_name = None
        contains_const_tensor = any(entry.get('type') == 'const' and entry.get('shape') for entry in ast.literal_eval(vars(self.tilers[0].layer)['inputs']))
        for entry in ast.literal_eval(vars(self.tilers[0].layer)['inputs']):
            if entry.get('type') == 'const' and entry.get('shape'):
                const_param_name  = entry.get('param_name')

        ifm_modes = 'stream'
        if self.tilers[0].ifm_mode == 'N32':
            ifm_modes = 'stream'
        elif self.tilers[0].ifm_mode == 'N8':
            ifm_modes = 'broadcast'
        else:
            ifm_modes = 'pin'

        if contains_const_tensor:
            ifmA_const = 1 if const_param_name == 'B' else 0  #A and B swapped
            ifmB_const = 1 if const_param_name == 'A' else 0  #A and B swapped
        else:
            ifmA_const = 0
            ifmB_const = 0

        assert not (ifmA_const == 1 and ifmB_const == 1), 'Both inputs cannot be const'

        if opt_sched in self.tilers[0].schedule_list:
            schedule_dict = {
                'ifmA': ifm_modes, 
                'ifmB': 'stream',
                'ofm': 'stream',
                'ifmA_ping_pong': self.tilers[0].ping_pong_ifmA,
                'ifmB_ping_pong': self.tilers[0].ping_pong_ifmB,
                'ofm_ping_pong': False,
                'ifmA_param_type' : 'const' if ifmA_const else 'act',
                'ifmB_param_type' : 'const' if ifmB_const else 'act',
                'multi_ch_batch_bcast': self.tilers[0].bcast_multi_channel,
                'inner_most_dim_is_1': self.tilers[0].innermost_broadcast,
                'ifmA_scale_factor' : self.tilers[0].vars_dict['ifmA_scale_factor'],
                'ifmB_scale_factor' : self.tilers[0].vars_dict['ifmB_scale_factor'],
                'ofm_scale_factor' : self.tilers[0].vars_dict['ofm_scale_factor']
            }

        original_dimensions = [
            {"input0": 
                {
                "dims": self.layer.orig_wgt_shape,  #shape [n,h,w,c]
                }
            },
            {"input1": 
                {
                "dims": self.layer.orig_act_shape,  #shape [n,h,w,c]
                }
            },
           {"output0": 
               {
                "dims": self.layer.orig_ofm_shape,  #shape [n,h,w,c]
                }
            }
        ]

        if self.layer.attributes != None:
            dis_dq0 = self.layer.attributes['disable_dq0'][0]
            dis_dq1 = self.layer.attributes['disable_dq1'][0]
            dis_q   = self.layer.attributes['disable_q'][0]
        else:
            dis_dq0 = 0
            dis_dq1 = 0
            dis_q = 0

        host_layer_padding = [
            {"ifmA": 
                {
                "dims": self.tilers[0].host_padding['ifmA'],
                "values": [0]*len(self.tilers[0].host_padding['ifmA']) if dis_dq0 else ["zp_i0"]*len(self.tilers[0].host_padding['ifmA'])
                }
            },
            {"ifmB": 
                {
                "dims": self.tilers[0].host_padding['ifmB'],
                "value": [0]*len(self.tilers[0].host_padding['ifmB']) if dis_dq1 else ["zp_i1"]*len(self.tilers[0].host_padding['ifmB'])
                }
            },
           {"ofm": 
               {
                "dims": self.tilers[0].host_padding['ofm'],
                "value": [0]*len(self.tilers[0].host_padding['ofm']) if dis_q else ["zp_o0"]*len(self.tilers[0].host_padding['ofm'])
                }
            }
        ]

        dma_layer_padding = [
            {"ifmA": 
                {
                "dims":  [(np.prod(self.tilers[opt_midx].padding['ifmA'])).astype(np.int32).tolist()],  #shape [n,h,w,c]
                "values": [0]*len(self.tilers[0].host_padding['ifmA']) if dis_dq0 else ["zp_i0"]*len(self.tilers[0].host_padding['ifmA']),
                "channels": [0, 1, 2, 3]
                }
            },
            {"ifmB": 
                {
                "dims":  [(np.prod(self.tilers[opt_midx].padding['ifmB'])).astype(np.int32).tolist()],  #shape [n,h,w,c]
                "values": [0]*len(self.tilers[0].host_padding['ifmB']) if dis_dq1 else ["zp_i1"]*len(self.tilers[0].host_padding['ifmB']),
                "channels": [0, 1, 2, 3]
                }
            },
           {"ofm": 
               {
                "dims":  [(np.prod(self.tilers[opt_midx].padding['ofm'])).astype(np.int32).tolist()],  #shape [n,h,w,c]
                "values": [0]*len(self.tilers[0].host_padding['ofm']) if dis_q else ["zp_o0"]*len(self.tilers[0].host_padding['ofm']),
                "channels": [5]
                }
            }
        ]
        if self.tilers[opt_midx].innermost_broadcast:
            shim_tilings = {
                'ifmA': ([np.prod(self.tilers[opt_midx].padded_shapes['ifmA']).tolist()]),
                'ifmB': ([np.prod(self.tilers[opt_midx].padded_shapes['ifmB']).tolist()] // self.tilers[0].overlay.mem_splits['ifmB']).astype(np.int32).tolist(),
                'ofm':  ([np.prod(self.tilers[opt_midx].padded_shapes['ofm']).tolist()] // self.tilers[0].overlay.mem_splits['ofm']).astype(np.int32).tolist()
            }
        else:
            shim_tilings = {
                'ifmA': ([np.prod(self.tilers[opt_midx].padded_shapes['ifmA']).tolist()] // self.tilers[0].overlay.mem_splits['ifmA']).astype(np.int32).tolist() if ifm_modes != 'pin' else [np.prod(self.tilers[opt_midx].padded_shapes['ifmA']).tolist()],
                'ifmB': ([np.prod(self.tilers[opt_midx].padded_shapes['ifmB']).tolist()] // self.tilers[0].overlay.mem_splits['ifmB']).astype(np.int32).tolist(),
                'ofm':  ([np.prod(self.tilers[opt_midx].padded_shapes['ofm']).tolist()] // self.tilers[0].overlay.mem_splits['ofm']).astype(np.int32).tolist()
            }
        shim_sizes = {
            'ifmA': int(np.prod(shim_tilings['ifmA']) * self.tilers[opt_midx].layer.wgt_bytes),
            'ifmB': int(np.prod(shim_tilings['ifmB']) * self.tilers[opt_midx].layer.in_bytes),
            'ofm': int(np.prod(shim_tilings['ofm']) * self.tilers[opt_midx].layer.out_bytes)
        }

        cflags_shapes_ifmA = np.array(self.tilers[0].cflags['ifmA']).astype(np.int32).tolist()
        cflags_shapes_ifmB = np.array(self.tilers[0].cflags['ifmB']).astype(np.int32).tolist()
        layerdict = vars(self.tilers[0].layer)
        
        
        layerdict['in_ifmA_shape'] = [np.prod(self.tilers[opt_midx].layer.in_wgt_shape).astype(int).tolist()]
        layerdict['in_ifmB_shape'] = [np.prod(self.tilers[opt_midx].layer.in_act_shape).astype(int).tolist()]
        layerdict['out_ofm_shape'] = [np.prod(self.tilers[opt_midx].layer.in_act_shape).astype(int).tolist()]

        layerdict['in_ifmA_datatype'] = layerdict.pop('wgt_datatype')
        layerdict['in_ifmB_datatype'] = layerdict.pop('in_datatype')
        layerdict['in_wgt1_datatype'] = layerdict.pop('wgt1_datatype')
        layerdict['out_ofm_datatype'] = layerdict.pop('out_datatype')

        layerdict['in_ifmA_bytes'] = layerdict.pop('wgt_bytes')
        layerdict['in_ifmB_bytes'] = layerdict.pop('in_bytes')
        layerdict['in_wgt1_bytes'] = layerdict.pop('wgt1_bytes')
        layerdict['out_ofm_bytes'] = layerdict.pop('out_bytes')

        test_cpp_name = self.tilers[0].kernel.testbench_args['HostName']
        tb_cflags = {
                self.tilers[0].kernel.testbench_args['CFLAGS'][0] : ",".join(map(str, cflags_shapes_ifmA)),  #DIMS_IFMA
                self.tilers[0].kernel.testbench_args['CFLAGS'][1] : ",".join(map(str, cflags_shapes_ifmB)),  #DIMS_IFMB
                self.tilers[0].kernel.testbench_args['CFLAGS'][2] : 1 if layerdict['orig_op_type']  in {'Add','Mul'} else 0, #ADD or MUL
                self.tilers[0].kernel.testbench_args['CFLAGS'][3] : int(layerdict['in_ifmA_bytes']),
                self.tilers[0].kernel.testbench_args['CFLAGS'][4] : int(layerdict['in_ifmB_bytes']),
                self.tilers[0].kernel.testbench_args['CFLAGS'][5] : int(layerdict['out_ofm_bytes']),
                self.tilers[0].kernel.testbench_args['CFLAGS'][6] : ifmA_const,
                self.tilers[0].kernel.testbench_args['CFLAGS'][7] : ifmB_const,
                self.tilers[0].kernel.testbench_args['CFLAGS'][8] : 1 if self.tilers[0].innermost_broadcast else 0,
                self.tilers[0].kernel.testbench_args['CFLAGS'][9] : 1 if (dis_dq0 == 0) else 0,
                self.tilers[0].kernel.testbench_args['CFLAGS'][10] : 1 if (dis_dq1 == 0) else 0,
                self.tilers[0].kernel.testbench_args['CFLAGS'][11] : 1 if (dis_q == 0) else 0
            }
        layerdict.pop('orig_act_shape')
        layerdict.pop('orig_wgt_shape')
        layerdict.pop('orig_ofm_shape')

        layerdict.pop('in_act_shape')
        layerdict.pop('in_wgt_shape')
        layerdict.pop('out_act_shape')

        # TODO improve and correct the cycle estimation
        # assume IO dominated cycle count, takes max of ifmA read, ifmB read, and ofm write and assume 80% efficiency
        opt_cycles = max( np.prod(layerdict['in_ifmA_shape']) * layerdict['in_ifmA_bytes'], 
                               np.prod(layerdict['in_ifmB_shape']) * layerdict['in_ifmB_bytes'],
                               np.prod(layerdict['out_ofm_shape']) * layerdict['out_ofm_bytes']) / 32 / 0.8 
        
        tiling_params = {

            'core_tile_params': {'subvols':core_tiling, 'iters': core_iters},
            'mem_tile_params': {'subvols':memtile_tiling, 'iters': memtile_iters, 'sizes':memtile_sizes},
            'shim_tile_params': {'subvols':shim_tilings,'sizes':shim_sizes},
            'scheduling': schedule_dict,
            'original_dimensions': original_dimensions,
            'host_layer_padding': host_layer_padding,
            'dma_layer_padding': dma_layer_padding,
            'kernel_info': {'placement_constraints': placement_dict},
            'overlay_info': {'overlay':self.overlay, 'mode':self.modes[0],'shape':{'row':self.tilers[0].overlay.rows,'col':self.tilers[0].overlay.cols}},
            'layer_info': layerdict,
            'testbench_args': {'HOST_NAME': test_cpp_name, 'COMPILE_FLAGS': tb_cflags},
            'cycle_counts': {'layer_cycles': opt_cycles}

        }

        def make_serializable(obj):
            if isinstance(obj, dict):
                return {k: make_serializable(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [make_serializable(i) for i in obj]
            elif isinstance(obj, tuple):
                return tuple(make_serializable(i) for i in obj)
            elif isinstance(obj, np.generic):
                return obj.item()
            else:
                return obj
        tiling_params = make_serializable(tiling_params)
        return tiling_params
    
if __name__=='__main__':
    import json
    from layer import Layer
    with open('tst_conv_layer.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    t = BroadcastTilingOpt(l,d,'8x4',k)
    r=t.find_optimal_tiling()
