"""
This script defines the schema for L2 fusion tiling JSON
"""

from __future__ import annotations

import re
from abc import ABC, abstractmethod
from enum import Enum
from functools import lru_cache
from typing import Any, Dict, List, Sequence, Set, Tuple
from pathlib import Path

import onnx
from onnx import AttributeProto, ModelProto
from pydantic import BaseModel

# pylint: disable=import-error,redefined-outer-name,no-name-in-module
from graph.allocation_types import AllocationResult, MemoryConfig, TensorAllocation
from graph.dag import descoped_id, scoped_id
from graph.tensor_types import (
    Operation,
    Tensor,
    TensorLifetime,
    TensorLocation,
    get_padded_shape_rev,
)
from graph.utilities import subclass_where, logger
from graph.runtime_ops import RuntimeOpType
from buildscripts.common import OperatorsRegistry

Alloc = Tuple[AllocationResult, TensorAllocation]
DictAlloc = Dict[str, Alloc]


class NpuMetadata(BaseModel):
    """Operator metadata"""
    is_compilable: bool  # Is the operator compilable?


class OpSupport(ABC):
    """Operator support related functions"""

    @staticmethod
    @lru_cache(maxsize=1)
    def get_supported_op_schedules():
        """Get list of supported operator schedules"""
        return set(OperatorsRegistry.get_operators().keys())

    @staticmethod
    def is_op_noop_or_runtime(op_type: str) -> bool:
        """Check if operator is a noop or runtime operation"""
        return re.search(r".+_(noop|runtime)(?:_|$)", op_type) is not None

    @staticmethod
    def is_op_noop(op_type: str) -> bool:
        """Check if operator is a noop operation"""
        return re.search(r".+_(noop)(?:_|$)", op_type) is not None

    @staticmethod
    def is_op_runtime(op_type: str) -> bool:
        """Check if operator is a runtime operation."""
        return RuntimeOpType.is_runtime_op(op_type)

    @staticmethod
    def is_op_schedule_supported(op_type: str) -> bool:
        """Check if an operator schedule is supported"""
        return (op_type in OpSupport.get_supported_op_schedules()) or OpSupport.is_op_noop_or_runtime(op_type)


class BaseIO(BaseModel):
    """Common parameters for an ONNX operator tiled on L2"""

    enable_L2_fusion: bool  # should we enable L2 fusion for this op?
    load_input_from_ddr: bool  # should we load the input tensor from DDR?
    name: str  # name of this operator
    output_addr: Dict[int, int]  # output tensor address per AIE4-column
    output_name: str  # name of the output tensor
    prm_addr: List[Tuple[int, int]]  # layer parameter address per AIE4-column
    store_output_to_ddr: bool  # should we store the output tensor to DDR?
    wgt_addr: List[Tuple[int, Tuple[int, int]]]  # weight address per AIE4-column


