"""
This script defines the schema for L3 fusion tiling JSON
"""

from __future__ import annotations
from typing import Dict, Sequence, Tuple, Set, Type, TypeVar
from onnx import AttributeProto, ModelProto
from pydantic import BaseModel

# pylint: disable=import-error,redefined-outer-name
from graph.L2_fusion_tiling import Tiler as L2Tiler, NpuMetadata, OpSupport
from graph.L2_fusion_tiling import BaseOp as L2BaseOp
from graph.allocation_types import Alloc, AllocList
from graph.tensor_types import Operation, TensorLocation

Addr = Tuple[int, int, int]


TypeOfBaseModel = TypeVar('TypeOfBaseModel', bound=BaseModel)


def as_base(obj: TypeOfBaseModel, base_cls: Type[TypeOfBaseModel]) -> TypeOfBaseModel:
    """Convert to base class"""
    if not isinstance(obj, base_cls):
        raise TypeError(f"{obj} is not a {base_cls}")
    return base_cls.model_validate(obj.model_dump())


class GenericOperator(NpuMetadata):
    """A generic operator"""

    name: str  # name of this operator
    op_type: str  # type of this operator
    input_addr: Dict[str, Addr]  # Input addresses
    output_addr: Dict[str, Addr]  # Output addresses

    @staticmethod
    def format_addr(addr: Dict[str, Addr], prefix: str) -> Dict[str, Addr]:
        """Format addresses"""
        addr_vals = addr.values()
        addr_dict = {}
        if len(addr_vals) > 1:
            addr_dict = {f"{prefix}{k}": v for k, v in enumerate(addr_vals)}
        elif len(addr_vals) == 1:
            addr_dict = {prefix: next(iter(addr_vals))}
        return addr_dict

    def to_l3(self) -> Dict:
        """Convert to L3 fusion tiling schema"""
        return {
            "L3": {
                **GenericOperator.format_addr(self.input_addr, "ifm"),
                **GenericOperator.format_addr(self.output_addr, "ofm"),
            },
            **as_base(self, NpuMetadata).model_dump(),
        }

    @staticmethod
    def update_l3(op: Operation, allocs: Dict[str, Alloc]) -> Dict:
        """Update L3 fusion tiling schema for an operator"""
        input_allocs, oalloc = L2BaseOp.get_allocs((set(), set()), allocs, op, strict=False)
        return {
            **GenericOperator.format_addr(Builder.get_enumerated_addr(input_allocs), "ifm"),
            **GenericOperator.format_addr(Builder.get_enumerated_addr([oalloc]), "ofm"),
        }


class Builder(OpSupport):
    """A generic operator"""

    @staticmethod
    def get_addr(allocs: AllocList) -> Dict[str, Addr]:
        """Get address from alloc"""
        return {
            L2BaseOp.get_io_name(alloc): (
                alloc[1].tensor.bin.value,
                alloc[1].block.start,
                alloc[1].tensor.size
            )
            for alloc in allocs
            if alloc[1].location != TensorLocation.UNKNOWN
        }

    @staticmethod
    def get_enumerated_addr(allocs: AllocList) -> Dict[str, Addr]:
        """Get address from alloc"""
        return {
            str(aid): (
                alloc[1].tensor.bin.value,
                alloc[1].block.start,
                alloc[1].tensor.size
            )
            for aid, alloc in enumerate(allocs)
            if alloc[1].location != TensorLocation.UNKNOWN
        }

    @classmethod
    def build(
        cls,
        model_ios: Tuple[Set[str], Set[str]],
        allocs: Dict[str, Alloc],
        op_name: str,
        op_type: str,
        _op_attributes: Sequence[AttributeProto],
        op: Operation,
    ) -> GenericOperator:
        """Factory function to build a generic operator"""

        input_allocs, oalloc = L2BaseOp.get_allocs(model_ios, allocs, op, strict=False)
        return GenericOperator(
            name=op_name,
            op_type=op_type,
            input_addr=Builder.get_addr(input_allocs),
            output_addr=Builder.get_addr([oalloc]),
            is_compilable=Builder.is_op_schedule_supported(op.type),
        )


# A type to represent the L3 fusion tiling schema
TilingSchema = Dict[str, GenericOperator]


class Tiler(L2Tiler):
    """Computes L3 related metadata associated with operators in ONNX graph"""

    def __init__(self, model: ModelProto):
        super().__init__(model, None)

    def get_operator_tiling(self, op: Operation, allocs: AllocList, is_nonwaic: bool = False) -> GenericOperator:
        """Get tiling for operator"""
        if op.id not in self.operator_metadata:
            raise ValueError(f"{op.id} not found in model graph")
        self._alloc_outputs(allocs[0])
        for alloc in allocs:
            self.allocs[alloc[1].tensor.id] = alloc

        op_name, op_type, op_attributes = self.operator_metadata[op.id]
        op_type = self.get_op_type(op_type, op_name)
        return Builder.build(
            self.model_ios, self.allocs, op_name, op_type, op_attributes, op
        )
