import numpy as np
import warnings

warnings.filterwarnings('ignore')

from OGOAT.src.Tiler.utils import sigmoid

class MHACostModel:

    def __init__(self, tilings):

        self.tilings = tilings
        self.mha_mode = '2p1' if '2p1' in  tilings.layer.op_type else '3p0'
        self.kernel_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.array_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.layer_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.total_layer_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}

    def calculate_kernel_cycles(self):

        outerloop_overhead = np.array(self.tilings.kernel.outerloop_overheads)[1] # idx 1 ==> Assume TDM always
        Mgran = self.tilings.Mgran
        Kgran = self.tilings.Kgran
        Ngran = self.tilings.Ngran

        
        for sched, core_subvols in self.tilings.valid_core_subvols.items():
            for mem_sub_id, core_subvol in core_subvols.items():
                M = core_subvol['ifm'][:,0]
                K = core_subvol['ifm'][:,1]
                N = core_subvol['wgt'][:,1]

                subvol_innerloop_iters = np.ceil(K/Kgran)
                subvol_outerloop_iters = np.ceil(M/Mgran) * np.ceil(N/Ngran)

                
                self.kernel_cycles[sched][mem_sub_id] = self.tilings.kernel.main_overhead + subvol_outerloop_iters * ( outerloop_overhead + subvol_innerloop_iters * self.tilings.kernel.inner_loop_cycles)



    def calculate_array_cycles(self):
        # stream_bandwidth = self.tilings.device.max_stream_bw

        for sched, core_subvols in self.tilings.valid_core_subvols.items():
            for mem_sub_id, core_subvol in core_subvols.items():
                ifm_cycles = np.prod(core_subvol['ifm'],1) * self.tilings.layer.in_bytes / self.tilings.device.max_stream_bw
                wgt_cycles = np.prod(core_subvol['wgt'],1) * self.tilings.layer.wgt_bytes / self.tilings.device.max_stream_bw
                ofm_cycles = np.prod(core_subvol['ofm'],1) * self.tilings.layer.out_ofm_bytes / self.tilings.device.max_stream_bw

                kernel_cycles = self.kernel_cycles[sched][mem_sub_id]

                input_cycles = np.maximum(ifm_cycles,wgt_cycles)
                output_cycles = kernel_cycles + ofm_cycles 

                # For a few tdm iteration the ofm latency is well hidden withing the TDM iterations
                # Assuming this as the common case we ignore the pipelining of ofm cycles

                # total_iters = np.prod(self.tilings.valid_core_iters[sched][mem_sub_id],1) - 1 ## -1 for starting and ending iterations
                core_iters = self.tilings.valid_core_iters[sched][mem_sub_id]

                # self.array_cycles[sched][mem_sub_id] = input_cycles + np.maximum(input_cycles, kernel_cycles) * total_iters + output_cycles
                self.array_cycles[sched][mem_sub_id] = np.hstack((np.vstack((ifm_cycles,wgt_cycles,kernel_cycles,ofm_cycles)).T,core_iters))

    def calculate_layer_cycles(self):
        shim_bandwidth = self.tilings.overlay.cols * self.tilings.device.max_stream_bw
        #                                                                    DRAM read BW divided betwee two inputs
        #                                                                                           |
        #                                                                                           V
        dram_read_bandwidth_percol  = min(shim_bandwidth, self.tilings.device.dram_read_bandwidth)  / self.tilings.overlay.cols
        dram_write_bandwidth_percol = min(shim_bandwidth, self.tilings.device.dram_write_bandwidth) / self.tilings.overlay.cols

        for sched, memtile_subvols in self.tilings.valid_memtile_subvols.items():

            for mem_sub_id, array_cycles in self.array_cycles[sched].items():

                #                                                                    This is for derating the dram BW for small transfers
                #                                                                                      |
                #                                                                                      V
                ifm_size = np.prod(memtile_subvols[mem_sub_id]['ifm'])
                ifm_cycles = ifm_size * self.tilings.layer.in_bytes  / dram_read_bandwidth_percol / sigmoid(ifm_size)
                wgt_size = np.prod(memtile_subvols[mem_sub_id]['wgt'])
                wgt_cycles = wgt_size * self.tilings.layer.wgt_bytes / dram_read_bandwidth_percol / sigmoid(wgt_size)
                ofm_size = np.prod(memtile_subvols[mem_sub_id]['ofm'])
                ofm_cycles = ofm_size * self.tilings.layer.out_ofm_bytes / dram_write_bandwidth_percol / sigmoid(ofm_size)

                memtile_cycles =np.vstack((ifm_cycles,wgt_cycles,ofm_cycles)).T

                valid_memtile_cycles=np.tile(memtile_cycles,(len(array_cycles),1))
                valid_memtile_iters = np.tile(self.tilings.valid_memtile_iters[sched][mem_sub_id],(len(array_cycles),1))
                memtile_pingpong = self.tilings.valid_fits_in_memtile[sched][mem_sub_id][self.tilings.valid_core_subvids[sched][mem_sub_id]]
                
                layer_cycles_info = np.hstack((array_cycles,valid_memtile_cycles,valid_memtile_iters,memtile_pingpong)).astype(int)

                self.layer_cycles[sched][mem_sub_id] = layer_cycles_info

                if sched==1:
                    layer_startup_cycles = np.max(layer_cycles_info[:,[7,8]],axis=1,keepdims=True)
                    #                                                                                 This penalizes large Tk options
                    #                                                                                              V
                    array_cycles_per_output = (np.max(layer_cycles_info[:,[0,1,2]],axis=1,keepdims=True) + self.tilings.device.lock_acq_overhead * 3) * layer_cycles_info[:,5:6]
                    ofm_cycles_to_dram = layer_cycles_info[:,3:4] + layer_cycles_info[:,9:10]

                    overlap_array_calc_and_ofm_write = np.maximum(array_cycles_per_output, ofm_cycles_to_dram)

                    cycles_for_one_ifm_shard = overlap_array_calc_and_ofm_write * layer_cycles_info[:,15:16]

                    total_cycles_single_buffer = (layer_startup_cycles + ((layer_cycles_info[:,14:15] - 1) * layer_cycles_info[:,7:8]) + (cycles_for_one_ifm_shard * layer_cycles_info[:,14:15]) ) #/ layer_cycles_info[:,16:17]
                    total_cycles_double_buffer = (layer_startup_cycles + (cycles_for_one_ifm_shard * layer_cycles_info[:,14:15]) ) #/ layer_cycles_info[:,17:18]

                    self.total_layer_cycles[sched][mem_sub_id] = np.hstack((total_cycles_single_buffer,total_cycles_double_buffer)) / memtile_pingpong.astype(int)

                if sched==2:
                    layer_startup_cycles = np.max(layer_cycles_info[:,[7,8]],axis=1,keepdims=True)
                    wgt_cycles_dram_to_core = np.max(layer_cycles_info[:,[1,8]],axis=1,keepdims=True)
                    array_cycles_per_output = ( np.maximum(wgt_cycles_dram_to_core,np.max(layer_cycles_info[:,[0,2]],axis=1,keepdims=True)) + self.tilings.device.lock_acq_overhead * 3) * layer_cycles_info[:,12:13]

                    ofm_cycles_to_dram = layer_cycles_info[:,3:4] + layer_cycles_info[:,9:10]

                    overlap_array_calc_and_ofm_write = np.maximum(array_cycles_per_output, ofm_cycles_to_dram)

                    cycles_for_one_ifm_shard = overlap_array_calc_and_ofm_write * layer_cycles_info[:,15:16]

                    total_cycles_single_buffer = (layer_startup_cycles + ((layer_cycles_info[:,14:15] - 1) * layer_cycles_info[:,7:8]) + (cycles_for_one_ifm_shard * layer_cycles_info[:,14:15]) ) #/ layer_cycles_info[:,16:17]
                    total_cycles_double_buffer = (layer_startup_cycles + (cycles_for_one_ifm_shard * layer_cycles_info[:,14:15]) ) #/ layer_cycles_info[:,17:18]

                    self.total_layer_cycles[sched][mem_sub_id] = np.hstack((total_cycles_single_buffer,total_cycles_double_buffer)) / memtile_pingpong.astype(int)

                if sched==5:
                    read_cycles_from_dram = np.max(layer_cycles_info[:,[0,1,7,8]],axis=1,keepdims=True)
                    array_cycles_per_output = ( np.maximum(read_cycles_from_dram,np.max(layer_cycles_info[:,[2]],axis=1,keepdims=True)) + self.tilings.device.lock_acq_overhead * 3) * layer_cycles_info[:,11:12]

                    ofm_cycles_to_dram = layer_cycles_info[:,3:4] + layer_cycles_info[:,9:10]

                    overlap_array_calc_and_ofm_write = np.maximum(array_cycles_per_output, ofm_cycles_to_dram)

                    total_cycles_double_buffer = overlap_array_calc_and_ofm_write * layer_cycles_info[:,[14]] * layer_cycles_info[:,[15]]

                    self.total_layer_cycles[sched][mem_sub_id] = np.hstack( (np.ones(total_cycles_double_buffer.shape)*np.inf, total_cycles_double_buffer) ) / memtile_pingpong.astype(int)


                if sched==6:
                    layer_startup_cycles = np.max(layer_cycles_info[:,[7,8]],axis=1,keepdims=True)
                    #                                                                                 This penalizes large Tk options
                    #                                                                                              V
                    array_cycles_per_output = (np.max(layer_cycles_info[:,[0,1,2]],axis=1,keepdims=True) + self.tilings.device.lock_acq_overhead * 3) * layer_cycles_info[:,5:6]
                    ofm_cycles_to_dram = layer_cycles_info[:,3:4] + layer_cycles_info[:,9:10]

                    overlap_array_calc_and_ofm_write = np.maximum(array_cycles_per_output, ofm_cycles_to_dram)

                    cycles_for_one_ifm_shard = overlap_array_calc_and_ofm_write * layer_cycles_info[:,15:16]

                    total_cycles_single_buffer = (layer_startup_cycles + ((layer_cycles_info[:,14:15] - 1) * layer_cycles_info[:,7:8]) + (cycles_for_one_ifm_shard * layer_cycles_info[:,14:15]) ) #/ layer_cycles_info[:,16:17]
                    total_cycles_double_buffer = (layer_startup_cycles + (cycles_for_one_ifm_shard * layer_cycles_info[:,14:15]) ) #/ layer_cycles_info[:,17:18]

                    self.total_layer_cycles[sched][mem_sub_id] = np.hstack((total_cycles_single_buffer,total_cycles_double_buffer)) / memtile_pingpong.astype(int)


if __name__=='__main__':
    from overlay import Overlay
    ov=Overlay('4x4','MHA','M1K1N16')

    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_layer.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    from tiler import MHATiler
    t = MHATiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()

    c=MHACostModel(t)
    c.calculate_kernel_cycles()
    c.calculate_array_cycles()
    c.calculate_layer_cycles()