class BaseOp(ABC):
    """Base class that defines properties common to all operators"""

    @classmethod
    @abstractmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> BaseOp:
        """Factory function to build an instance of self"""

    @staticmethod
    def get_attributes(
        attributes: Sequence[AttributeProto],
    ) -> Dict[str, AttributeProto]:
        """Get all attributes of the ONNX operator."""
        attrs = {}
        for attr in attributes:
            attrs[attr.name] = attr
        return attrs

    @staticmethod
    def get_allocs(
        model_ios: Tuple[Set[str], Set[str]], allocs: DictAlloc, op: Operation,
        strict: bool = True
    ) -> Tuple[List[Alloc], Alloc]:
        """Get input and output tensor allocations for this operator."""
        input_allocs = []
        output_alloc: Any = tuple()
        model_inputs, model_outputs = model_ios

        for inp in op.inputs:
            if inp not in allocs and not strict:
                continue
            alloc = allocs[inp]
            if alloc[1].tensor.size != 0 or alloc[1].tensor.id in model_inputs:
                input_allocs.append(alloc)

        for output in op.outputs:
            if output not in allocs and not strict:
                continue
            alloc = allocs[output]
            if alloc[1].tensor.size != 0 or alloc[1].tensor.id in model_outputs:
                output_alloc = alloc

        if output_alloc == tuple():
            assert len(op.outputs) == 1
            lst = list(model_outputs)
            entry = next((op.outputs[0] for e in lst if op.outputs[0].startswith(e)), None)
            if entry not in allocs and not strict:
                return input_allocs, (AllocationResult.DEALLOCATED, TensorAllocation.empty())
            output_alloc = allocs[entry]

        return input_allocs, output_alloc

    @staticmethod
    def should_enable_fusion(input_allocs: List[Alloc], output_alloc: Alloc) -> bool:
        """Should we enable L2 fusion for this operator?"""
        return all(
            map(lambda alloc: alloc[0] != AllocationResult.SPILLED, input_allocs)
        ) and (
            output_alloc is None
            or len(output_alloc) == 0
            or output_alloc[0] != AllocationResult.SPILLED
        )

    @staticmethod
    def get_shape(alloc: Alloc, ndim: int | None = None) -> Tuple[int, ...]:
        """Get shape of the tensor"""
        if not alloc:
            return ()
        if ndim is None:
            # assert alloc[1].tensor.shape[0] == 1, f"Expecting batch dimension = 1, found {alloc[1].tensor.shape[0]}"
            return alloc[1].tensor.shape[1:]
        shape = tuple(alloc[1].tensor.shape)
        if len(shape) < ndim:
            shape = (1,) * (ndim - len(shape)) + shape
        elif len(shape) > ndim:
            assert all(s == 1 for s in shape[: len(shape) - ndim])
            shape = shape[len(shape) - ndim:]
        return shape

    @staticmethod
    def should_load_input_from_ddr(input_allocs: List[Alloc]) -> bool:
        """Should we load inputs from DDR for this operator?"""
        return any(
            map(lambda alloc: alloc[0] == AllocationResult.SPILLED, input_allocs)
        )

    @staticmethod
    def get_io_name(alloc: Alloc) -> str:
        """Get name of the tensor"""
        return descoped_id(alloc[1].tensor.id)

    @staticmethod
    def should_store_output_to_ddr(output_alloc: Alloc) -> bool:
        """Should we store outputs to DDR for this operator?"""
        return output_alloc[0] == AllocationResult.SPILLED

    @staticmethod
    def get_addr(alloc: Alloc) -> Dict[int, int]:
        """Get address of the memory where this tensor was allocated"""
        # Spilled
        memtile_size = 3 * (2**20)
        if alloc is None or alloc[1].block is None:
            return {0: 0}
        tile = alloc[1].block.start // memtile_size
        local_offset = alloc[1].block.start % memtile_size
        return {tile: local_offset}

    @staticmethod
    def get_param_weight_addr(
        memory_config: MemoryConfig,
    ) -> Tuple[List[Tuple[int, int]], List[Tuple[int, Tuple[int, int]]]]:
        """Get address of the layer parameters and weights"""
        memtile_size = 3 * (2**20)
        prm, wgt = memory_config.params, memory_config.weights
        prm_addr, wgt_addr = [], []
        for p in prm:
            tile = p.start // memtile_size
            local_offset = p.start % memtile_size
            prm_addr.append((tile, local_offset))
        for w in wgt:
            tile_ping = w.ping.start // memtile_size
            local_offset_ping = w.ping.start % memtile_size
            tile_pong = w.pong.start // memtile_size
            local_offset_pong = w.pong.start % memtile_size
            assert tile_ping == tile_pong, "WGT has to be on the same tile"
            wgt_addr.append((tile_ping, (local_offset_ping, local_offset_pong)))
        return prm_addr, wgt_addr


class ConvTypes(str, Enum):
    """Defines types of convolution operators supported by the backend"""

    DWC = "dwc"
    CONV_NOQDQ_A8W8 = "conv_noqdq_a8w8"


