import numpy as np

from OGOAT.src.Tiler.utils import sigmoid

class RoPECostModel:

    def __init__(self, tilings):

        self.tilings = tilings
        self.kernel_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.array_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.layer_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.total_layer_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}


    def calculate_kernel_cycles(self):
        Ngran = self.tilings.kernel.Ngran

        for sched, core_subvols in self.tilings.valid_core_subvols.items():
            for mem_sub_id, core_subvol in core_subvols.items():
                N = core_subvol['ifm']

                loop_count = N/Ngran

                self.kernel_cycles[sched][mem_sub_id] = loop_count * self.tilings.kernel.inner_loop_cycles + self.tilings.kernel.main_overhead


    def calculate_array_cycles(self):
        for sched, core_subvols in self.tilings.valid_core_subvols.items():
            for mem_sub_id, core_subvol in core_subvols.items():
                if len(core_subvol['ifm'])>0:
                    ifm_cycles = 2 * np.prod(core_subvol['ifm']) * self.tilings.layer.in_bytes / self.tilings.device.max_stream_bw

                    ofm_cycles = np.prod(core_subvol['ofm']) * self.tilings.layer.out_bytes / self.tilings.device.max_stream_bw

                    kernel_cycles = self.kernel_cycles[sched][mem_sub_id][0]
            
                    core_iters = self.tilings.valid_core_iters[sched][mem_sub_id][0]

                    # print((ifm_cycles,kernel_cycles,ofm_cycles,core_iters))
                    # print((ifm_cycles.shape,kernel_cycles.shape,ofm_cycles.shape,core_iters.shape))
                    # print(np.stack((ifm_cycles,kernel_cycles,ofm_cycles,core_iters)))

                    self.array_cycles[sched][mem_sub_id] = np.array((ifm_cycles,kernel_cycles,ofm_cycles,core_iters)).reshape((1,4))
                else:
                    self.array_cycles[sched][mem_sub_id] = np.full(shape=(1,4),fill_value=np.inf)


    def calculate_layer_cycles(self):
        

        shim_bandwidth = self.tilings.overlay.cols * self.tilings.device.max_stream_bw
        #                                                                    DRAM read BW divided betwee two inputs
        #                                                                                           |
        #                                                                                           V
        dram_read_bandwidth_percol  = min(shim_bandwidth, self.tilings.device.dram_read_bandwidth)  / self.tilings.overlay.cols
        dram_write_bandwidth_percol = min(shim_bandwidth, self.tilings.device.dram_write_bandwidth) / self.tilings.overlay.cols

        for sched, memtile_subvols in self.tilings.valid_memtile_subvols.items():
            #                                                                    This is for derating the dram BW for small transfers
            #                                                                                      |
            #                                                                                      V
            ifm_size = np.prod(memtile_subvols['ifm'],1)
            ifm_cycles = 2 * ifm_size * self.tilings.layer.in_bytes  / dram_read_bandwidth_percol / sigmoid(ifm_size)
            ofm_size = np.prod(memtile_subvols['ofm'],1)
            ofm_cycles = ofm_size * self.tilings.layer.out_bytes / dram_write_bandwidth_percol / sigmoid(ofm_size)

            memtile_cycles =np.vstack((ifm_cycles,ofm_cycles)).T

            for mem_sub_id, array_cycles in self.array_cycles[sched].items():
                valid_memtile_cycles=np.tile(memtile_cycles[mem_sub_id],(len(array_cycles),1))
                valid_memtile_iters = np.tile(self.tilings.valid_memtile_iters[sched][mem_sub_id],(len(array_cycles),1))
                memtile_pingpong = np.tile(self.tilings.fits_in_memtile[sched][mem_sub_id],(len(array_cycles),1))
                
                layer_cycles_info = np.hstack((array_cycles,valid_memtile_cycles,valid_memtile_iters,memtile_pingpong))

                self.layer_cycles[sched][mem_sub_id] = layer_cycles_info

                if sched==5:
                    slowest_step_cycles = np.max(layer_cycles_info[:,[0,1,2,4,5]],axis=1,keepdims=True) + 300

                    total_cycles_double_buffer = np.nan_to_num( layer_cycles_info[:,[4]] + slowest_step_cycles * (layer_cycles_info[:,[6]] -1) + layer_cycles_info[:,[5]], nan=np.inf)

                    self.total_layer_cycles[sched][mem_sub_id] = np.hstack( (np.ones(total_cycles_double_buffer.shape)*np.inf, total_cycles_double_buffer) )


if __name__=='__main__':
    from overlay import Overlay
    ov=Overlay('4x4','Add','N16')

    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_elemwise.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    from tiler import RoPETiler
    t = RoPETiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()

    c=RoPECostModel(t)
    c.calculate_kernel_cycles()
    c.calculate_array_cycles()
    c.calculate_layer_cycles()
