"""Tiler for Resize Op"""
from dataclasses import dataclass, field

from utils.utils_common import BaseShape, BaseDims, ceildiv, DMA_ONLY_SPATIAL_SPLIT_MODES

from dmacompiler import set_dev_gen, DevGen

set_dev_gen(DevGen.Aie4)


class ResizeShape(BaseShape):
    """Define the Shape of a Resize Op"""
    ifm_bytes: int
    num_interpolations: tuple[int, int]


@dataclass(slots=True)
class ResizeMapping(BaseDims):
    """Define the Mapping of a Resize Op"""
    shape: ResizeShape = ResizeShape(ifm=(0, 0, 0, 0), ofm=(0, 0, 0, 0), ifm_bytes=2, num_interpolations=(0, 0))
    spatial_split: dict[str, int] = field(default_factory=lambda: {
        "N": DMA_ONLY_SPATIAL_SPLIT_MODES[0][0],
        "Y": DMA_ONLY_SPATIAL_SPLIT_MODES[0][1],
        "X": DMA_ONLY_SPATIAL_SPLIT_MODES[0][2],
        "C": DMA_ONLY_SPATIAL_SPLIT_MODES[0][3]
        })

    Ni: int = 1
    No: int = 1
    Nis: int = 1
    Nos: int = 1
    C_loop: int = 1
    ifm_bits: int = 0
    ifm_bytes: int = 0
    ofm_bits: int = 0
    scale_Y: int = 1
    scale_X: int = 1

    def __post_init__(self) -> None:
        self.Ni, self.Yi, self.Xi, self.Ci = self.shape.ifm
        self.No, self.Yo, self.Xo, self.Co = self.shape.ofm

        split_N = self.spatial_split.get("N", 1)
        split_Y = self.spatial_split.get("Y", 1)
        split_X = self.spatial_split.get("X", 1)
        split_C = self.spatial_split.get("C", 1)

        self.Nis = self.Ni
        self.Yis = 1
        self.Xis = self.Xi
        self.Cis = self.Ci

        self.Nos = self.No
        self.Yos = 1
        self.Xos = self.Xi
        self.Cos = self.Ci

        self.N_split = split_N
        self.Y_split = split_Y
        self.X_split = split_X
        self.Co_split = split_C

        self.N_loop = ceildiv(self.Ni, self.Ni * split_N)
        self.Y_loop = ceildiv(self.Yi, split_Y)
        self.X_loop = ceildiv(self.Xi, self.Xi * split_X)
        self.C_loop = ceildiv(self.Ci, self.Ci * split_C)

        self.ifm_bytes = self.shape.ifm_bytes
        self.ifm_bits = self.shape.ifm_bytes * 8
        self.ofm_bits = self.shape.ifm_bytes * 8

        self.scale_Y = self.shape.num_interpolations[0]
        self.scale_X = self.shape.num_interpolations[1]