class Conv(BaseOp, BaseIO):
    """Represents a convolution operation"""

    act_type: int  # Is it RELU or not?
    dilation: Tuple[int, int]  # (Dy, Dx)
    input: Tuple[int, int, int]  # (Ci, Yi, Xi)
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the input tensor
    kernel: Tuple[int, int]  # (Ky, Kx)
    op: str  # type of this operator
    output: Tuple[int, int, int]  # (Co, Yo, Xo)
    pad: Tuple[int, int]  # (Py, Px)
    stride: Tuple[int, int]  # (Sy, Sx)
    padded_input: Tuple[int, int, int]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, int, int]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Conv:
        attrs = BaseOp.get_attributes(attributes)
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return Conv(
            act_type=(attrs["is_relu"].i if "is_relu" in attrs else 0),
            dilation=(attrs["dilations"].ints if "dilations" in attrs else (1, 1)),  # type: ignore
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=BaseOp.get_shape(input_alloc),
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            kernel=(attrs["kernel_shape"].ints[:2] if "kernel_shape" in attrs else (1, 1)),  # type: ignore
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=BaseOp.get_shape(output_alloc),
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            pad=(attrs["pads"].ints[:2] if "pads" in attrs else (0, 0)),  # type: ignore
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            stride=(attrs["strides"].ints if "strides" in attrs else (1, 1)),  # type: ignore
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(BaseOp.get_shape(input_alloc)),
            padded_output=get_padded_shape_rev(BaseOp.get_shape(output_alloc)),
        )


class ConvAddTypes(str, Enum):
    """Defines types of convolution+add operators supported by the backend"""

    DWC = "dwc"
    CONVADD_NOQDQ_A8W8 = "convadd_noqdq_a8w8"


class ConvAdd(BaseOp, BaseIO):
    """Represents a convolution operation"""

    act_type: int  # Is it RELU or not?
    dilation: Tuple[int, int]  # (Dy, Dx)
    input0: Tuple[int, int, int]  # (Ci, Yi, Xi)
    input1: Tuple[int, int, int]  # (Cb, Yb, Xb)
    input0_addr: Dict[int, int]  # first input tensor address per AIE4-column
    input0_name: str  # name of the first input tensor
    input1_addr: Dict[int, int]  # second input tensor address per AIE4-column
    input1_name: str  # name of the second input tensor
    kernel: Tuple[int, int]  # (Ky, Kx)
    op: str  # type of this operator
    output: Tuple[int, int, int]  # (Co, Yo, Xo)
    pad: Tuple[int, int]  # (Py, Px)
    stride: Tuple[int, int]  # (Sy, Sx),
    padded_input0: Tuple[int, int, int]  # round-up input C-dim to nearest 64 multiple
    padded_input1: Tuple[int, int, int]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, int, int]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> ConvAdd:
        attrs = BaseOp.get_attributes(attributes)
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input0_alloc, input1_alloc = input_allocs

        return ConvAdd(
            act_type=(attrs["is_relu"].i if "is_relu" in attrs else 0),
            dilation=(attrs["dilations"].ints if "dilations" in attrs else (1, 1)),  # type: ignore
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input0=BaseOp.get_shape(input0_alloc),
            input1=BaseOp.get_shape(input1_alloc),
            input0_addr=BaseOp.get_addr(input0_alloc),
            input1_addr=BaseOp.get_addr(input1_alloc),
            input0_name=BaseOp.get_io_name(input0_alloc),
            input1_name=BaseOp.get_io_name(input1_alloc),
            kernel=(attrs["kernel_shape"].ints[:2] if "kernel_shape" in attrs else (1, 1)),  # type: ignore
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=BaseOp.get_shape(output_alloc),
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            pad=(attrs["pads"].ints[:2] if "pads" in attrs else (0, 0)),  # type: ignore
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            stride=(attrs["strides"].ints if "strides" in attrs else (1, 1)),  # type: ignore
            wgt_addr=wgt_addr,
            padded_input0=get_padded_shape_rev(BaseOp.get_shape(input0_alloc)),
            padded_input1=get_padded_shape_rev(BaseOp.get_shape(input1_alloc)),
            padded_output=get_padded_shape_rev(BaseOp.get_shape(output_alloc)),
        )


