import yaml
import os

import numpy as np

from OGOAT.src.Tiler.tiler import Tiler
from cost_model import CostModel
from OGOAT.src.Tiler.overlay import Overlay
from layernorm_tiling_opt import LayerNormTilingOpt
parent_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

class SoftMaxTilingOpt(LayerNormTilingOpt):
    
    def find_optimal_tiling(self):
        self.calculate_tiling_cycles()

        ElemBytes   = self.tilers[0].layer.in_bytes
        OutBytes    = self.tilers[0].layer.out_bytes        
        
        split_type = self.tilers[0].split_type
        ktype = self.tilers[0].ktype
        kgran = [self.tilers[0].vars_dict['Mgran'], self.tilers[0].vars_dict['Ngran']] 
        tdim_layer = [self.tilers[0].tensorshape['ifm'][0].item(), self.tilers[0].tensorshape['ifm'][1].item()]
        core_tiling = {
            'ifm':self.tilers[0].valid_core_subvols['ifm'],
            'ofm':self.tilers[0].valid_core_subvols['ofm']
        }
        core_iters = {
            'ifm':self.tilers[0].core_iters['ifm'],
            'ofm':self.tilers[0].core_iters['ofm']
        }

        memtile_tiling = {
            'ifm':self.tilers[0].valid_memtile_subvols['ifm'],
            'ofm':self.tilers[0].valid_memtile_subvols['ofm']
        }

        shim_tiling = {
            'ifm':self.tilers[0].valid_shimtile_subvols['ifm'],
            'ofm':self.tilers[0].valid_shimtile_subvols['ofm']
        }

        
        memtile_iters = self.tilers[0].memtile_iters.copy()
        
        placement_dict = {}
        self.tilers[0].vars_dict['Msubv'] = np.max(core_tiling['ifm'][:, :, :, 0]).item()
        self.tilers[0].vars_dict['Nsubv'] = np.max(core_tiling['ifm'][:, :, :, 1]).item()
        for buff, bankdict in self.tilers[0].kernel.placement_constraints.items():
            placement_dict[buff] = {k:int(eval(str(v),self.tilers[0].vars_dict)) for k,v in bankdict.items()}
        
        
        MemtileActSizeIfmS2MM = np.prod(memtile_tiling['ifm']['s2mm'], axis=2)*ElemBytes
        MemtileActSizeIfmMM2S = np.prod(memtile_tiling['ifm']['mm2s'], axis=-1)*ElemBytes
        MemtileOutSizeOfmMM2S = np.prod(memtile_tiling['ofm']['mm2s'], axis=2)*OutBytes
        MemtileOutSizeOfmS2MM = np.prod(memtile_tiling['ofm']['s2mm'], axis=-1)*OutBytes
                
        memtile_sizes = {
            'ifm': int(np.max(MemtileActSizeIfmS2MM)),
            'ofm': int(np.max(MemtileOutSizeOfmS2MM))
        }

   
        schedule_dict = {
            'ifm': 'stream',
            'ofm': 'stream',
            'ifm_ping_pong': True,
            'ofm_ping_pong': True
        }
        
        shim_sizes = {
            'ifm': int(np.sum((shim_tiling['ifm'][:,0,0])) * np.sum((shim_tiling['ifm'][0,:,1])) * self.tilers[0].layer.in_bytes),
            'ofm': int(np.sum((shim_tiling['ofm'][:,0,0])) * np.sum((shim_tiling['ofm'][0,:,1])) * self.tilers[0].layer.out_bytes),
        }

        if(split_type in ['ROW_SPLIT']):
            tensor_padded_2D = [int(np.sum((shim_tiling['ifm'][0,:,0]))) , int(np.sum((shim_tiling['ifm'][0,0,1])))]
        else:
            tensor_padded_2D = [int(np.sum((shim_tiling['ifm'][:,0,0]))) , int(np.sum((shim_tiling['ifm'][0,:,1])))]
        
        layerdict = vars(self.tilers[0].layer)
        layerdict['in_ifm_shape']   = tensor_padded_2D #layerdict.pop('in_act_shape')
        layerdict['out_ofm_shape']  =  tensor_padded_2D #layerdict.pop('out_act_shape')

        layerdict['in_ifm_datatype'] = layerdict.pop('in_datatype')
        layerdict['out_ofm_datatype'] = layerdict.pop('out_datatype')
        
        layerdict['in_ifm_bytes'] = layerdict.pop('in_bytes')
        layerdict['out_ofm_bytes'] = layerdict.pop('out_bytes')
        

        tensor_padded    = layerdict.get('in_act_shape')[:-1] + [tensor_padded_2D[-1]]
        dis_dq = layerdict["attributes"]["disable_dq0"][0]
        dis_q = layerdict["attributes"]["disable_q"][0]
        
        original_dimensions = [
            {"input0": 
                {
                "dims": layerdict.get('in_act_shape'),  #shape [n,h,w,c]
                }
            },
           {"output0": 
               {
                "dims": layerdict.get('out_act_shape'),  #shape [n,h,w,c]
                }
            }
        ]

        host_layer_padding = [
            {"input0": 
                {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "value": [None]*len(tensor_padded) if dis_dq else ["zp_i0"]*len(tensor_padded)
                }
            },
           {"output0": 
               {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "value": [0]*len(tensor_padded) if dis_q else ["zp_o0"]*len(tensor_padded)
                }
            }
        ]
        
        dma_layer_padding = [
            {"input0": 
                {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "values": [None]*len(tensor_padded),
                "channels": [0, 1, 2, 3]
                }
            },
           {"output0": 
               {
                "dims": tensor_padded,  #shape [n,h,w,c]
                "values": [0]*len(tensor_padded) if dis_q else ["zp_o0"]*len(tensor_padded),
                "channels": [5]
                }
            }
        ]
        shim_tiling_list = {
                'ifm':(self.tilers[0].valid_shimtile_subvols['ifm'].tolist()),
                'ofm':(self.tilers[0].valid_shimtile_subvols['ofm'].tolist())
        }
        shim_sizes_list = self.npdict2list(shim_sizes)
        core_tiling_list = {
                'ifm':self.tilers[0].valid_core_subvols['ifm'].tolist(),
                'ofm':self.tilers[0].valid_core_subvols['ofm'].tolist()
        }
        core_iters_list = self.npdict2list(core_iters)
        memtile_tiling_list = self.npdict2list(memtile_tiling)
        memtile_sizes_list     = self.npdict2list(memtile_sizes)
        memtile_iters_list  = self.npdict2list(memtile_iters)  
        
        test_cpp_name = self.tilers[0].kernel.testbench_args['HostName']
        tb_cflags = {
                self.tilers[0].kernel.testbench_args['CFLAGS'][0] : 1, 
                self.tilers[0].kernel.testbench_args['CFLAGS'][1] : split_type, 
		self.tilers[0].kernel.testbench_args['CFLAGS'][2] : tensor_padded_2D[0], 
                self.tilers[0].kernel.testbench_args['CFLAGS'][3] : tensor_padded_2D[1], 
                self.tilers[0].kernel.testbench_args['CFLAGS'][4] : tdim_layer[1], #Nlayer 
                self.tilers[0].kernel.testbench_args['CFLAGS'][5] : 0 if dis_dq else 1, # Enable DQ
                self.tilers[0].kernel.testbench_args['CFLAGS'][6] : 0 if dis_q else 1, # Enable Q
                self.tilers[0].kernel.testbench_args['CFLAGS'][7] : layerdict['in_ifm_bytes'],
                self.tilers[0].kernel.testbench_args['CFLAGS'][8] : layerdict['out_ofm_bytes']
                }
        # using a empirical estimate as the cost model is inaccurate
        opt_cycles = np.prod(layerdict['out_ofm_shape'],dtype=np.int64)/6                     
        tiling_params = {
        
            'core_tile_params': {'subvols':core_tiling_list, 'iters': core_iters_list},
            'mem_tile_params': {'subvols':memtile_tiling_list, 'iters': memtile_iters_list, 'sizes':memtile_sizes_list},
            'shim_tile_params': {'subvols':shim_tiling_list,'sizes':shim_sizes_list},
            'scheduling': schedule_dict,
            'original_dimensions': original_dimensions,
            'host_layer_padding': host_layer_padding,
            'dma_layer_padding': dma_layer_padding,
            'kernel_info': {'kernel_gran': kgran,'split_type':split_type, 'placement_constraints': placement_dict, 'tdim_layer': tdim_layer },
            'overlay_info': {'overlay':self.overlay, 'mode':self.modes[0],'aieinst': self.tilers[0].overlay.aieinst, 'shape':{'col':self.tilers[0].overlay.cols,'row':self.tilers[0].overlay.rows}},
            'layer_info': layerdict,
            'testbench_args': {'HOST_NAME': test_cpp_name, 'COMPILE_FLAGS': tb_cflags},
            'cycle_counts': {'layer_cycles': opt_cycles}
        }

        return tiling_params
    
