import numpy as np
from math import floor, sqrt, gcd
import warnings
from OGOAT.src.Tiler.overlay import Overlay
from dataflow.conv.conv_common import iceil

class GroupNormTiler:

    def __init__(self, layer, device, overlay, kernel):

        self.layer = layer
        self.overlay = overlay
        self.device = device
        self.kernel = kernel

        self.kernel_granularity = {
            'ifm':np.array([kernel.Mgran, kernel.Ngran]),
            'ofm':np.array([kernel.Mgran, kernel.Ngran])
            }


        schedule_list = [1]

        self.memtile_subvols = {k:{} for k in schedule_list}
        self.memtile_sublayers = {k:{} for k in schedule_list}
        self.memtile_iters = {k:{} for k in schedule_list}

        self.fits_in_memtile = {k:{} for k in schedule_list}
        self.valid_fits_in_memtile = {k:{} for k in schedule_list}

        self.valid_memtile_subvols = {k:{} for k in schedule_list}
        self.valid_memtile_sublayers = {k:{} for k in schedule_list}
        self.valid_memtile_iters = {k:{} for k in schedule_list}

        self.core_subvols = {k:{} for k in schedule_list}
        self.core_iters= {k:{} for k in schedule_list}
        
        self.core_validity_checks = {k:{} for k in schedule_list}
        self.valid_core_subvols = {k:{} for k in schedule_list}
        self.valid_core_iters = {k:{} for k in schedule_list}
        self.tensorshape = {
            'ifm' : np.array([self.layer.in_act_shape[-3] * self.layer.in_act_shape[-2], self.layer.in_act_shape[-1]]),
            'ofm' : np.array([self.layer.out_act_shape[-3] * self.layer.out_act_shape[-2], self.layer.out_act_shape[-1]])
        }
        
        #self.tensorshape = {
        #    'ifm' : np.array([4096, 960]),
        #    'ofm' : np.array([4096, 960])
        #}
        
        self.vars_dict = {
            'ifm_bytes': layer.in_bytes,
            'ofm_bytes': layer.out_bytes,
            ## check other constraints
            'Mgran': self.kernel.Mgran,
            'Ngran': self.kernel.Ngran
        }
        
        self.bytes_per_word = device.bytes_per_word

        in_act_shape_reshaped = self.layer.in_act_shape[-2:]
        out_act_shape_reshaped = self.layer.out_act_shape[-2:]

    def lcm(self, a, b):
        return abs(a*b) // gcd(a, b)
    def is_power_of_two(self, n):
        if n == 0:
            return False
        return (n & (n - 1)) == 0
    def calculate_memtile_tilings(self):
        pass
                
    def check_valid_memtile_tilings(self):
        self.valid_memtile_subvols = self.memtile_subvols
        self.valid_memtile_sublayers = self.memtile_sublayers
        self.valid_memtile_iters = self.memtile_iters

    def calculate_array_tilings(self):
        core_subvols = {}
        core_iters = {}
        NG  = 32
        core_vec_len = 16
        Mgpn = self.tensorshape['ifm'][0].item()
        Ngpn = self.tensorshape['ifm'][1].item()
        
        assert(Ngpn % NG == 0)
        col_per_grp = Ngpn // NG
        in_bytes  = self.vars_dict['ifm_bytes'] 
        out_bytes = self.vars_dict['ofm_bytes'] 
        assert((col_per_grp*in_bytes) % self.bytes_per_word == 0)
        assert((col_per_grp*out_bytes) % self.bytes_per_word == 0)
        
        aie_inst = self.overlay.cols // self.overlay.rows
        
        aie_cols = self.overlay.cols // aie_inst
        aie_rows = self.overlay.rows
        isNpwrOfTwo = self.is_power_of_two(Ngpn)
        if isNpwrOfTwo:
            max_ifm_samp_norm = 2048
            Mgpn_core = max_ifm_samp_norm // Ngpn
        else:
            Mgpn_core = self.lcm(col_per_grp, core_vec_len) // col_per_grp
            max_ifm_samp_norm = 2560
        Mgpn_min_gran = Mgpn_core * aie_cols * aie_rows
        Mgpn_padd = iceil(Mgpn, Mgpn_min_gran)
        print("Mgpn_min_gran:-------------------", Mgpn_min_gran)
        print("Mgpn_padd:-------------------", Mgpn_padd)
        # Mgpn_padd = 512
        # Handle padding
        self.padded_shape = {
            'ifm': [Mgpn_padd, Ngpn],
            'ofm': [Mgpn_padd, Ngpn]
        }

        assert((Mgpn_core*col_per_grp)%core_vec_len == 0)
        assert(Mgpn_core*aie_cols*aie_rows<=Mgpn_padd)
        Ngpn_core = Ngpn // aie_inst

        Mgpn_reshape = Mgpn_core*Ngpn_core // (NG*core_vec_len//aie_inst)
        iter_fact_norm = 3 if Mgpn_core*Ngpn_core > max_ifm_samp_norm else 1
        samp_core_norm = Mgpn_core*Ngpn_core // iter_fact_norm

        #print("MgpnCore, NgpnCore:", Mgpn_core, Ngpn_core)
        #print("iter_fact_norm:", iter_fact_norm)
        #print("Mgpn_reshape:", Mgpn_reshape)

        Mlrn_mem = (Mgpn_core*aie_rows)
        Nlrn_mem = Ngpn_core

        Mlrn_shim = Mgpn_padd // aie_cols
        Nlrn_shim = Ngpn // aie_inst

        self.gpn_overlay = [aie_inst, aie_cols, aie_rows]
        core_subvols = {
            'ifm': [Mgpn_core, Ngpn_core],
            'ofm': [1, samp_core_norm]
        }
        self.core_subvols = core_subvols
        tsubv = (Mlrn_shim) // Mlrn_mem
        tsubv_norm = tsubv*iter_fact_norm
        core_iters = {
            'ifm': [tsubv],
            'ofm': [tsubv_norm]
        }
        self.core_iters = core_iters

        self.memtile_subvols = {
            'ifm': [Mlrn_mem, Nlrn_mem],
            'ofm': [Mlrn_mem, Nlrn_mem] 
            }
        self.memtile_iters = {
            'ifm': [tsubv, 1] ,
            'ofm': [tsubv, 1] 
            }
        self.memtile_sublayers = {
            'ifm': self.memtile_subvols['ifm'] ,
            'ofm': self.memtile_subvols['ofm'] 
            }

    def check_core_constraints(self):
        diff_sv = self.core_subvols['ifm'] - self.kernel_granularity['ifm']
        if np.any(diff_sv<0):
            warnings.warn("Core Sub volume is voilating the minimum granularity")
        self.valid_core_subvols = self.core_subvols 
        self.valid_core_iters = self.core_iters

if __name__=='__main__':
    ov=Overlay('4x4','GroupNormalization','N16')

    import json
    from layer import Layer
    with open('src/Tiler/tst_groupnorm.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)


    t = GroupNormTiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()