class MaxPoolTypes(str, Enum):
    """Defines types of max-pool operators supported by the backend"""

    MAXPOOL_NOQDQ_A8 = "maxpool_noqdq_a8"


class MaxPool(BaseOp, BaseIO):
    """Represents a max-pool operation"""

    ceil_mode: int  # Whether to use ceil (1) or floor (0) to compute the output shape
    input: Tuple[int, int, int]  # (Ci, Yi, Xi)
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the input tensor
    kernel: Tuple[int, int]  # (Ky, Kx)
    op: str  # type of this operator
    output: Tuple[int, int, int]  # (Co, Yo, Xo)
    pad: Tuple[int, int]  # (Py, Px)
    stride: Tuple[int, int]  # (Sy, Sx)
    padded_input: Tuple[int, int, int]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, int, int]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> MaxPool:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        attrs = BaseOp.get_attributes(attributes)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return MaxPool(
            ceil_mode=(attrs["ceil_mode"].i if "ceil_mode" in attrs else 0),
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=BaseOp.get_shape(input_alloc),
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            kernel=attrs["kernel_shape"].ints[:2],  # type: ignore
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=BaseOp.get_shape(output_alloc),
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            pad=attrs["pads"].ints[:2],  # type: ignore
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            stride=attrs["strides"].ints,  # type: ignore
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(BaseOp.get_shape(input_alloc)),
            padded_output=get_padded_shape_rev(BaseOp.get_shape(output_alloc)),
        )


class AddTypes(str, Enum):
    """Defines types of add operators supported by the backend"""

    ADD_NOQDQ_A8 = "add_noqdq_a8"


class Add(BaseOp, BaseIO):
    """Represents an add operation"""

    input0: Tuple[int, ...]
    input1: Tuple[int, ...]
    input0_addr: Dict[int, int]  # first input tensor address per AIE4-column
    input1_addr: Dict[int, int]  # second input tensor address per AIE4-column
    input0_name: str  # name of the first input tensor
    input1_name: str  # name of the second input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input0: Tuple[int, ...]  # round-up input0 C-dim to nearest 64 multiple
    padded_input1: Tuple[int, ...]  # round-up input1 C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Add | GenericUnaryOp:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        if len(input_allocs) != 2:
            new_op = type("Add", (GenericUnaryOp, ), {})
            return new_op.build(memory_config, model_ios, allocs, name, attributes, op)
        input0_alloc, input1_alloc = input_allocs

        return Add(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input0=input0_alloc[1].tensor.shape,  # type: ignore
            input0_addr=BaseOp.get_addr(input0_alloc),
            input1=input1_alloc[1].tensor.shape,  # type: ignore
            input1_addr=BaseOp.get_addr(input1_alloc),
            input0_name=BaseOp.get_io_name(input0_alloc),
            input1_name=BaseOp.get_io_name(input1_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,  # type: ignore
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input0=get_padded_shape_rev(input0_alloc[1].tensor.shape),
            padded_input1=get_padded_shape_rev(input1_alloc[1].tensor.shape),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),
        )


class GemmTypes(str, Enum):
    """Defines types of GEMM operators supported by the backend"""

    GEMM = "gemm"


class Gemm(BaseOp, BaseIO):
    """Represents a general matrix multiplication operation"""

    input: Tuple[int, int, int]  # (Ci, Yi, Xi)
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the input tensor
    op: str  # type of this operator
    output: Tuple[int, int, int]  # (Co, Yo, Xo)
    padded_input: Tuple[int, int, int]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, int, int]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Gemm:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return Gemm(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=BaseOp.get_shape(input_alloc, 3),
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,  # type: ignore
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(BaseOp.get_shape(input_alloc, 3)),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),  # type:
        )


class GlobalAveragePoolTypes(str, Enum):
    """Defines types of global average pooling operators supported by the backend"""

    GAP = "gap"


