from typing import Tuple, Dict, List

from dataclasses import dataclass

from dataflow.conv.conv_common import ConvPingPong, PingPong, conv_input, iceil
from dataflow.dataflow_common import ceildiv
from OGOAT.src.Tiler.utils import create_grid, filter_grid
from dmacompiler import compute_buffer_size, set_dev_gen, DevGen, config
from itertools import product
import math

from OGOAT.src.Scheduling_Engine.infra.const import MEMTILE_SIZE, CORE_BANK_SIZE

set_dev_gen(DevGen.Aie2p)


@dataclass
class TemporalSplits:
    def __init__(self, layer, convsubvol, spatial_splits):
        self.Y_loop = ceildiv(layer.Yo, (convsubvol.Yos * spatial_splits.Y_split))
        self.Co_loop = ceildiv(layer.Co, (convsubvol.Cos * spatial_splits.Co_split))
        self.Ci_loop = ceildiv(layer.Ci,
                               (convsubvol.Cis * spatial_splits.Ci_split)) if not layer.is_standalone_dwc else 1
        self.X_loop = ceildiv(layer.Xo, (convsubvol.Xos * spatial_splits.X_split))
        self.X_loop = 1 << (math.ceil(math.log2(self.X_loop)))  # Get next-power of 2
        self.loops = self.Y_loop * self.Co_loop * self.Ci_loop * self.X_loop


@dataclass
class SpatialSplits:
    def __init__(self, overlay):
        self.Y_split, self.X_split, self.Co_split = overlay.core_splits['ofm']
        self.Ci_split = overlay.core_splits['ifm'][2]


@dataclass
class ConvSubVol:
    def __init__(self, Xis, Yis, Cis, Xos, Yos, Cos):
        self.Xis = Xis
        self.Yis = Yis
        self.Cis = Cis
        self.Xos = Xos
        self.Yos = Yos
        self.Cos = Cos


