import yaml
import os

import numpy as np

from RoPE_tiler import RoPETiler
from RoPE_cost_model import RoPECostModel
from overlay import Overlay

parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class RoPETilingOpt:
    def __init__(self, layer, device, overlay_name, kernel):
        with open(os.path.join(parent_dir,'Collaterals/overlays.yaml')) as f:
            all_overlays = yaml.safe_load(f)
        self.modes = list(all_overlays[overlay_name][layer.orig_op_type].keys())
        self.overlay = overlay_name

        self.tilers = []
        self.cost_models = []
        for mode in self.modes:

            overlay = Overlay(overlay_name, layer.orig_op_type, mode, layer.op_type.split('_')[0])
            tiler = RoPETiler(layer, device, overlay, kernel)
            self.tilers.append( tiler )
            self.cost_models.append( RoPECostModel(tiler) )

        self.indexes = []

    def calculate_tiling_cycles(self):
        for tiler, cost_model in zip(self.tilers, self.cost_models):
            tiler.calculate_memtile_tilings()
            tiler.check_valid_memtile_tilings()
            tiler.calculate_array_tilings()
            tiler.check_core_constraints()

            cost_model.calculate_kernel_cycles()
            cost_model.calculate_array_cycles()
            cost_model.calculate_layer_cycles()

    def find_optimal_tiling(self):
        self.calculate_tiling_cycles()

        all_cycles = []
        indexes = []
        iters = []
        for midx, cost_model in enumerate(self.cost_models):
            for sched, sched_cycles in cost_model.total_layer_cycles.items():
                for mem_sub_id, cycles in sched_cycles.items():
                    all_cycles.append( cycles )

                    indexes.append( np.hstack( (np.tile([midx, sched,mem_sub_id],(len(cycles),1)), np.arange(len(cycles)).reshape((len(cycles),1))) ) )
                    iters.append( np.tile(self.tilers[midx].valid_memtile_iters[sched][mem_sub_id], (len(cycles),1)) )

        all_cycles = np.vstack(all_cycles)
        indexes = np.vstack(indexes)
        iters = np.vstack(iters)
        self.indexes = indexes
        # idx, pingpong = np.unravel_index(np.argmin(all_cycles), all_cycles.shape)

        idx, pp = np.where( all_cycles < np.nanmin(all_cycles)*1.02 )

        total_iters = []
        for opt_midx, opt_sched, opt_memtiling, opt_coretiling in indexes[idx]:
            tn  = self.tilers[opt_midx].valid_core_iters[opt_sched][opt_memtiling][opt_coretiling].astype(int).tolist()
            tmp = self.tilers[opt_midx].valid_memtile_iters[opt_sched][opt_memtiling]
            total_iters.append(tn*np.prod(tmp))
        itermin_idx = np.argmin(total_iters)

        nidx = idx[itermin_idx]
        pingpong = pp[itermin_idx]

        opt_midx  = indexes[nidx][0]
        opt_sched = indexes[nidx][1]
        opt_memtiling =  indexes[nidx][2]
        opt_coretiling = indexes[nidx][3]

        core_tiling = {
            'ifm':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ifm'].astype(int).tolist(),
            'sin':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['sin'].astype(int).tolist(),
            'cos':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['cos'].astype(int).tolist(),
            'ofm':self.tilers[opt_midx].valid_core_subvols[opt_sched][opt_memtiling]['ofm'].astype(int).tolist()
        }
        tn  = self.tilers[opt_midx].valid_core_iters[opt_sched][opt_memtiling].astype(int).tolist()
        core_iters = {
            'ifm': tn,
            'sin': tn,
            'cos': tn,
            'ofm': tn
        }

        memtile_tiling = {
            'ifm':self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ifm'][opt_memtiling].astype(int).tolist(),
            'sin':self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['sin'][opt_memtiling].astype(int).tolist(),
            'cos':self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['cos'][opt_memtiling].astype(int).tolist(),
            'ofm':self.tilers[opt_midx].valid_memtile_subvols[opt_sched]['ofm'][opt_memtiling].astype(int).tolist()
        }

        tmp = self.tilers[opt_midx].valid_memtile_iters[opt_sched][opt_memtiling]
        memtile_iters = {
            'ifm': [int(tmp[0])],
            'sin': [int(tmp[0])],
            'cos': [int(tmp[0])],
            'ofm': [int(tmp[1])]
        }

        ## delete the builins from the vars_dict
        del self.tilers[0].vars_dict['__builtins__']
        placement_dict = {}
        self.tilers[0].vars_dict['Nsubv'] = core_tiling['ifm'][0]
        for buff, bankdict in self.tilers[0].kernel.placement_constraints.items():
            placement_dict[buff] = {k:int(eval(str(v),self.tilers[0].vars_dict)) for k,v in bankdict.items()}


        memtile_sizes = {
            'ifm': int(np.prod(memtile_tiling['ifm']) * self.tilers[opt_midx].layer.in_bytes),
            'sin': int(np.prod(memtile_tiling['sin']) * self.tilers[opt_midx].layer.in_bytes),
            'cos': int(np.prod(memtile_tiling['cos']) * self.tilers[opt_midx].layer.in_bytes),
            'ofm': int(np.prod(memtile_tiling['ofm']) * self.tilers[opt_midx].layer.out_bytes),
        }
        schedule_dict =  {}
        if opt_sched == 5:
            schedule_dict = {
                'ifm': 'stream',
                'sin': 'stream',
                'cos': 'stream',
                'ofm': 'stream',
                'ifm_ping_pong': True,
                'sin_ping_pong': True,
                'cos_ping_pong': True,
                'ofm_ping_pong': True
            }

        host_layer_padding = [
            {"ifm": 
                {
                "dims": self.tilers[opt_midx].padded_shapes['ifm'].astype(int).tolist(),  #shape [M,K] 
                "value": [0]
                }
            },
            {"sin": 
                {
                "dims": self.tilers[opt_midx].padded_shapes['sin'].astype(int).tolist(),  #shape [M,K] 
                "value": [0]
                }
            },
            {"cos": 
                {
                "dims": self.tilers[opt_midx].padded_shapes['cos'].astype(int).tolist(),  #shape [M,K] 
                "value": [0]
                }
            },
           {"ofm": 
               {
                "dims": self.tilers[opt_midx].padded_shapes['ofm'].astype(int).tolist(),  #shape [M,K] 
                "value": [0]
                }
            }
        ]

        dma_layer_padding = [
            {"ifm": {
                "dims": [0],  #shape [M,K] 
                "value": [None]
            }},
            {"sin": {
                "dims": [0],  #shape [M,K] 
                "value": [None]
            }},
            {"cos": {
                "dims": [0],  #shape [M,K] 
                "value": [None]
            }},
            {"ofm": {
                "dims": [0],  #shape [M,K] 
                "value": [None]
            }}
        ]

        shim_tilings = {k: (np.array(memtile_tiling[k]) * np.array(memtile_iters[k])).astype(int).tolist() for k in memtile_iters.keys()}
        shim_sizes = {k: int(memtile_sizes[k]*np.prod(memtile_iters[k])) for k in memtile_iters.keys()}

        layerdict = vars(self.tilers[0].layer)
        orig_shape = layerdict.pop('orig_act_shape')
        layerdict.pop('in_act_shape')
        layerdict.pop('in_wgt_shape')
        layerdict.pop('out_act_shape')
        
        layerdict['in_ifm_shape'] = [int(self.tilers[opt_midx].padded_shapes['ifm'])]
        layerdict['in_sin_shape'] = [int(self.tilers[opt_midx].padded_shapes['sin'])]
        layerdict['in_cos_shape'] = [int(self.tilers[opt_midx].padded_shapes['cos'])]
        layerdict['out_ofm_shape'] = [int(self.tilers[opt_midx].padded_shapes['ofm'])]

        layerdict['in_ifm_datatype'] = layerdict.pop('in_datatype')
        layerdict['in_sin_datatype'] = layerdict.pop('wgt_datatype')
        layerdict['in_cos_datatype'] = layerdict.pop('wgt1_datatype')
        layerdict['out_ofm_datatype'] = layerdict.pop('out_datatype')

        layerdict['in_ifm_bytes'] = layerdict.pop('in_bytes')
        layerdict['in_sin_bytes'] = layerdict.pop('wgt_bytes')
        layerdict['in_cos_bytes'] = layerdict.pop('wgt1_bytes')
        layerdict['out_ofm_bytes'] = layerdict.pop('out_bytes')

        test_cpp_name = self.tilers[0].kernel.testbench_args['HostName']
        mod_orig_shape = orig_shape
        if layerdict['in_ifm_shape'] != [np.prod(orig_shape)]:
            mod_orig_shape[0] = layerdict['in_ifm_shape'][0] // orig_shape[1]
            mod_orig_shape[1] = orig_shape[1]
        tb_cflags = {
                self.tilers[0].kernel.testbench_args['CFLAGS'][0] : mod_orig_shape[0], #M
                self.tilers[0].kernel.testbench_args['CFLAGS'][1] : mod_orig_shape[1], #N is always set to 1
                self.tilers[0].kernel.testbench_args['CFLAGS'][2] : core_tiling['sin'][0], #MSUBV
                self.tilers[0].kernel.testbench_args['CFLAGS'][3] : 1, #NSUBV is always set to 1
                self.tilers[0].kernel.testbench_args['CFLAGS'][4] : 1 if 'actxact' in layerdict['op_type'] else 0, #DO_NEG
                }

        tiling_params = {

            'core_tile_params': {'subvols':core_tiling, 'iters': core_iters},
            'mem_tile_params': {'subvols':memtile_tiling, 'iters': memtile_iters, 'sizes':memtile_sizes},
            'shim_tile_params': {'subvols':shim_tilings,'sizes':shim_sizes},
            'scheduling': schedule_dict,
            'host_layer_padding': host_layer_padding,
            'dma_layer_padding': dma_layer_padding,
            'kernel_info': {'placement_constraints': placement_dict},
            'overlay_info': {'overlay':self.overlay, 'mode':self.modes[0],'shape':{'row':self.tilers[0].overlay.rows,'col':self.tilers[0].overlay.cols}},
            'layer_info': layerdict,
            'testbench_args': {'HOST_NAME': test_cpp_name, 'COMPILE_FLAGS': tb_cflags},

        }
        return tiling_params

if __name__=='__main__':
    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_elemwise.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    t = RoPETilingOpt(l,d,'8x4',k)
    r=t.find_optimal_tiling()
