from typing import Optional

from OGOAT.src.Tiler.concat_tiling_opt import ConcatTilingOpt
from OGOAT.src.Tiler.layer import Layer
from OGOAT.src.utils.context import Context
from matmul_tiling_opt import MatMulTilingOpt
from conv_tiling_opt import ConvTilingOpt
from mha_tiling_opt import MHATilingOpt
from LUT_tiling_opt import LUTTilingOpt
from Elem_wise_ops_tiling_opt import ElemWiseOpsTilingOpt
from broadcast_tiling_opt import BroadcastTilingOpt
from layernorm_tiling_opt import LayerNormTilingOpt
from groupnorm_tiling_opt import GroupNormTilingOpt
from lpnorm_tiling_opt import LpNormTilingOpt
from softmax_tiling_opt import SoftMaxTilingOpt
from dataflow_tiling_opt import DataflowTilingOpt
from RoPE_tiling_opt import RoPETilingOpt

class TilingOpt:
    def __new__(
            self, layer: Layer, device, overlay_name, kernel, layer_id: str, context: Context
    ):
        # print('op type ----', layer.orig_op_type)
        if layer.orig_op_type in ['Concat', 'Resize', 'Slice', 'Transpose',
                                  'DepthToSpace', "Quant", "Dequant",
                                  'Gatherelements', 'Gather', 'MaxPool']:
            return DataflowTilingOpt(layer, device, overlay_name)
        if layer.orig_op_type == "MatMul":
            return MatMulTilingOpt(layer, device, overlay_name, kernel)
        elif layer.orig_op_type == "Conv":
            return ConvTilingOpt(layer, device, overlay_name, kernel)
        elif layer.orig_op_type == "MHA":
            return MHATilingOpt(layer, device, overlay_name, kernel, layer_id, context)
        elif layer.orig_op_type == "PWLA":
            return LUTTilingOpt(layer, device, overlay_name, kernel)
        elif layer.orig_op_type in ["Add", "Mul"]:
            if "EleWise" in layer.op_type:  # NOTE: Only works for QDQ ops because op_type[1] is qdq
                return ElemWiseOpsTilingOpt(layer, device, overlay_name, kernel)
            elif "BroadCast" in layer.op_type:
                return BroadcastTilingOpt(layer, device, overlay_name, kernel)
            else:
                raise ValueError(
                    f"Unknown layer type. OP: {layer.orig_op_type}"
                )
        elif layer.orig_op_type == "LayerNormalization":
            return LayerNormTilingOpt(layer, device, overlay_name, kernel)
        elif layer.orig_op_type == "GroupNormalization":
            return GroupNormTilingOpt(layer, device, overlay_name, kernel)
        elif layer.orig_op_type == "LpNormalization":
            return LpNormTilingOpt(layer, device, overlay_name, kernel)
        elif layer.orig_op_type == "Softmax":
            return SoftMaxTilingOpt(layer, device, overlay_name, kernel)
        elif layer.orig_op_type == "RoPE":
            return RoPETilingOpt(layer, device, overlay_name, kernel)
        else:
            raise ValueError(f"Unknown layer type. OP: {layer.orig_op_type}")