@dataclass
class L1Buffers:
    def __init__(self, convsubv: ConvSubVol, conv_pp: ConvPingPong, kernel, layer, overlay, enable_add):
        """Calculate memory requirements for given tiling dimensions."""

        # Calculate memory sizes
        self.ifm_size = (convsubv.Cis * convsubv.Yis * convsubv.Xis * kernel.ifm_bytes)
        self.wgt_size = (
                iceil(((layer.Ky * layer.Kx * kernel.wgt_bytes) * (
                    convsubv.Cis if layer.is_standalone_dwc else (convsubv.Cos * convsubv.Cis))), kernel.mem_align) +
                iceil((kernel.qdq_param_size + (
                            kernel.bias_bytes * (convsubv.Cis if layer.is_standalone_dwc else convsubv.Cos))),
                      kernel.mem_align)
        )
        self.ofm_size = (convsubv.Cos * convsubv.Yos * convsubv.Xos * kernel.ofm_bytes)
        self.tdm_size = (convsubv.Cos * convsubv.Yos * convsubv.Xos * kernel.tdm_bytes) // 2

        # Calculate ifm_sum_size
        Xi_g = ((convsubv.Xos // kernel.X_gran) * layer.Sx) + (layer.Kx > 1)
        Yi_g = ((convsubv.Yos - 1) * layer.Sy) + layer.Ky

        '''
        convsubv_Xis_ceil = iceil(convsubv.Xis, 8)
        convsubv_Yis_ceil = iceil(convsubv_Xis_ceil * convsubv.Yis, 64)
        max_value = max(128, convsubv_Yis_ceil)
        final_ceil = iceil(max_value, convsubv.Yis * 8)
        '''
        self.tmp_buf_size = iceil(
            iceil(max(128, iceil(iceil(convsubv.Xis, 8) * convsubv.Yis, 64)), convsubv.Yis * 8) * kernel.tdm_bytes, 64)

        if kernel.is_a16w8:
            ifm_sum_bytes = kernel.tdm_bytes
            # NOTE:
            # Formulae for subsequent variables are derived from direct_conv_int16x8_generic_template kernel
            ifm_sum_elem_calculated_per_outer_g_iteration = 64
            inner_g_loop_mul_factor = 8
            N_g = 1
            min_outer_g = 2
            outer_g = max(min_outer_g, ceildiv((Xi_g * Yi_g * N_g), inner_g_loop_mul_factor))
            self.ifm_sum_size = ifm_sum_elem_calculated_per_outer_g_iteration * outer_g * ifm_sum_bytes
            # ensure minimal size of ifm_sum_size
            self.ifm_sum_size = max(1024, self.ifm_sum_size)
        elif kernel.is_a8w8:
            Yi = (convsubv.Yos - 1) * layer.Sy + (layer.Ky - 1) + 1
            Xi = (convsubv.Xos - 1) * layer.Sx + (layer.Kx - 1) + 1
            sum_lanes = 64
            sum_block = Xi
            self.ifm_sum_size = max(2 * sum_lanes, iceil(iceil(sum_block, 8) * Yi, sum_lanes)) * kernel.tdm_bytes
        else:
            self.ifm_sum_size = iceil(Xi_g * kernel.X_gran * Yi_g, kernel.mem_align) * kernel.tdm_bytes

        # For CONV_ASYM case assigning 512 bytes of memory space for ifm sum weight unpack
        if kernel.is_a8w8:
            Yi = (convsubv.Yos - 1) * layer.Sy + (layer.Ky - 1) + 1
            Xi = (convsubv.Xos - 1) * layer.Sx + (layer.Kx - 1) + 1
            sum_lanes = 64
            sum_block = Xi
            self.scratch_buf_size = max(512,
                                        iceil(max(2 * sum_lanes, iceil(iceil(sum_block, 8) * Yi, sum_lanes)), Yi * 8) * 4)
        else:
            self.scratch_buf_size = convsubv.Cos * 8 * layer.Ky * layer.Kx if layer.is_standalone_dwc else 512

        core_bank_size = CORE_BANK_SIZE
        conv_kernel_param_size = kernel.conv_kernel_param_size
        stack_addr = overlay.coretile_capacity_bytes

        # Get ping, pong size accordingly from single buffer size and pingpong
        pp = lambda x, y: (x * int(y.ping), x * int(y.pong))
        ifm_size_ping, ifm_size_pong = pp(self.ifm_size, conv_pp.ifm)
        ofm_size_ping, ofm_size_pong = pp(self.ofm_size, conv_pp.ofm)
        wgt_size_ping, wgt_size_pong = pp(self.wgt_size, conv_pp.wgt)
        tdm_size_ping, tdm_size_pong = pp(self.tdm_size, conv_pp.tdm)
        # --------------------------------------------------------------------------
        self.conv_kernelprm_addr = 0
        self.ifm_ping_addr = conv_kernel_param_size
        self.wgt_ping_addr = (2 * core_bank_size - wgt_size_ping)

        self.tdm_ping_addr = self.ifm_ping_addr + ifm_size_ping
        self.ofm_ping_addr = self.tdm_ping_addr + tdm_size_ping
        self.scratch_buf = self.ofm_ping_addr + ofm_size_ping

        valid_alloc = ((self.scratch_buf + self.scratch_buf_size) <= self.wgt_ping_addr)

        self.ifm_pong_addr = 2 * core_bank_size
        self.wgt_pong_addr = stack_addr - wgt_size_pong

        self.tdm_pong_addr = self.ifm_pong_addr + ifm_size_pong
        self.tmp_buf = self.tdm_pong_addr + tdm_size_pong
        self.add_ifm_addr = self.tmp_buf + self.tmp_buf_size
        self.ifm_sum_addr = ((self.add_ifm_addr + ofm_size_ping) if enable_add else (self.tmp_buf + self.tmp_buf_size))
        self.ofm_pong_addr = None

        valid_alloc = valid_alloc and ((self.ifm_sum_addr + self.ifm_sum_size) <= self.wgt_pong_addr)

        if not conv_pp.ifm.pong:
            self.ifm_pong_addr = None

        if not conv_pp.wgt.pong:
            self.wgt_pong_addr = None

        # --------------------------------------------------------------------------
        if not valid_alloc:
            self.conv_kernelprm_addr = 0
            self.ifm_ping_addr = conv_kernel_param_size
            self.tdm_ping_addr = self.ifm_ping_addr + ifm_size_ping
            self.wgt_ping_addr = self.tdm_ping_addr + tdm_size_ping
            self.ofm_ping_addr = self.wgt_ping_addr + wgt_size_ping
            self.add_ifm_addr = self.ofm_ping_addr + ofm_size_ping
            self.ifm_sum_addr = (self.add_ifm_addr if enable_add else self.ofm_ping_addr) + ofm_size_ping
            self.ifm_pong_addr = self.ifm_sum_addr if kernel.is_xint8 else (self.ifm_sum_addr + self.ifm_sum_size)
            self.tdm_pong_addr = self.ifm_pong_addr + ifm_size_pong
            self.tmp_buf = self.tdm_pong_addr + tdm_size_pong
            self.wgt_pong_addr = self.tmp_buf + self.tmp_buf_size
            self.scratch_buf = self.wgt_pong_addr + wgt_size_pong
            mem_end = self.scratch_buf + self.scratch_buf_size
            self.ofm_pong_addr = None

            if not conv_pp.ifm.pong:
                self.ifm_pong_addr = None

            if not conv_pp.wgt.pong:
                self.wgt_pong_addr = None

            valid_alloc = mem_end <= stack_addr
        # -----------------------------------------------
        self.core_mem_constraint = valid_alloc


@dataclass
class L2Buffers:
    def __init__(self, is_X8_split, convsubv, temporalsplits, ofm_pp, wgt_subv_size, layer, kernel, spatial_splits,
                 aie_rows, enable_add=False):

        self.Xim = convsubv.Xis if is_X8_split else layer.Xi
        self.Yim = convsubv.Yis
        self.Cim = layer.Ci

        self.Xom = convsubv.Xos * (1 if is_X8_split else spatial_splits.X_split)
        self.Yom = convsubv.Yos
        self.mt_co_pack = (temporalsplits.Co_loop if is_X8_split else 1)
        self.Com = (convsubv.Cos * spatial_splits.Co_split)

        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
        self.prm_size = compute_buffer_size(f'Row:{aie_rows} Param:{self.param_subv_size}')
        self.conv_kernel_param_size = kernel.conv_kernel_param_size

        ifm_shard_size = (self.Yim * self.Xim * self.Cim * kernel.ifm_bytes)
        self.ofm_size = self.mt_co_pack * (self.Yom * self.Xom * self.Com * kernel.ofm_bytes)

        # NOTE: For Co_loop == 1 and Ci_loop == 1; wgt is pinned in memtile and coretile
        self.pin_wgt_bias_l1 = (temporalsplits.Co_loop == 1) and (temporalsplits.Ci_loop == 1) and (
            not layer.is_standalone_dwc)

        # NOTE: For Y_loop == 1 and Ci_loop == 1; ifm is pinned in memtile and coretile
        self.pin_ifm_l1 = (temporalsplits.Y_loop == 1) and (temporalsplits.Ci_loop == 1) and (not is_X8_split) and (
            not layer.is_standalone_dwc)

        total_ifm_reuse_memtile_size = (
                self.prm_size +
                ifm_shard_size +
                (wgt_subv_size * 2) +
                self.ofm_size +
                self.conv_kernel_param_size
        )

        max_memtile_size = MEMTILE_SIZE
        self.enable_ifm_streaming = (
                                            ((total_ifm_reuse_memtile_size > max_memtile_size) and (
                                                        temporalsplits.Ci_loop > 1)) or  # NOTE: when Ci_loop == 1, won't achieve any split along Ci
                                            ((temporalsplits.Co_loop == 1) and (not is_X8_split)) or
                                            layer.is_standalone_dwc
                                    ) and not enable_add

        self.num_ifm_subv = temporalsplits.Ci_loop
        self.num_pack_wgt_subv = 1

        memtile_ofm_buffering = int(ofm_pp)
        if self.enable_ifm_streaming:
            wgt_shard_size = temporalsplits.Co_loop * temporalsplits.Ci_loop * wgt_subv_size
            ifm_size_cutoff = ((max_memtile_size - self.prm_size - wgt_shard_size - (
                        self.ofm_size * memtile_ofm_buffering) - self.conv_kernel_param_size) // 2)
            self.num_ifm_subv = self._split_memtile_ifm_subv(convsubv.Yis, convsubv.Cos, convsubv.Cis,
                                                             spatial_splits.Co_split, temporalsplits.Co_loop,
                                                             temporalsplits.Ci_loop, ifm_size_cutoff, layer, kernel,
                                                             spatial_splits)
            if layer.is_standalone_dwc:
                assert convsubv.Cis == convsubv.Cos
                self.Cim = convsubv.Cos * spatial_splits.Co_split * self.num_ifm_subv
            else:
                self.Cim = min(convsubv.Cis * self.num_ifm_subv, layer.Ci)
            self.ifm_size = (convsubv.Yis * self.Xim * self.Cim * kernel.ifm_bytes)

            total_wgt_reuse_memtile_size = (
                    self.prm_size +
                    (self.ifm_size * 2) +
                    wgt_shard_size +
                    (self.ofm_size * memtile_ofm_buffering) +
                    self.conv_kernel_param_size
            )
            self.enable_wgt_reuse = (
                    ((total_wgt_reuse_memtile_size <= max_memtile_size) or
                     self.pin_wgt_bias_l1) and
                    (temporalsplits.Y_loop > 1)
            )
            if self.enable_wgt_reuse:
                self.wgt_size = wgt_shard_size
            else:
                if is_X8_split:
                    wgt_size_cutoff = ((max_memtile_size - self.prm_size - self.ifm_size * 2 - (
                                self.ofm_size * memtile_ofm_buffering) - self.conv_kernel_param_size) // 2)
                    self.num_pack_wgt_subv = self._pack_memtile_wgt_subv(temporalsplits.Ci_loop, temporalsplits.Co_loop,
                                                                         temporalsplits.Y_loop, wgt_subv_size,
                                                                         wgt_size_cutoff, layer)
                    self.wgt_size = wgt_subv_size * self.num_pack_wgt_subv
                else:
                    self.wgt_size = wgt_subv_size
        else:
            self.Cim = layer.Ci
            self.ifm_size = ifm_shard_size

            if is_X8_split:
                wgt_shard_size = temporalsplits.Co_loop * temporalsplits.Ci_loop * wgt_subv_size

                total_wgt_reuse_memtile_size = (
                        self.prm_size +
                        self.ifm_size +
                        wgt_shard_size +
                        (self.ofm_size * memtile_ofm_buffering) +
                        self.conv_kernel_param_size
                )
                self.enable_wgt_reuse = (
                        ((total_wgt_reuse_memtile_size <= max_memtile_size) or
                         self.pin_wgt_bias_l1) and
                        (temporalsplits.Y_loop > 1) and (temporalsplits.Y_loop <= config.MAX_LOCK_VALUE)
                # Y_loop will be reuse ratio of wgt,
                )
                if self.enable_wgt_reuse:
                    self.wgt_size = wgt_shard_size
                    num_pack_wgt_subv = 1
                else:
                    wgt_size_cutoff = ((max_memtile_size - self.prm_size - self.ifm_size - (
                                self.ofm_size * memtile_ofm_buffering) - self.conv_kernel_param_size) // 2)
                    self.num_pack_wgt_subv = self._pack_memtile_wgt_subv(temporalsplits.Ci_loop, temporalsplits.Co_loop,
                                                                         temporalsplits.Y_loop, wgt_subv_size,
                                                                         wgt_size_cutoff, layer, is_pingpong=True)
                    self.wgt_size = wgt_subv_size * self.num_pack_wgt_subv

            else:
                self.enable_wgt_reuse = False
                self.wgt_size = wgt_subv_size

    def _split_memtile_ifm_subv(
            self,
            Yis, Cos, Cis, Co_split, Co_loop, Ci_loop,
            size_cutoff: int,
            layer,
            kernel,
            spatial_splits,
    ) -> int:
        if layer.is_standalone_dwc:
            loop_count = Co_loop
            subv_size = (Yis * layer.Xi * Cos * spatial_splits.Co_split * kernel.ifm_bytes)
        else:
            loop_count = Ci_loop
            subv_size = (Yis * layer.Xi * Cis * kernel.ifm_bytes)
        for num_ifm_subv in range(loop_count, 0, -1):
            memtile_size = subv_size * num_ifm_subv
            is_valid = (
                    (memtile_size <= size_cutoff) and
                    ((loop_count % num_ifm_subv) == 0)
            )
            if is_valid:
                return num_ifm_subv
        return 1

    def _pack_memtile_wgt_subv(
            self,
            Ci_loop, Co_loop, Y_loop, wgt_subv_size,
            size_cutoff: int,
            layer,
            is_pingpong: bool = True,
    ) -> int:
        if layer.is_standalone_dwc:
            loop_count = Co_loop
            subv_size = wgt_subv_size
            total_count = Y_loop * Ci_loop
        else:
            loop_count = Ci_loop
            subv_size = wgt_subv_size
            total_count = Y_loop * Co_loop * Ci_loop
        if total_count <= 256:
            return 1
        for num_wgt_subv in range(loop_count, 0, -1):
            memtile_size = subv_size * num_wgt_subv * 2 if is_pingpong else 1
            is_valid = (
                    (memtile_size <= size_cutoff) and
                    ((loop_count % num_wgt_subv) == 0)
            )
            if is_valid:
                return num_wgt_subv
        return 1


class ValidTiling:
    def __init__(self, convsubv: ConvSubVol, temporalsplits: TemporalSplits, conv_pp: ConvPingPong,
                 spatialsplits: SpatialSplits, l1buffers: L1Buffers, memtile_params: L2Buffers, loop_constraint: bool):
        self.convsubv = convsubv
        self.temporalsplits = temporalsplits
        self.spatialsplits = spatialsplits
        self.conv_pp = conv_pp
        self.l1buffers = l1buffers
        self.memtile_params = memtile_params
        self.loop_constraint = loop_constraint


class ConvTiler:

    def __init__(self, layer, device, overlay, kernel):

        self.overlay = overlay
        self.device = device
        self.layer = layer
        self.kernel = kernel
        self.spatial_splits = SpatialSplits(self.overlay)

        # Initialize valid_core_subvols as a list
        self.valid_core_subvols = []
        self.valid_count = 0

        # Get enable_add flag from kernel
        self.enable_add = hasattr(self.kernel, 'enable_add') and self.kernel.enable_add

        self.aie_cols, self.aie_rows = self.overlay.cols, self.overlay.rows
        self.spatial_splits.Y_split, self.spatial_splits.X_split, self.spatial_splits.Co_split = \
        self.overlay.core_splits['ofm']

        self.spatial_splits.Ci_split = self.overlay.core_splits['ifm'][2]

        # Get core instruction size
        self.param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
        self.vars_dict = {}

    def _check_l2fusion_constraint(self, l1_buffer_wgt_size, convsubv, is_X8_split):
        """
        Check if data can fit in a L2 column based on ifm/ofm residencies
        """
        if (("L2" in self.layer.in_act_residency) or ("L2" in self.layer.out_act_residency)) and is_X8_split:  # TODO
            return False

        Yoc = ceildiv(self.layer.Yo, self.overlay.num_memtile_subregions)
        Yic = conv_input(Yoc, self.layer.Ky, self.layer.Sy)
        # Calculate initial L2 column sizes assuming residency='L2'
        memtile_ifm_size = (self.kernel.ifm_bytes * self.layer.Xi * Yic * self.layer.Ci)
        memtile_ofm_size = (self.kernel.ofm_bytes * self.layer.Xo * Yoc * self.layer.Co)
        memtile_wgt_size = l1_buffer_wgt_size * self.spatial_splits.Ci_split * self.spatial_splits.Co_split

        num_memtile_subregions = self.overlay.num_memtile_subregions
        available_memtile_size = self.overlay.memtile_capacity_bytes

        if self.layer.in_act_residency == "L3":
            memtile_ifm_size = (self.kernel.ifm_bytes * (convsubv.Xis * self.spatial_splits.X_split) * ceildiv(
                convsubv.Yis * self.spatial_splits.Y_split, num_memtile_subregions) * convsubv.Cis)

        if self.layer.out_act_residency == "L3":
            memtile_ofm_size = (self.kernel.ofm_bytes * (convsubv.Xos * self.spatial_splits.X_split) * ceildiv(
                convsubv.Yos * self.spatial_splits.Y_split, num_memtile_subregions) * (
                                            convsubv.Cos * self.spatial_splits.Co_split))

        required_memtile_size = memtile_ifm_size + memtile_ofm_size + memtile_wgt_size
        required_memtile_size += self.kernel.conv_kernel_param_size * num_memtile_subregions
        required_memtile_size += self.param_subv_size * num_memtile_subregions

        return (required_memtile_size <= available_memtile_size)

    def calculate_memtile_tilings(self):
        pass

    def check_valid_memtile_tilings(self):
        pass

    def _check_other_constraints(self, convsubv: ConvSubVol) -> bool:
        if not hasattr(self.kernel, 'other_constraints'):
            return True
        self.vars_dict['H'] = convsubv.Yos
        self.vars_dict['Kw'] = self.layer.Kx
        self.vars_dict['Kh'] = self.layer.Ky
        self.vars_dict['Sw'] = self.layer.Sx
        self.vars_dict['Sh'] = self.layer.Sy
        other_constraints = {}
        for constraint, formula in self.kernel.other_constraints.items():
            other_constraints[constraint] = eval(formula, self.vars_dict)
        return all(other_constraints.values())

    def _check_kernel_constraints(self, convsubv: ConvSubVol) -> bool:
        """Check various constraints on tiling configuration."""
        # Outer loop range constraint
        Yos_g = convsubv.Yos // self.kernel.Y_gran
        Xos_g = convsubv.Xos // self.kernel.X_gran
        Cos_g = convsubv.Cos // self.kernel.Co_gran
        outer_loop_constraint = ((Yos_g * Xos_g * Cos_g) >= self.kernel.outer_loop_min)

        if self.kernel.is_a16w8:
            Cos_valid = ((convsubv.Cos % (self.kernel.Co_gran * self.layer.Sx)) == 0)
        else:
            Cos_valid = ((convsubv.Cos % (self.kernel.Co_gran * (1 if self.layer.is_standalone_dwc else 2))) == 0)

        granularity_constraint = (convsubv.Xos >= self.kernel.X_gran) and (convsubv.Yos >= self.kernel.Y_gran) and (
                    convsubv.Cos >= self.kernel.Co_gran)

        # Xis*Cis product constraint
        # TODO: only for a16w8 case, will be removed later. Disabled from Conv_kernel_metadata.yaml
        # Xis_Cis_constraint = ((not (self.kernel.additional_constraints.get('Xis_Cis_max') is not None)) or
        #                      ((Xis * Cis) < self.kernel.additional_constraints['Xis_Cis_max']))

        return outer_loop_constraint and granularity_constraint and Cos_valid  # and Xis_Cis_constraint

    def _check_scheduler_constraints(self,
                                     temporalsplits,
                                     convsubv: ConvSubVol,
                                     wgt_subv_size, is_X8_split: bool = False,
                                     enable_ifm_streaming: bool = False) -> bool:
        """Check constraints on scheduler limitations"""

        # Scheduler max loop constraint, consider double buffering
        loop_constraint = (temporalsplits.Co_loop * temporalsplits.Ci_loop <= (
                    config.MAX_REPEAT_COUNT * config.MAX_TASK_QUEUE_SIZE * 2))

        # DMA compiler lock value constraint
        # TODO: Disable for now and recheck this.
        # num_readers = self.aie_rows
        # lock_init_value = Co_loop * num_readers
        # lock_constraint = lock_init_value <= MAX_LOCK_VALUE
        ########
        if is_X8_split:
            Y_loop_valid = ceildiv(self.layer.Yo, convsubv.Yos) <= 64 * 8
            Co_mt_pack = temporalsplits.Co_loop
            Co_loop_valid = Co_mt_pack <= 4  # the max queue limit
            is_X8_split_valid = not enable_ifm_streaming and Y_loop_valid and Co_loop_valid and (
            (not self._X8_exception([convsubv.Xos])[0]))
            valid_X_loop = [1, 2, 4, 8]
            shim_iter_valid = True
        else:
            # TODO: Check if can be moved to shim related constraints later.
            shim_wgt_Co_loop = temporalsplits.Co_loop * self.spatial_splits.Co_split
            shim_wgt_Co_iter = (temporalsplits.Ci_loop * wgt_subv_size) // 4  # words
            shim_wgt_iter_step = (shim_wgt_Co_iter * self.overlay.broadcast_cols) if (
                        shim_wgt_Co_loop > self.overlay.broadcast_cols) else 0
            shim_ofm_iter_step = convsubv.Yos * self.layer.Xo * self.layer.Co * self.kernel.ofm_bytes * self.aie_cols // 4
            shim_iter_valid = ((shim_wgt_iter_step <= config.MAX_SHIM_STEP) and
                               (shim_ofm_iter_step < config.MAX_SHIM_STEP))
            is_X8_split_valid = True
            valid_X_loop = [1]
        is_X8_split_valid = is_X8_split_valid & (temporalsplits.X_loop in valid_X_loop)
        return (loop_constraint and shim_iter_valid and is_X8_split_valid)

    def _X8_exception(self, Xos_grid) -> list[bool]:
        """
        Function to check if there is a exception in X8 split.
        1. MEM tile DMA limiations :
            1) MEM tile MM2S has 4 dimension.
            2) Max Padding value @Dim0 = 64(32bits); @Dim1 = 32, @Dim2 =16 and no padding @Dim3
            3) Max Wrap value for each Dims is 1023
        2. IFM data at MEM tile :
            1) Memory format : YCXC8
            2) Transfer from MEM to Core :  CYCisXisCi8 - Requirement is 5 dimension,but Max dimension we have is 4.
            3) Only way to support this is to fuse XisCi8. Conside XisCi8 = Cib, then it becomes a CYCisCib --> 4Dims transfer.
        3. Limitation for  XisCi8 fusion:
            1) Dim0 Wrap should be <= MAX_MEMTILE_WRAP ((Xis * C_gran * (bytes/elem)) // 4) <= 1023.
            2) Padding on Dim0. It should be less than MAX_DIM_PAD_VALUE
                ((Xis-Xi_remaider) * C_gran * (bytes/elem) //4) <=64
        """
        is_exception = [False] * len(Xos_grid)
        if not Xos_grid:
            return is_exception
        Ci_block = (self.kernel.Ci_gran * self.kernel.ifm_bytes)
        for idx, Xos in enumerate(Xos_grid):
            Xis = iceil(conv_input(Xos, self.layer.Kx, self.layer.Sx) * Ci_block, self.kernel.mem_align) // Ci_block
            Xi_overlap = (Xis - Xos) // 2
            Xi_remainder = self.layer.Xi % Xos
            Xi_remainder = Xos if (Xi_remainder == 0) else Xi_remainder
            padding = (Xis - (Xi_remainder + Xi_overlap)) * self.kernel.Ci_gran * self.kernel.ifm_bytes // 4
            if padding > config.MAX_MEMTILE_D0_PAD:
                is_exception[idx] = True
        return is_exception

    def calculate_array_tilings(self, l2fusion_pass: bool = False):

        Ci_block = (self.kernel.Ci_gran * self.kernel.ifm_bytes)
        Cis_min = ceildiv(self.kernel.inner_loop_min, self.layer.Ky * self.layer.Kx) * self.kernel.Ci_gran
        if self.kernel.is_a16w8:
            Cos_min = self.kernel.Co_gran * 2 if self.layer.Sx >= 2 else self.kernel.Co_gran
        else:
            Cos_min = self.kernel.Co_gran * (1 if self.layer.is_standalone_dwc else 2)

        if self.spatial_splits.X_split == 8 and not self.enable_add:
            self.layer.Xo = (iceil(iceil(self.layer.Xo_orig, self.spatial_splits.X_split), self.kernel.X_gran))
            Xos_min = self.kernel.X_gran
            is_X8_split = True
        else:
            self.layer.Xo = self.layer.Xo_orig
            Xos_min = iceil(ceildiv(self.layer.Xo, self.spatial_splits.X_split), self.kernel.X_gran)
            is_X8_split = False

        # Create grids with appropriate constraints
        # NOTE: require_divisible derived from disable_scheduler_constraints. Since currently scheduler does not support DMA padding along Co and Ci.
        Cos_grid = filter_grid(
            list(range(Cos_min, max(Cos_min, iceil(self.layer.Co, self.kernel.Co_gran)) + 1, self.kernel.Co_gran)),
            self.layer.Co, self.spatial_splits.Co_split, not self.layer.disable_scheduler_constraints)
        Yos_grid = filter_grid(list(range(1, ceildiv(self.layer.Yo, self.spatial_splits.Y_split) + 1, 1)),
                               self.layer.Yo, self.spatial_splits.Y_split)
        Xos_grid = filter_grid(
            list(range(Xos_min, max(Xos_min, iceil(self.layer.Xo, self.kernel.X_gran)) + 1, self.kernel.X_gran)),
            self.layer.Xo, self.spatial_splits.X_split)
        Cis_grid = filter_grid(
            list(range(Cis_min, max(Cis_min, iceil(self.layer.Ci, self.kernel.Ci_gran)) + 1, self.kernel.Ci_gran)),
            self.layer.Ci, 1, not self.layer.disable_scheduler_constraints)

        grid_list = (Xos_grid, Yos_grid, Cos_grid) + ((Cis_grid,) if not self.layer.is_standalone_dwc else ())
        grid_product = product(*grid_list)
        for grid in grid_product:
            Xos, Yos, Cos = grid[0:3]
            Cis = Cos if self.layer.is_standalone_dwc else grid[3]

            # Calculate corresponding input dimensions
            Yis = conv_input(Yos, self.layer.Ky, self.layer.Sy)
            Xis = iceil(conv_input(Xos, self.layer.Kx, self.layer.Sx) * Ci_block, self.kernel.mem_align) // Ci_block

            # Create a dict with input and output dimension
            convsubvol = ConvSubVol(Xis, Yis, Cis, Xos, Yos, Cos)

            # Calculate temporal splits
            temporalsplits = TemporalSplits(self.layer, convsubvol, self.spatial_splits)

            # NOTE: For Co_loop == 1 and Ci_loop == 1; wgt is pinned in memtile and coretile
            self.pin_wgt_bias_l1 = (temporalsplits.Co_loop == 1) and (temporalsplits.Ci_loop == 1) and (
                not self.layer.is_standalone_dwc)

            # NOTE: For Y_loop == 1 and Ci_loop == 1; ifm is pinned in memtile and coretile
            self.pin_ifm_l1 = (temporalsplits.Y_loop == 1) and (temporalsplits.Ci_loop == 1) and (not is_X8_split) and (
                not self.layer.is_standalone_dwc)

            # Determine ping-pong configuration
            if self.kernel.is_xint8 and (temporalsplits.Ci_loop == 1):
                tdm_pp_bool = False
            else:
                tdm_pp_bool = True

            conv_pp = ConvPingPong(
                ifm=PingPong(True,
                             False if (self.enable_add and temporalsplits.Ci_loop == 1) or self.pin_ifm_l1 else True),
                ofm=PingPong(True, False),
                wgt=PingPong(True, False if self.pin_wgt_bias_l1 else True),
                tdm=PingPong(True if self.enable_add else tdm_pp_bool, tdm_pp_bool),
            )

            # Calculate total memory requirement
            ofm_pp = conv_pp.ofm

            # Calculate memory sizes for this tiling
            l1_buffers = L1Buffers(convsubvol, conv_pp, self.kernel, self.layer, self.overlay, self.enable_add)

            memtile_params = L2Buffers(is_X8_split, convsubvol, temporalsplits, ofm_pp, l1_buffers.wgt_size, self.layer,
                                       self.kernel, self.spatial_splits, self.aie_rows, self.enable_add)

            kernel_constraints = self._check_kernel_constraints(convsubvol)
            scheduler_constraints = self._check_scheduler_constraints(
                temporalsplits,
                convsubvol,
                l1_buffers.wgt_size, is_X8_split, memtile_params.enable_ifm_streaming)
            scheduler_constraints = True if self.layer.disable_scheduler_constraints else scheduler_constraints

            other_constraints = self._check_other_constraints(convsubvol)

            # L2fusion validity constraint if l2fusion is enabled, and
            # for 'L3' residency, check min_memtile_size constraint (necessary but not sufficient)
            l2fusion_constraint = not l2fusion_pass or self._check_l2fusion_constraint(
                l1_buffers.wgt_size, convsubvol, is_X8_split)

            # Determine if this is a valid configuration
            if self.layer.is_standalone_dwc:
                is_valid = l1_buffers.core_mem_constraint
            else:
                is_valid = l1_buffers.core_mem_constraint and kernel_constraints and scheduler_constraints and other_constraints  # and l2fusion_constraint

            # If this is a valid configuration, store it
            if is_valid:
                # Increment valid count
                self.valid_count += 1
                num_buffers = 1 if l1_buffers.ifm_pong_addr is not None else 2
                acc_loop = temporalsplits.Ci_loop
                inner_loop = temporalsplits.Co_loop
                outer_loop = temporalsplits.Y_loop * temporalsplits.X_loop
                loop_constraint = (acc_loop * inner_loop * outer_loop <= (
                            config.MAX_REPEAT_COUNT * config.MAX_TASK_QUEUE_SIZE * num_buffers));
                tiling_entities = ValidTiling(convsubvol, temporalsplits, conv_pp, self.spatial_splits, l1_buffers,
                                              memtile_params, loop_constraint=loop_constraint)

                self.valid_core_subvols.append(tiling_entities)
                if self.layer.disable_scheduler_constraints:
                    return

                # Early exit if L2 fusion is enabled
                if l2fusion_pass:
                    return

    def check_core_constraints(self):
        pass