import yaml
import os

import numpy as np

from OGOAT.src.Tiler.tiler import Tiler
from cost_model import CostModel
from OGOAT.src.Tiler.overlay import Overlay

parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class GroupNormTilingOpt:
    def __init__(self, layer, device, overlay_name, kernel):
        with open(os.path.join(parent_dir,'Collaterals/overlays.yaml')) as f:
            all_overlays = yaml.safe_load(f)
        self.modes = list(all_overlays[overlay_name][layer.orig_op_type].keys())
        self.overlay = overlay_name
        
        self.tilers = []
        self.cost_models = []
        for mode in self.modes:
            overlay = Overlay(overlay_name,layer.orig_op_type,mode)
            tiler = Tiler(layer, device, overlay, kernel)
            self.tilers.append( tiler )
            self.cost_models.append( CostModel(tiler) )

        self.indexes = []

    def calculate_tiling_cycles(self):
        for tiler, cost_model in zip(self.tilers, self.cost_models):
            tiler.calculate_array_tilings()
            tiler.check_core_constraints()
            tiler.calculate_memtile_tilings()
            tiler.check_valid_memtile_tilings()
            #cost_model.calculate_kernel_cycles()
            #cost_model.calculate_array_cycles()
            #cost_model.calculate_layer_cycles()

    def find_optimal_tiling(self):
        self.calculate_tiling_cycles()
        core_tiling = {
            'ifm':self.tilers[0].valid_core_subvols['ifm'],
            'ofm':self.tilers[0].valid_core_subvols['ofm']
        }
        core_iters = {
            'ifm':self.tilers[0].core_iters['ifm'],
            'ofm':self.tilers[0].core_iters['ofm']
        }

        memtile_tiling = {
            'ifm':self.tilers[0].valid_memtile_subvols['ifm'],
            'ofm':self.tilers[0].valid_memtile_subvols['ofm']
        }
        tdim_layer = [self.tilers[0].tensorshape['ifm'][0].item(), self.tilers[0].tensorshape['ifm'][1].item()]
        tensor_padded    = self.tilers[0].padded_shape['ifm'] #layerdict.get('in_act_shape')
        memtile_iters = self.tilers[0].memtile_iters

        ## delete the builins from the vars_dict
        #del self.tilers[0].vars_dict['__builtins__']
        placement_dict = {}
        self.tilers[0].vars_dict['Msubv'] = core_tiling['ifm'][0]
        self.tilers[0].vars_dict['Nsubv'] = core_tiling['ifm'][1]
        self.tilers[0].vars_dict['biaRepetition']  = core_tiling['ifm'][0]
        for buff, bankdict in self.tilers[0].kernel.placement_constraints.items():
            placement_dict[buff] = {k:int(eval(str(v),self.tilers[0].vars_dict)) for k,v in bankdict.items()}
                
        memtile_sizes = {
            'ifm': int(np.prod(np.array(memtile_tiling['ifm'])) * self.tilers[0].layer.in_bytes),
            'ofm': int(np.prod(np.array(memtile_tiling['ofm'])) * self.tilers[0].layer.out_bytes),
        }

   
        schedule_dict = {
            'ifm': 'stream',
            'ofm': 'stream',
            'ifm_ping_pong': True,
            'ofm_ping_pong': True
        }
        

        shim_tilings = {k: (np.array(memtile_tiling[k]) * np.array(memtile_iters[k])).astype(int).tolist() for k in memtile_iters.keys()}
        shim_sizes = {k: (memtile_sizes[k]*np.prod(memtile_iters[k])).astype(int).item() for k in memtile_iters.keys()}

        layerdict = vars(self.tilers[0].layer)
        layerdict['in_ifm_shape'] = tensor_padded #self.tilers[0].tensorshape['ifm'].tolist()
        layerdict['out_ofm_shape'] = tensor_padded #self.tilers[0].tensorshape['ofm'].tolist()

        layerdict['in_ifm_datatype'] = layerdict.pop('in_datatype')
        layerdict['in_wgt_datatype'] = 'bfloat16'
        layerdict['in_wgt1_datatype'] = 'bfloat16'
        layerdict['out_ofm_datatype'] = layerdict.pop('out_datatype')
        
        layerdict['in_ifm_bytes'] = layerdict.pop('in_bytes')
        layerdict['in_wgt_bytes'] = 2
        layerdict['in_wgt1_bytes'] = 2
        layerdict['out_ofm_bytes'] = layerdict.pop('out_bytes')
        
        
        dis_dq = layerdict["attributes"]["disable_dq0"][0]
        dis_q = layerdict["attributes"]["disable_q"][0]
        
        original_dimensions = [
            {"input0": 
                {
                "dims": layerdict.get('in_act_shape'),  #shape [n,h,w,c]
                }
            },
           {"output0": 
               {
                "dims": layerdict.get('out_act_shape'),  #shape [n,h,w,c]
                }
            }
        ]
        
        host_layer_padding = [
            {"input0": 
                {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "values": [0]*len(tensor_padded) if dis_dq else ["zp_i0"]*len(tensor_padded)
                }
            },
           {"output0": 
               {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "values": [0]*len(tensor_padded) if dis_q else ["zp_o0"]*len(tensor_padded)
                }
            }
        ]
        
        dma_layer_padding = [
            {"input0": 
                {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "values": [0]*len(tensor_padded) if dis_dq else ["zp_i0"]*len(tensor_padded),
                "channels": [0, 1, 2, 3]
                }
            },
           {"output0": 
               {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "values": [0]*len(tensor_padded) if dis_q else ["zp_o0"]*len(tensor_padded),
                "channels": [5]
                }
            }
        ]

        test_cpp_name = self.tilers[0].kernel.testbench_args['HostName']
        tb_cflags = {
                self.tilers[0].kernel.testbench_args['CFLAGS'][0] : 0, ## 1->LRN; 0->GPN
                self.tilers[0].kernel.testbench_args['CFLAGS'][1] : tensor_padded[0], #K_GEMM_A16W8
                self.tilers[0].kernel.testbench_args['CFLAGS'][2] : tensor_padded[1], #N_GEMM_A16W8
                self.tilers[0].kernel.testbench_args['CFLAGS'][3] : tdim_layer[0], #N_GEMM_A16W8
                self.tilers[0].kernel.testbench_args['CFLAGS'][4] : tdim_layer[1], #N_GEMM_A16W8
                self.tilers[0].kernel.testbench_args['CFLAGS'][5] : 0 if dis_dq else 1, # Enable DQ
                self.tilers[0].kernel.testbench_args['CFLAGS'][6] : 0 if dis_q else 1, # Enable Q
                self.tilers[0].kernel.testbench_args['CFLAGS'][7] : layerdict['in_ifm_bytes'],
                self.tilers[0].kernel.testbench_args['CFLAGS'][8] : layerdict['out_ofm_bytes']
                }
        # using a empirical estimate as the cost model is inaccurate
        opt_cycles = np.prod(layerdict['out_ofm_shape'],dtype=np.int64)/2.5                
        tiling_params = {

            'core_tile_params': {'subvols':core_tiling, 'iters': core_iters},
            'mem_tile_params': {'subvols':memtile_tiling, 'iters': memtile_iters, 'sizes':memtile_sizes},
            'shim_tile_params': {'subvols':shim_tilings,'sizes':shim_sizes},
            'scheduling': schedule_dict,
            'original_dimensions': original_dimensions,
            'host_layer_padding': host_layer_padding,
            'dma_layer_padding': dma_layer_padding,
            'kernel_info': {'placement_constraints': placement_dict, 'tdim_layer': tdim_layer},
            'overlay_info': {'overlay':self.overlay, 'mode':self.modes[0],'shape':{'col':self.tilers[0].overlay.cols,'row':self.tilers[0].overlay.rows}},
            'layer_info': layerdict,
            'testbench_args': {'HOST_NAME': test_cpp_name, 'COMPILE_FLAGS': tb_cflags},
            'cycle_counts': {'layer_cycles': opt_cycles}
        }

        return tiling_params
    
if __name__=='__main__':
    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_groupnorm.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    t = GroupNormTilingOpt(l,d,'4x4',k)
    r=t.find_optimal_tiling()
