import numpy as np
from OGOAT.src.Tiler.overlay import Overlay

class LpNormCostModel:

    def __init__(self, tilings):

        self.tilings = tilings
        self.kernel_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.array_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.layer_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}
        self.total_layer_cycles = {k:{} for k in self.tilings.valid_core_subvols.keys()}


    def calculate_kernel_cycles(self):
        pass


    def calculate_array_cycles(self):
        pass


    def calculate_layer_cycles(self):
        pass


if __name__=='__main__':
    
    ov=Overlay('4x4','LpNormalization','N16')

    import json
    from layer import Layer
    with open('OGOAT/src/Tiler/tst_layernorm.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    from tiler import LayerNormTiler
    t = LayerNormTiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()

    c=LpNormCostModel(t)
    c.calculate_kernel_cycles()
    c.calculate_array_cycles()
    c.calculate_layer_cycles()