class GlobalAveragePool(BaseOp, BaseIO):
    """Represents a Global Average Pool operation"""

    input: Tuple[int, int, int]  # (Ci, Yi, Xi)
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the input tensor
    op: str  # type of this operator
    output: Tuple[int, int, int]  # (Co, Yo, Xo)
    padded_input: Tuple[int, int, int]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, int, int]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> GlobalAveragePool:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return GlobalAveragePool(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=BaseOp.get_shape(input_alloc),
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=BaseOp.get_shape(output_alloc),
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(BaseOp.get_shape(input_alloc)),
            padded_output=get_padded_shape_rev(BaseOp.get_shape(output_alloc)),
        )


# elementwise operators


class MulTypes(str, Enum):
    """Defines types of multiplication operators supported by the backend"""

    MUL = "mul"


class Mul(BaseOp, BaseIO):
    """Represents an element-wise multiplication operation"""

    input: Tuple[int, ...]
    input0_addr: Dict[int, int]  # first input tensor address per AIE4-column
    input1_addr: Dict[int, int]  # second input tensor address per AIE4-column
    input0_name: str  # name of the first input tensor
    input1_name: str  # name of the second input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input: Tuple[int, ...]  # round-up input0 C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Mul | GenericUnaryOp:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        if len(input_allocs) != 2:
            new_op = type("Mul", (GenericUnaryOp, ), {})
            return new_op.build(memory_config, model_ios, allocs, name, attributes, op)
        input0_alloc, input1_alloc = input_allocs

        return Mul(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=input0_alloc[1].tensor.shape,
            input0_addr=BaseOp.get_addr(input0_alloc),
            input1_addr=BaseOp.get_addr(input1_alloc),
            input0_name=BaseOp.get_io_name(input0_alloc),
            input1_name=BaseOp.get_io_name(input1_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(input0_alloc[1].tensor.shape),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),
        )


class ConcatTypes(str, Enum):
    """Defines types of concatenation operators supported by the backend"""

    CONCAT = "concat"


class Concat(BaseOp, BaseIO):
    """Represents a concatenation operation"""

    input: Tuple[int, ...]
    input0_addr: Dict[int, int]  # first input tensor address per AIE4-column
    input1_addr: Dict[int, int]  # second input tensor address per AIE4-column
    input0_name: str  # name of the first input tensor
    input1_name: str  # name of the second input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input: Tuple[int, ...]  # round-up input0 C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Concat:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input0_alloc, input1_alloc = input_allocs

        return Concat(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=input0_alloc[1].tensor.shape,
            input0_addr=BaseOp.get_addr(input0_alloc),
            input1_addr=BaseOp.get_addr(input1_alloc),
            input0_name=BaseOp.get_io_name(input0_alloc),
            input1_name=BaseOp.get_io_name(input1_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(input0_alloc[1].tensor.shape),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),
        )


# streaming ops


class ResizeTypes(str, Enum):
    """Enum to define different types of resize operators"""

    RESIZE = "resize"


class Resize(BaseOp, BaseIO):
    """Represents a resize operation"""

    input: Tuple[int, ...]
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the first input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input: Tuple[int, ...]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Resize:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return Resize(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=input_alloc[1].tensor.shape,
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(input_alloc[1].tensor.shape),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),
        )


class SliceTypes(str, Enum):
    """Defines types of slice operators supported by the backend"""

    SLICE = "slice"


class Slice(BaseOp, BaseIO):
    """Represents a slice operation"""

    input: Tuple[int, ...]
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the first input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input: Tuple[int, ...]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Slice:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return Slice(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=input_alloc[1].tensor.shape,
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(input_alloc[1].tensor.shape),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),
        )


class GatherTypes(str, Enum):
    """Defines types of gather operators supported by the backend"""

    GATHER = "gather"


class Gather(BaseOp, BaseIO):
    """Represents a gather operation"""

    input: Tuple[int, ...]
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the first input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input: Tuple[int, ...]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> Gather:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return Gather(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=input_alloc[1].tensor.shape,
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(input_alloc[1].tensor.shape),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),
        )


