from matmul_cost_model import MatMulCostModel
from conv_cost_model import ConvCostModel
from mha_cost_model import MHACostModel
from LUT_cost_model import LUTCostModel
from Elem_wise_ops_cost_model import ElemWiseOpsCostModel
from broadcast_cost_model import BroadcastCostModel
from layernorm_cost_model import LayerNormCostModel
from groupnorm_cost_model import GroupNormCostModel
from lpnorm_cost_model import LpNormCostModel
from softmax_cost_model import SoftMaxCostModel
from OGOAT.src.Tiler.overlay import Overlay

class CostModel:

    def __new__(cls, tilings):

        if tilings.layer.orig_op_type == 'MatMul':
            return MatMulCostModel(tilings)
        elif tilings.layer.orig_op_type == 'Conv':
            return ConvCostModel(tilings)
        elif tilings.layer.orig_op_type == 'MHA':
            return MHACostModel(tilings)
        elif tilings.layer.orig_op_type == 'PWLA':
            return LUTCostModel(tilings)
        elif (tilings.layer.orig_op_type in ['Add','Mul']):
            if tilings.layer.op_type.split('_')[2] == 'EleWise': #TODO: check opcode for elemwise matmul etc
                return ElemWiseOpsCostModel(tilings)
            elif tilings.layer.op_type.split('_')[2] == 'BroadCast':
                return BroadcastCostModel(tilings)
            else:
                raise ValueError(f'Unknown layer type. OP: {tilings.layer.orig_op_type}')
        elif tilings.layer.orig_op_type == 'LayerNormalization':
            return LayerNormCostModel(tilings)
        elif tilings.layer.orig_op_type == 'GroupNormalization':
            return GroupNormCostModel(tilings)
        elif tilings.layer.orig_op_type == 'LpNormalization':
            return LpNormCostModel(tilings)
        elif tilings.layer.orig_op_type == 'Softmax':
            return SoftMaxCostModel(tilings)
        else:
            raise ValueError(f'Unknown layer type. OP: {tilings.layer.orig_op_type}')
        

if __name__=='__main__':
    
    ov=Overlay('4x4','MatMul','M4K1N4')

    import json
    from layer import Layer
    with open('tst.json') as f:
        mdict = json.load(f)
    ld=mdict['tst']
    ld['in_act_residency']='L3'
    ld['out_act_residency']='L3'

    l = Layer(ld)

    from kernel import Kernel
    from device import Device
    d = Device('strix')
    k = Kernel(l)

    from tiler import Tiler
    t = Tiler(l,d,ov,k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()

    c=CostModel(t)
    c.calculate_kernel_cycles()
    c.calculate_array_cycles()
    c.calculate_layer_cycles()
