import yaml
import os
import ast

import ast
import numpy as np

from Elem_wise_ops_tiler import ElemWiseOpsTiler
from Elem_wise_ops_cost_model import ElemWiseOpsCostModel
from overlay import Overlay

parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class ElemWiseOpsTilingOpt:
    def __init__(self, layer, device, overlay_name, kernel):
        with open(os.path.join(parent_dir,'Collaterals/overlays.yaml')) as f:
            all_overlays = yaml.safe_load(f)
        self.modes = list(all_overlays[overlay_name][layer.orig_op_type][layer.op_type.split('_')[2]].keys())
        core_bank_capacity = 1024 * (device.core_data_memory // device.core_num_banks) #bytes
        if "BroadCast" in layer.op_type:
            self.modes = ['N32_stream'] if np.prod(layer.in_wgt_shape) > core_bank_capacity else ['N32_pin']
        self.overlay = overlay_name

        self.layer = layer

        self.numBatches = None
        if layer.attributes is not None:
            self.numBatches = layer.attributes.get('num_batches', None)
        # layer.in_act_shape[0][0] = layer.in_act_shape[0][0] // self.numBatches[0] if self.numBatches != None else layer.in_act_shape[0][0]
        # layer.orig_act_shape[0] = layer.orig_act_shape[0] // self.numBatches[0] if self.numBatches != None else layer.orig_act_shape[0]

        self.tilers = []
        self.cost_models = []
        for mode in self.modes:

            overlay = Overlay(overlay_name, layer.orig_op_type, mode, layer.op_type.split('_')[2])
            tiler = ElemWiseOpsTiler(layer, device, overlay, kernel)
            self.tilers.append( tiler )
            self.cost_models.append( ElemWiseOpsCostModel(tiler) )

        self.indexes = []
    def calculate_tiling_cycles(self):
        for tiler, cost_model in zip(self.tilers, self.cost_models):
            tiler.calculate_memtile_tilings()
            tiler.check_valid_memtile_tilings()
            tiler.calculate_array_tilings()
            tiler.check_core_constraints()

            cost_model.calculate_kernel_cycles()
            cost_model.calculate_array_cycles()
            cost_model.calculate_layer_cycles()

    def find_optimal_tiling(self):
        self.calculate_tiling_cycles()
        is_bcast = "BroadCast" in vars(self.tilers[0].layer)['op_type']

        all_cycles = []
        indexes = []
        iters = []
        for midx, cost_model in enumerate(self.cost_models):
            for sched, sched_cycles in cost_model.total_layer_cycles.items():
                for mem_sub_id, cycles in sched_cycles.items():
                    all_cycles.append( cycles )

                    indexes.append( np.hstack( (np.tile([midx, sched,mem_sub_id],(len(cycles),1)), np.arange(len(cycles)).reshape((len(cycles),1))) ) )
                    iters.append( np.tile(self.tilers[midx].valid_memtile_iters[sched][mem_sub_id], (len(cycles),1)) )

        all_cycles = np.vstack(all_cycles)
        indexes = np.vstack(indexes)
        iters = np.vstack(iters)
        self.indexes = indexes
        # idx, pingpong = np.unravel_index(np.argmin(all_cycles), all_cycles.shape)

        idx, pp = np.where( all_cycles < np.nanmin(all_cycles)*1.02 )

        total_iters = []
        for opt_midx, opt_sched, opt_memtiling, opt_coretiling in indexes[idx]:
            tn  = self.tilers[opt_midx].valid_core_iters[opt_sched][opt_memtiling][opt_coretiling].astype(int).tolist()
            tmp = self.tilers[opt_midx].valid_memtile_iters[opt_sched][opt_memtiling]
            total_iters.append(tn*np.prod(tmp.astype(int)))   #NOTE- Maybe filter based on IFMB since it's expected to be same or bigger than ifmA

        # Find the indices of the sublists with the minimum length
        lengths = [len(sublist) for sublist in total_iters]
        min_length_indices = np.where(lengths == np.min(lengths))[0]

        # Compare the contents of the sublists with the minimum length
        itermin_idx = min(min_length_indices, key=lambda i: min(total_iters[i]))
        
        #itermin_idx = np.argmin(total_iters)

        nidx = idx[itermin_idx]
        pingpong = pp[itermin_idx]

        opt_midx  = indexes[nidx][0]
        opt_sched = 5 if is_bcast and self.modes[0] == 'N32_stream' else indexes[nidx][1]
        opt_memtiling =  indexes[nidx][2]
        opt_coretiling = indexes[nidx][3]
        core_split = 4 if is_bcast and self.modes[0] == 'N32_pin' else 1

        core_tiling = {
            'ifmA':(self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ifmA']//core_split).astype(int).tolist() if self.numBatches == None else [0],
            'ifmB':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ifmB'].astype(int).tolist(),
            'ofm':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ofm'].astype(int).tolist()
        }
        tn = [item for array in self.tilers[opt_midx].valid_core_iters[opt_sched][opt_memtiling] for item in array.tolist()]
        core_iters = {
            'ifmA': [tn[0]] if self.numBatches == None else [0],
            'ifmB': [tn[1]],
            'ofm': [tn[2]]
        }

        memtile_tiling = {
            'ifmA':self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ifmA'][opt_memtiling].astype(int).tolist() if self.numBatches == None else [0],
            'ifmB':self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ifmB'][opt_memtiling].astype(int).tolist(),
            'ofm':self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ofm'][opt_memtiling].astype(int).tolist()
        }

        tmp = self.tilers[opt_midx].valid_memtile_iters[opt_sched][opt_memtiling]
        memtile_iters = {
            'ifmA': [int(tmp[0])] if self.numBatches == None else [0],
            'ifmB': [int(tmp[1])],
            'ofm': [int(tmp[2])]
        }
        if is_bcast and self.modes[0] == 'N32_pin':
            memtile_tiling['ifmA'] = [core_tiling['ifmA'][0] * core_split] if self.numBatches == None else [0]
            memtile_iters['ifmA'] = core_iters['ifmA'] 

        ## delete the builins from the vars_dict
        del self.tilers[0].vars_dict['__builtins__']
        placement_dict = {}
        self.tilers[0].vars_dict['Nsubv'] = core_tiling['ifmB'][0]
        self.tilers[0].vars_dict['ifmA_len'] = core_tiling['ifmA'][0]

        for buff, bankdict in self.tilers[0].kernel.placement_constraints.items():
            placement_dict[buff] = {k:int(eval(str(v),self.tilers[0].vars_dict)) for k,v in bankdict.items()}

        memtile_sizes = {
            'ifmA': int(np.prod(memtile_tiling['ifmA']) * self.tilers[opt_midx].layer.wgt_bytes) if self.numBatches == None else 0,
            'ifmB': int(np.prod(memtile_tiling['ifmB']) * self.tilers[opt_midx].layer.in_bytes),
            'ofm': int(np.prod(memtile_tiling['ofm']) * self.tilers[opt_midx].layer.out_bytes),
        }
        schedule_dict =  {}

        # Check if any tensor type is const for entry where shape is a non-empty list
        contains_const_tensor = any(entry.get('type') == 'const' and entry.get('shape') for entry in ast.literal_eval(vars(self.tilers[0].layer)['inputs']))

        if contains_const_tensor:
            ifmA_const = 1 #Always ifmA is const in elwadd
        else:
            ifmA_const = 0

        if opt_sched in self.tilers[0].schedule_list:
            schedule_dict = {
                'ifmA': 'pin' if is_bcast and self.modes[0] == 'N32_pin' else 'stream', 
                'ifmB': 'stream',
                'ofm': 'stream',
                'ifmA_ping_pong': False if is_bcast else True,
                'ifmB_ping_pong': True,
                'ofm_ping_pong': True,
                'ifmA_param_type' : 'const' if ifmA_const else 'act',
                'ifmB_param_type' : 'act',
                'ifmA_scale_factor' : self.tilers[0].vars_dict['ifmA_scale_factor'],
                'ifmB_scale_factor' : self.tilers[0].vars_dict['ifmB_scale_factor'],
                'ofm_scale_factor' : self.tilers[0].vars_dict['ofm_scale_factor']
            }
        
        if self.layer.attributes != None:
            dis_dq0 = self.layer.attributes['disable_dq0'][0]
            dis_dq1 = self.layer.attributes['disable_dq1'][0]
            dis_q   = self.layer.attributes['disable_q'][0]
        else:
            dis_dq0 = 0
            dis_dq1 = 0
            dis_q = 0

        original_dimensions = [
            {"input0": 
                {
                "dims": self.layer.orig_wgt_shape,  #shape [n,h,w,c]
                }
            },
            {"input1": 
                {
                "dims": self.layer.orig_act_shape,  #shape [n,h,w,c]
                }
            },
           {"output0": 
               {
                "dims": self.layer.orig_ofm_shape,  #shape [n,h,w,c]
                }
            }
        ]
        if self.layer.attributes != None:
            dis_dq0 = self.layer.attributes['disable_dq0'][0]
            dis_dq1 = self.layer.attributes['disable_dq1'][0]
            dis_q   = self.layer.attributes['disable_q'][0]
        else:
            dis_dq0 = 0
            dis_dq1 = 0
            dis_q = 0

        host_layer_padding = [
            {"ifmA": 
                {
                "dims": self.tilers[0].host_padding['ifmA'],
                "values": [0]*len(self.tilers[0].host_padding['ifmA']) if dis_dq0 else ["zp_i0"]*len(self.tilers[0].host_padding['ifmA'])
                }
            },
            {"ifmB": 
                {
                "dims": self.tilers[0].host_padding['ifmB'],
                "value": [0]*len(self.tilers[0].host_padding['ifmB']) if dis_dq1 else ["zp_i1"]*len(self.tilers[0].host_padding['ifmB'])
                }
            },
           {"ofm": 
               {
                "dims": self.tilers[0].host_padding['ofm'],
                "value": [0]*len(self.tilers[0].host_padding['ofm']) if dis_q else ["zp_o0"]*len(self.tilers[0].host_padding['ofm'])
                }
            }
        ]

        dma_layer_padding = [
            {"ifmA": 
                {
                "dims":  [(np.prod(self.tilers[opt_midx].padding['ifmA'])).astype(np.int32).tolist()],  #shape [n,h,w,c]
                "values": [0]*len(self.tilers[0].host_padding['ifmA']) if dis_dq0 else ["zp_i0"]*len(self.tilers[0].host_padding['ifmA']),
                "channels": [0, 1, 2, 3]
                }
            },
            {"ifmB": 
                {
                "dims":  [(np.prod(self.tilers[opt_midx].padding['ifmB'])).astype(np.int32).tolist()],  #shape [n,h,w,c]
                "values": [0]*len(self.tilers[0].host_padding['ifmB']) if dis_dq1 else ["zp_i1"]*len(self.tilers[0].host_padding['ifmB']),
                "channels": [0, 1, 2, 3]
                }
            },
           {"ofm": 
               {
                "dims":  [(np.prod(self.tilers[opt_midx].padding['ofm'])).astype(np.int32).tolist()],  #shape [n,h,w,c]
                "values": [0]*len(self.tilers[0].host_padding['ofm']) if dis_q else ["zp_o0"]*len(self.tilers[0].host_padding['ofm']),
                "channels": [5]
                }
            }
        ]

        #shim_tilings = {k: (np.array(memtile_tiling[k]) * np.array(memtile_iters[k])).astype(int).tolist() for k in memtile_iters.keys()}
        shim_tilings = {
            'ifmA': ([self.tilers[opt_midx].padded_shapes['ifmA'].astype(int).tolist()] // self.tilers[0].overlay.mem_splits['ifmA']).tolist() if self.numBatches == None else [0],
            'ifmB': ([self.tilers[opt_midx].padded_shapes['ifmB'].astype(int).tolist()] // self.tilers[0].overlay.mem_splits['ifmB']).tolist(),
            'ofm':  ([self.tilers[opt_midx].padded_shapes['ofm'].astype(int).tolist()] // self.tilers[0].overlay.mem_splits['ofm']).tolist()
        }
        #shim_sizes = {k: int(memtile_sizes[k]*np.prod(memtile_iters[k])) for k in memtile_iters.keys()}
        shim_sizes = {
            'ifmA': int(np.prod(shim_tilings['ifmA']) * self.tilers[opt_midx].layer.wgt_bytes) if self.numBatches == None else 0,
            'ifmB': int(np.prod(shim_tilings['ifmB']) * self.tilers[opt_midx].layer.in_bytes),
            'ofm': int(np.prod(shim_tilings['ofm']) * self.tilers[opt_midx].layer.out_bytes)
        }

        cflags_shapes_ifmA = vars(self.tilers[0].layer)['in_wgt_shape']
        cflags_shapes_ifmB = vars(self.tilers[0].layer)['in_act_shape']
        layerdict = vars(self.tilers[0].layer)
        
        
        layerdict['in_ifmA_shape'] = [int(np.prod(vars(self.tilers[0].layer)['in_wgt_shape']).tolist())] if self.numBatches == None else [0]
        layerdict['in_ifmB_shape'] = [int(np.prod(vars(self.tilers[0].layer)['in_act_shape']).tolist())]
        layerdict['out_ofm_shape'] = [int(np.prod(vars(self.tilers[0].layer)['out_act_shape']).tolist())]

        layerdict['in_ifmA_datatype'] = layerdict.pop('wgt_datatype')
        layerdict['in_ifmB_datatype'] = layerdict.pop('in_datatype')
        layerdict['in_wgt1_datatype'] = layerdict.pop('wgt1_datatype')
        layerdict['out_ofm_datatype'] = layerdict.pop('out_datatype')

        layerdict['in_ifmA_bytes'] = layerdict.pop('wgt_bytes')
        layerdict['in_ifmB_bytes'] = layerdict.pop('in_bytes')
        layerdict['in_wgt1_bytes'] = layerdict.pop('wgt1_bytes')
        layerdict['out_ofm_bytes'] = layerdict.pop('out_bytes')

        test_cpp_name = self.tilers[0].kernel.testbench_args['HostName']
        tb_cflags = {
            self.tilers[0].kernel.testbench_args['CFLAGS'][0] : ",".join(map(str, cflags_shapes_ifmB)),  #DIMS
            self.tilers[0].kernel.testbench_args['CFLAGS'][1] : 1 if layerdict['orig_op_type']  in {'Add','Mul'} else 0, #IS_ADD - check with venkat if MUL flag exist
            self.tilers[0].kernel.testbench_args['CFLAGS'][2] : int(layerdict['in_ifmA_bytes']),
            self.tilers[0].kernel.testbench_args['CFLAGS'][3] : int(layerdict['in_ifmB_bytes']),
            self.tilers[0].kernel.testbench_args['CFLAGS'][4] : int(layerdict['out_ofm_bytes']),
            self.tilers[0].kernel.testbench_args['CFLAGS'][5] : 1 if contains_const_tensor and self.numBatches == None else 0,  #ACT_CONST_ADD
            self.tilers[0].kernel.testbench_args['CFLAGS'][6] : self.numBatches[0] if self.numBatches != None else 1, #CASCADE
            self.tilers[0].kernel.testbench_args['CFLAGS'][7] : 1 if not dis_dq0 else 0,
            self.tilers[0].kernel.testbench_args['CFLAGS'][8] : 1 if not dis_dq1 else 0,
            self.tilers[0].kernel.testbench_args['CFLAGS'][9] : 1 if not dis_q else 0
        }
        layerdict.pop('orig_act_shape')
        layerdict.pop('orig_wgt_shape')
        layerdict.pop('orig_ofm_shape')

        layerdict.pop('in_act_shape')
        layerdict.pop('in_wgt_shape')
        layerdict.pop('out_act_shape')

        # TODO improve and correct the cycle estimation
        # assume IO dominated cycle count, takes max of ifmA read, ifmB read, and ofm write and assume 80% efficiency
        opt_cycles = max( np.prod(layerdict['in_ifmA_shape']) * layerdict['in_ifmA_bytes'], 
                               np.prod(layerdict['in_ifmB_shape']) * layerdict['in_ifmB_bytes'],
                               np.prod(layerdict['out_ofm_shape']) * layerdict['out_ofm_bytes']) / 32 / 0.8 
        tiling_params = {

            'core_tile_params': {'subvols':core_tiling, 'iters': core_iters},
            'mem_tile_params': {'subvols':memtile_tiling, 'iters': memtile_iters, 'sizes':memtile_sizes},
            'shim_tile_params': {'subvols':shim_tilings,'sizes':shim_sizes},
            'scheduling': schedule_dict,
            'original_dimensions': original_dimensions,
            'host_layer_padding': host_layer_padding,
            'dma_layer_padding': dma_layer_padding,
            'kernel_info': {'placement_constraints': placement_dict},
            'overlay_info': {'overlay':self.overlay, 'mode':self.modes[0],'shape':{'row':self.tilers[0].overlay.rows,'col':self.tilers[0].overlay.cols}},
            'layer_info': layerdict,
            'testbench_args': {'HOST_NAME': test_cpp_name, 'COMPILE_FLAGS': tb_cflags},
            'cycle_counts': {'layer_cycles': opt_cycles}

        }
        return tiling_params

if __name__=='__main__':
    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_elemwise.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    t = ElemWiseOpsTilingOpt(l,d,'8x4',k)
    r=t.find_optimal_tiling()