class GenericUnaryOp(BaseOp, BaseIO):
    """Represents a generic operation"""

    input: Tuple[int, ...]
    input_addr: Dict[int, int]  # input tensor address per AIE4-column
    input_name: str  # name of the first input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input: Tuple[int, ...]  # round-up input C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> GenericUnaryOp:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input_alloc = input_allocs[0]

        return GenericUnaryOp(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input=input_alloc[1].tensor.shape,
            input_addr=BaseOp.get_addr(input_alloc),
            input_name=BaseOp.get_io_name(input_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input=get_padded_shape_rev(input_alloc[1].tensor.shape),
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),
        )


class Reshape(GenericUnaryOp):
    """Represents a reshape operation"""


class Quant(GenericUnaryOp):
    """Represents a quantization operation"""


class Sigmoid(GenericUnaryOp):
    """Represents a sigmoid operation"""


class Unsqueeze(GenericUnaryOp):
    """Represents an unsqueeze operation"""


class Transpose(GenericUnaryOp):
    """Represents a transpose operation"""


class GroupNormalization(GenericUnaryOp):
    """Represents a group normalization operation"""


class Dequant(GenericUnaryOp):
    """Represents a dequantization operation"""


class Softmax(GenericUnaryOp):
    """Represents a softmax operation"""


class InstanceNormalization(GenericUnaryOp):
    """Represents an instance normalization operation"""


class LayerNormalization(GenericUnaryOp):
    """Represents a layer normalization operation"""


class LpNormalization(GenericUnaryOp):
    """Represents an lp-normalization operation"""


class Squeeze(GenericUnaryOp):
    """Represents a squeeze operation"""


class GroupQueryAttention(GenericUnaryOp):
    """Represents a group query attention operation"""


class PWLA(GenericUnaryOp):
    """Represents a PWLA operation"""


class Gelu(GenericUnaryOp):
    """Represents a GELU (Gaussian Error Linear Unit) activation operation"""


class Expand(GenericUnaryOp):
    """Represents a Expand operation"""


class ReduceSum(GenericUnaryOp):
    """Represents a ReduceSum activation operation"""


class Tanh(GenericUnaryOp):
    """Represents a Tanh activation operation"""


class Sqrt(GenericUnaryOp):
    """Represents a Sqrt activation operation"""


class Concat2(Concat):
    """Represents a concatenation of two inputs"""


class Flatten(GenericUnaryOp):
    """Represents a flatten operation which flattens the input tensor into a 2D matrix"""


class Identity(GenericUnaryOp):
    """Represents a Identity Operation"""


class QuickGelu(GenericUnaryOp):
    """Represents a QuickGelu operation"""


class GenericBinaryOp(BaseOp, BaseIO):
    """Represents a generic binary operation e.g., MatMul"""

    input0: Tuple[int, ...]
    input1: Tuple[int, ...]
    input0_addr: Dict[int, int]  # first input tensor address per AIE4-column
    input1_addr: Dict[int, int]  # second input tensor address per AIE4-column
    input0_name: str  # name of the first input tensor
    input1_name: str  # name of the second input tensor
    op: str  # type of this operator
    output: Tuple[int, ...]
    padded_input0: Tuple[int, ...]  # round-up input0 C-dim to nearest 64 multiple
    padded_input1: Tuple[int, ...]  # round-up input1 C-dim to nearest 64 multiple
    padded_output: Tuple[int, ...]  # round-up output C-dim to nearest 64 multiple

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> GenericBinaryOp:
        prm_addr, wgt_addr = BaseOp.get_param_weight_addr(memory_config)
        input_allocs, output_alloc = BaseOp.get_allocs(model_ios, allocs, op)
        input0_alloc, input1_alloc = input_allocs

        return GenericBinaryOp(
            enable_L2_fusion=BaseOp.should_enable_fusion(input_allocs, output_alloc),
            input0=input0_alloc[1].tensor.shape,  # type: ignore
            input1=input1_alloc[1].tensor.shape,  # type: ignore
            input0_addr=BaseOp.get_addr(input0_alloc),
            input1_addr=BaseOp.get_addr(input1_alloc),
            input0_name=BaseOp.get_io_name(input0_alloc),
            input1_name=BaseOp.get_io_name(input1_alloc),
            load_input_from_ddr=BaseOp.should_load_input_from_ddr(input_allocs),
            name=name,
            op=op.type,
            output=output_alloc[1].tensor.shape,
            output_addr=BaseOp.get_addr(output_alloc),
            output_name=BaseOp.get_io_name(output_alloc),
            prm_addr=prm_addr,
            store_output_to_ddr=BaseOp.should_store_output_to_ddr(output_alloc),
            wgt_addr=wgt_addr,
            padded_input0=get_padded_shape_rev(input0_alloc[1].tensor.shape),  # type: ignore
            padded_input1=get_padded_shape_rev(input1_alloc[1].tensor.shape),  # type: ignore
            padded_output=get_padded_shape_rev(output_alloc[1].tensor.shape),  # type: ignore
        )


