from typing import Optional

from OGOAT.src.Tiler.concat_tiler import ConcatTiler
from OGOAT.src.utils.context import Logger
from matmul_tiler import MatMulTiler
from conv_tiler import ConvTiler
from mha_tiler import MHATiler
from LUT_tiler import LUTTiler
from Elem_wise_ops_tiler import ElemWiseOpsTiler
from broadcast_tiler import BroadcastTiler
from layernorm_tiler import LayerNormTiler
from groupnorm_tiler import GroupNormTiler
from lpnorm_tiler import LpNormTiler
from softmax_tiler import SoftMaxTiler
from dataflow_tiler import DataflowTiler
from OGOAT.src.Tiler.overlay import Overlay


class Tiler:
    def __new__(
        cls, layer, device, overlay, kernel=None, logger: Optional[Logger] = None
    ):
        if layer.orig_op_type == "MatMul":
            return MatMulTiler(layer, device, overlay, kernel)
        elif layer.orig_op_type == "Conv":
            return ConvTiler(layer, device, overlay, kernel)
        elif layer.orig_op_type == "MHA":
            assert logger is not None, "Logger should be provided for MHA tiler"
            return MHATiler(layer, device, overlay, kernel, logger)
        elif layer.orig_op_type.startswith("Concat"):
            return ConcatTiler(layer, device, overlay)
        elif layer.orig_op_type == "PWLA":
            return LUTTiler(layer, device, overlay, kernel)
        elif layer.orig_op_type in ["Add", "Mul"]:
            if (
                layer.op_type.split("_")[2] == "EleWise"
            ):  # NOTE: Only works for QDQ ops because op_type[1] is qdq
                return ElemWiseOpsTiler(layer, device, overlay, kernel)
            elif layer.op_type.split("_")[2] == "BroadCast":
                return BroadcastTiler(layer, device, overlay, kernel)
            else:
                raise ValueError(f"Unknown layer type. OP: {layer.orig_op_type}")
        elif layer.orig_op_type == "LayerNormalization":
            return LayerNormTiler(layer, device, overlay, kernel)
        elif layer.orig_op_type == "GroupNormalization":
            return GroupNormTiler(layer, device, overlay, kernel)
        elif layer.orig_op_type == "LpNormalization":
            return LpNormTiler(layer, device, overlay, kernel)
        elif layer.orig_op_type == "Softmax":
            return SoftMaxTiler(layer, device, overlay, kernel)
        elif layer.orig_op_type in [
            "Concat",
            "Resize",
            "Slice",
            "Transpose",
            "DepthToSpace",
            "Quant",
            "Dequant",
            "Gatherelements",
            "Gather",
            "MaxPool",
        ]:
            return DataflowTiler(layer, device, overlay)
        else:
            raise ValueError("Unknown layer type")


if __name__ == "__main__":

    ov = Overlay("4x4", "MatMul", "M4K1N4")

    import json
    from layer import Layer

    with open("tst.json") as f:
        mdict = json.load(f)
    ld = mdict["tst"]
    ld["in_act_residency"] = "L3"
    ld["out_act_residency"] = "L3"

    l = Layer(ld)

    from kernel import Kernel
    from device import Device

    d = Device("strix")
    k = Kernel(l)

    t = Tiler(l, d, ov, k)
    t.calculate_memtile_tilings()
    t.check_valid_memtile_tilings()
    t.calculate_array_tilings()
    t.check_core_constraints()
