import yaml
import os

import numpy as np

from LUT_tiler import LUTTiler
from LUT_cost_model import LUTCostModel
from overlay import Overlay
from enum import Enum

parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class LUTTilingOpt:
    def __init__(self, layer, device, overlay_name, kernel):
        with open(os.path.join(parent_dir,'Collaterals/overlays.yaml')) as f:
            all_overlays = yaml.safe_load(f)
        self.modes = list(all_overlays[overlay_name][layer.orig_op_type].keys())
        self.overlay = overlay_name
        self.overlay_objects = []

        self.tilers = []
        self.cost_models = []
        for mode in self.modes:
            tilers = []
            cost_models = []
            overlay = Overlay(overlay_name,layer.orig_op_type,mode)
            self.overlay_objects.append(overlay)
            for subcol, subrow in overlay.subarray:
                sub_overlay = Overlay(overlay_name,layer.orig_op_type,mode)
                # edit the overlay properties to a subarray level
                for operands in overlay.core_splits.keys():
                    sub_overlay.core_splits[operands] = (overlay.core_splits[operands]//overlay.cols//overlay.rows) * subcol * subrow
                    sub_overlay.mem_splits[operands] = (overlay.mem_splits[operands]//overlay.cols) * subcol
                sub_overlay.cols = subcol
                sub_overlay.rows = subrow

                tiler = LUTTiler(layer, device, sub_overlay, kernel)
                cost_model = LUTCostModel(tiler)

                tilers.append( tiler )
                cost_models.append( cost_model )

            self.tilers.append( tilers )
            self.cost_models.append( cost_models )

        self.indexes = []

    def calculate_tiling_cycles(self):
        for tilers, cost_models in zip(self.tilers, self.cost_models):
            for tiler, cost_model in zip(tilers, cost_models):
                tiler.calculate_memtile_tilings()
                tiler.check_valid_memtile_tilings()
                tiler.calculate_array_tilings()
                tiler.check_core_constraints()

                cost_model.calculate_kernel_cycles()
                cost_model.calculate_array_cycles()
                cost_model.calculate_layer_cycles()

    def find_optimal_tiling(self):
        self.calculate_tiling_cycles()

        all_cycles = []
        indexes = []
        iters = []
        for midx, cost_models in enumerate(self.cost_models):
            for subidx, cost_model in enumerate(cost_models):
                for sched, sched_cycles in cost_model.total_layer_cycles.items():
                    for mem_sub_id, cycles in sched_cycles.items():
                        all_cycles.append( cycles )

                        indexes.append( np.hstack( (np.tile([midx, subidx, sched,mem_sub_id],(len(cycles),1)), np.arange(len(cycles)).reshape((len(cycles),1))) ) )
                        iters.append( np.tile(self.tilers[midx][subidx].valid_memtile_iters[sched][mem_sub_id], (len(cycles),1)) )

        all_cycles = np.vstack(all_cycles)
        indexes = np.vstack(indexes)
        iters = np.vstack(iters)
        self.indexes = indexes
        # idx, pingpong = np.unravel_index(np.argmin(all_cycles), all_cycles.shape)

        idx, pp = np.where( all_cycles < np.nanmin(all_cycles)*1.05 )

        total_iters = []
        for opt_midx, opt_subidx, opt_sched, opt_memtiling, opt_coretiling in indexes[idx]:
            try:
                tn  = self.tilers[opt_midx][opt_subidx].valid_core_iters[opt_sched][opt_memtiling][opt_coretiling].astype(int).tolist()
                tmp = self.tilers[opt_midx][opt_subidx].valid_memtile_iters[opt_sched][opt_memtiling]
                total_iters.append(tn*np.prod(tmp))
            except:
                total_iters.append(np.nan)

        # itermin_idx = np.nanargmin(total_iters)
        itermin_idx = np.where(total_iters==np.nanmin(total_iters))[0][-1]

        nidx = idx[itermin_idx]
        pingpong = pp[itermin_idx]

        opt_midx  = indexes[nidx][0]
        opt_subidx = indexes[nidx][1]
        opt_sched = indexes[nidx][2]
        opt_memtiling =  indexes[nidx][3]
        opt_coretiling = indexes[nidx][4]

        core_tiling = {
            'ifm':self.tilers[opt_midx][opt_subidx].valid_core_subvols[opt_sched][opt_memtiling]['ifm'].astype(int).tolist(),
            'ofm':self.tilers[opt_midx][opt_subidx].valid_core_subvols[opt_sched][opt_memtiling]['ofm'].astype(int).tolist()
        }
        tn  = self.tilers[opt_midx][opt_subidx].valid_core_iters[opt_sched][opt_memtiling].astype(int).tolist()
        core_iters = {
            'ifm': tn,
            'ofm': tn
        }

        memtile_tiling = {
            'ifm':self.tilers[opt_midx][opt_subidx].valid_memtile_subvols[opt_sched]['ifm'][opt_memtiling].astype(int).tolist(),
            'ofm':self.tilers[opt_midx][opt_subidx].valid_memtile_subvols[opt_sched]['ofm'][opt_memtiling].astype(int).tolist()
        }

        tmp = self.tilers[opt_midx][opt_subidx].valid_memtile_iters[opt_sched][opt_memtiling]
        memtile_iters = {
            'ifm': [int(tmp[0])],
            'ofm': [int(tmp[1])]
        }

        ## delete the builins from the vars_dict
        del self.tilers[0][0].vars_dict['__builtins__']
        placement_dict = {}
        self.tilers[0][0].vars_dict['Nsubv'] = core_tiling['ifm'][0]
        for buff, bankdict in self.tilers[0][0].kernel.placement_constraints.items():
            placement_dict[buff] = {k:int(eval(str(v),self.tilers[0][0].vars_dict)) for k,v in bankdict.items()}


        memtile_sizes = {
            'ifm': int(np.prod(memtile_tiling['ifm']) * self.tilers[opt_midx][opt_subidx].layer.in_bytes),
            'ofm': int(np.prod(memtile_tiling['ofm']) * self.tilers[opt_midx][opt_subidx].layer.out_bytes),
        }
        schedule_dict = {}
        if opt_sched == 5:
            schedule_dict = {
                'ifm': 'stream',
                'ofm': 'stream',
                'ifm_ping_pong': True,
                'ofm_ping_pong': True
            }

        #layer_padding = {k:list(zip( np.zeros(len(v),dtype=int).tolist(), v.astype(int).tolist() ) ) for k, v in self.tilers[opt_midx][opt_subidx].padding.items()}

        host_layer_padding = [
            {"ifm": 
                {
                "dims": self.tilers[opt_midx][opt_subidx].padded_shapes['ifm'].astype(int).tolist(),  #shape [M,K] 
                "value": [0]
                }
            },
           {"ofm": 
               {
                "dims": self.tilers[opt_midx][opt_subidx].padded_shapes['ofm'].astype(int).tolist(),  #shape [M,K] 
                "value": [0]
                }
            }
        ]

        dma_layer_padding = [
            {"ifm": {
                "dims": [0],  #shape [M,K] 
                "value": [None]
            }},
            {"ofm": {
                "dims": [0],  #shape [M,K] 
                "value": [None]
            }}
        ]


        shim_tilings = {k: (np.array(memtile_tiling[k]) * np.array(memtile_iters[k])).astype(int).tolist() for k in memtile_iters.keys()}
        shim_sizes = {k: int(memtile_sizes[k]*np.prod(memtile_iters[k])) for k in memtile_iters.keys()}

        layerdict = vars(self.tilers[0][0].layer)
        layerdict['in_ifm_shape'] = layerdict.pop('in_act_shape').astype(int).tolist()
        layerdict['out_ofm_shape'] = layerdict.pop('out_act_shape').astype(int).tolist()

        layerdict['in_ifm_datatype'] = layerdict.pop('in_datatype')
        layerdict['in_wgt_datatype'] = layerdict.pop('wgt_datatype')
        layerdict['in_wgt1_datatype'] = layerdict.pop('wgt1_datatype')
        layerdict['out_ofm_datatype'] = layerdict.pop('out_datatype')

        layerdict['in_ifm_bytes'] = layerdict.pop('in_bytes')
        layerdict['in_wgt_bytes'] = layerdict.pop('wgt_bytes')
        layerdict['in_wgt1_bytes'] = layerdict.pop('wgt1_bytes')
        layerdict['out_ofm_bytes'] = layerdict.pop('out_bytes')

        flag_dq = layerdict["attributes"]["disable_dq0"][0]
        flag_q = layerdict["attributes"]["disable_q"][0]

        test_cpp_name = self.tilers[0][0].kernel.testbench_args['HostName']
        tb_cflags = {
                self.tilers[0][0].kernel.testbench_args['CFLAGS'][0] : layerdict['in_ifm_shape'][0], #MGEMM
                self.tilers[0][0].kernel.testbench_args['CFLAGS'][1] : 1, #NGEMM is always set to 1
                self.tilers[0][0].kernel.testbench_args['CFLAGS'][2] : core_tiling['ifm'][0], #MSUBV
                self.tilers[0][0].kernel.testbench_args['CFLAGS'][3] : 1, #NSUBV is always set to 1
                self.tilers[0][0].kernel.testbench_args['CFLAGS'][4] : 1 if layerdict['orig_op_type'] == 'PWLA' else 0,
                # ADD CODE HERE TO SET QDQMODE
                self.tilers[0][0].kernel.testbench_args['CFLAGS'][6] : layerdict['in_ifm_bytes'],
                self.tilers[0][0].kernel.testbench_args['CFLAGS'][7] : layerdict['out_ofm_bytes']
                }
        class QDQMode(Enum):
            QDQ0 = 0
            QDQ1 = 1
            QDQ2 = 2
            QDQ3 = 3

        if (flag_dq == 0) and (flag_q == 0):
            tb_cflags[self.tilers[0][0].kernel.testbench_args['CFLAGS'][5]] = QDQMode.QDQ0.value
        elif (flag_dq == 0) and (flag_q == 1):
            tb_cflags[self.tilers[0][0].kernel.testbench_args['CFLAGS'][5]] = QDQMode.QDQ1.value
        elif (flag_dq == 1) and (flag_q == 0):
            tb_cflags[self.tilers[0][0].kernel.testbench_args['CFLAGS'][5]] = QDQMode.QDQ2.value
        elif (flag_dq == 1) and (flag_q == 1):
            tb_cflags[self.tilers[0][0].kernel.testbench_args['CFLAGS'][5]] = QDQMode.QDQ3.value
        else:
            assert False, "Check IR JSON for incorrect attributes flags"
        
        overlay_dict = {'overlay':self.overlay, 
                        'mode':self.modes[opt_midx],
                        'shape':{'row':self.overlay_objects[opt_midx].rows,'col':self.overlay_objects[opt_midx].cols},
                        'enabled':{'row':self.tilers[opt_midx][opt_subidx].overlay.rows, 'col':self.tilers[opt_midx][opt_subidx].overlay.cols}}
        
        # opt_cycles = self.cost_models[opt_midx][opt_subidx].total_layer_cycles[opt_sched][opt_memtiling][opt_coretiling][pingpong]
        # using a empirical estimate as the cost model is inaccurate
        opt_cycles = np.prod(layerdict['out_ofm_shape'],dtype=np.int64)/9.4
        
        tiling_params = {

            'core_tile_params': {'subvols':core_tiling, 'iters': core_iters},
            'mem_tile_params': {'subvols':memtile_tiling, 'iters': memtile_iters, 'sizes':memtile_sizes},
            'shim_tile_params': {'subvols':shim_tilings,'sizes':shim_sizes},
            'scheduling': schedule_dict,
            'host_layer_padding': host_layer_padding,
            'dma_layer_padding': dma_layer_padding,
            'kernel_info': {'placement_constraints': placement_dict},
            'overlay_info': overlay_dict,
            'layer_info': layerdict,
            'testbench_args': {'HOST_NAME': test_cpp_name, 'COMPILE_FLAGS': tb_cflags},
            'cycle_counts': {'layer_cycles': opt_cycles}

        }
        return tiling_params

if __name__=='__main__':
    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_silugelu.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device

    d = Device('strix')
    k = Kernel(l)

    t = LUTTilingOpt(l,d,'8x4',k)
    r=t.find_optimal_tiling()