class GenericUnaryOrBinaryOp(BaseOp, BaseIO):
    """Represents an operation with one or two inputs"""

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> GenericUnaryOp | GenericBinaryOp:
        input_allocs, _ = BaseOp.get_allocs(model_ios, allocs, op)

        if len(input_allocs) not in (1, 2):
            raise ValueError(f"Operation {op.id} has unsupported number of inputs: {len(input_allocs)}")

        op_cls = GenericUnaryOp if len(input_allocs) == 1 else GenericBinaryOp
        new_op = type(op_cls.__name__, (op_cls, ), {})
        return new_op.build(memory_config, model_ios, allocs, name, attributes, op)


class MatMul(GenericUnaryOrBinaryOp):
    """Represents a matrix multiplication operation"""


class Div(GenericUnaryOrBinaryOp):
    """Represents a binary Div operation"""


class Sub(GenericUnaryOrBinaryOp):
    """Represents a binary Sub operation"""


class GenericUnsupportedOp(BaseOp, NpuMetadata):
    """Represents a generic unsupported operation"""

    enable_L2_fusion: bool  # should we enable L2 fusion for this op?
    name: str  # name of this operator
    op: str  # type of this operator

    @classmethod
    def build(
        cls,
        memory_config: MemoryConfig,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: DictAlloc,
        name: str,
        attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> GenericUnsupportedOp:
        return GenericUnsupportedOp(
            enable_L2_fusion=False,
            name=name,
            op=op.type,
            is_compilable=False
        )


# List of pointwise operators whose inputs can be overwritten by output
POINTWISE_OPS = set([Add.__name__])

# Non standard ops
NONSTANDARD_OPS = set([Quant.__name__, Dequant.__name__])


# A type to represent the L2 fusion tiling schema
TilingOperator = (
    Conv
    | MaxPool
    | Add
    | Gemm
    | Reshape
    | MatMul
    | Unsqueeze
    | Transpose
    | Concat
    | Quant
    | Sigmoid
    | Gelu
    | GroupNormalization
    | Dequant
    | Softmax
    | InstanceNormalization
    | LayerNormalization
    | LpNormalization
    | Squeeze
    | GroupQueryAttention
    | Resize
    | Slice
    | Gather
    | GenericUnaryOp
    | GenericBinaryOp
    | Mul
    | GlobalAveragePool
    | Concat2
    | Flatten
    | GenericUnsupportedOp
)
TilingSchema = Dict[str, TilingOperator]
OperatorMetadata = Tuple[str, str, Sequence[AttributeProto]]


class Tiler:
    """Computes L2 fusion related metadata associated with operators in ONNX graph"""

    model: ModelProto
    scope: str = "root"

    def __init__(self, model: ModelProto, memory_config: MemoryConfig | None):
        self.model = model
        self.memory_config = memory_config
        self.operator_metadata = self._extract_operator_metadata()
        self.model_ios = self._extract_model_ios()
        self.allocs: DictAlloc = {}
        self.outputs_alloc = False

    def _extract_operator_metadata(self) -> Dict[str, OperatorMetadata]:
        operator_metadata: Dict[str, OperatorMetadata] = {}
        for node in self.model.graph.node:
            operator_metadata[scoped_id(self.scope, node.name)] = (
                node.name,
                node.op_type,
                node.attribute,
            )
        return operator_metadata

    def _extract_model_ios(self) -> Tuple[Set[str], Set[str]]:
        inputs = {scoped_id(self.scope, vi.name) for vi in self.model.graph.input}
        outputs = {scoped_id(self.scope, vi.name) for vi in self.model.graph.output}
        return (inputs, outputs)

    def _alloc_outputs(self, alloc: Alloc) -> None:
        if self.outputs_alloc:
            return

        for output in self.model.graph.output:
            shape = [dim.dim_value for dim in output.type.tensor_type.shape.dim]
            sid = scoped_id(self.scope, output.name)
            self.allocs[sid] = (
                AllocationResult.SPILLED,
                TensorAllocation(
                    tensor=Tensor(id=sid, shape=tuple(shape), dtype=alloc[1].tensor.dtype),  # type: ignore
                    range=TensorLifetime(start=0, end=0),
                    location=TensorLocation.SPILLED,
                ),
            )

        self.outputs_alloc = True

    @lru_cache(maxsize=1)
    def get_supported_onnx_ops(self) -> Set[str]:
        """Get list of all supported onnx ops"""
        schemas = onnx.defs.get_all_schemas()
        supported_operators = {schema.name for schema in schemas}
        return supported_operators

    def is_an_onnx_op(self, operator_name: str) -> bool:
        """Check if an operator is supported by ONNX."""
        return operator_name in self.get_supported_onnx_ops()

    def get_op_type(self, op_type: str, op_name: str) -> str:
        """Get the base operator type by splitting on underscore."""
        op_name = Path(op_name).parts[-1].split("_")[0]
        op_type = op_type.split("_")[0]
        if self.is_an_onnx_op(op_type) or op_type in NONSTANDARD_OPS:
            return op_type
        if self.is_an_onnx_op(op_name):
            return op_name
        return op_type

    def get_op_metadata(self, op: Operation) -> OperatorMetadata:
        """Get operator metadata for a given operation"""
        if op.id not in self.operator_metadata:
            raise ValueError(f"{op.id} not found in model graph")
        return self.operator_metadata[op.id]

    def get_operator_tiling(
        self, op: Operation, allocs: List[Tuple[AllocationResult, TensorAllocation]], is_nonwaic: bool = False
    ) -> TilingOperator:
        """Get tiling for operator"""
        if op.id not in self.operator_metadata:
            raise ValueError(f"{op.id} not found in model graph")
        self._alloc_outputs(allocs[0])
        for alloc in allocs:
            self.allocs[alloc[1].tensor.id] = alloc

        # Find the name, type, and metadata for this particular operator node
        op_name, op_type, op_attributes = self.operator_metadata[op.id]

        # If the operator is not present in the operator registry, it's unsupported
        if not is_nonwaic and not OpSupport.is_op_schedule_supported(op_type):
            logger.info("found unsupported op='%s'", op_type)
            return GenericUnsupportedOp(enable_L2_fusion=False, name=op_name, op=op_type, is_compilable=False)

        # If the operator is supported, check if the allocator is aware of it
        op_classname = self.get_op_type(op_type, op_name)
        try:
            op_class = subclass_where(BaseOp, __name__=op_classname)  # type: ignore
        except Exception as e:
            raise RuntimeError(f"Operator type: {op_type} is supported, but the allocator is unaware: {e}") from e

        try:
            config = op_class.build(self.memory_config, self.model_ios, self.allocs, op_name, op_attributes, op)
        except Exception as e:  # pylint: disable=W0718
            logger.info("found unsupported op='%s', exception='%s'", op_type, e)
            config = GenericUnsupportedOp.build(self.memory_config, self.model_ios, self.allocs, op_name, op_attributes, op)
        return config
