'''
This module implements the optimizer that lowers the data transfer IR
down to the physical buffer descriptors. External facing functions
are documented below.

compile_data_transfers - lowers a list of memtile and shim data transfers
to the physical buffer descriptor level
'''

DEBUG_MODE = False

from typing import List, Dict, Tuple, Optional, Union, Sequence, Any
from copy import copy, deepcopy
from collections import OrderedDict, defaultdict

from .types import (
    DevGen,
    OverlayShape, TileType, AieTile, DmaDir, DmaChannel, AieDma,
    TransferParams, DataTransfer, SyncStrategy,
    Lock, BufferDescriptor, BufferTask, DataBuffer,
)
from . import config
from .print_run_layer_compilation import print_task_queue_optimization_inputs

class DmaAllocator:
    '''
        The DmaAllocator class tracks the availability of resources: BDs, Locks, and Task Queue Depth.
        
        For BDs, we initialize a list of zeros where the length depends on the number of BDs available 
        in the pool, which is determined by DEV_GEN. When a BD ID is used, its corresponding index in 
        the list is set to 1. The first zero in the list indicates the next available BD ID.
        
        Locks and Task Queue Depth are tracked using counters, which are incremented with `counter += 1`.
    '''
    __slots__ = ('bd_counter', 'lock_counters', 'task_counters', 'barrier_count', 'shape')

    def __init__(self, shape: OverlayShape):
        self.bd_counter: Dict[str, List[int]] = {}
        self.lock_counters: Dict[AieTile, int] = {}
        self.task_counters: Dict[AieDma, int] = {}
        self.barrier_count: int = 1
        self.shape = shape
        
        if config.DEV_GEN == DevGen.Aie2p:
            self._aie2p_init_bd_counters()
        elif config.DEV_GEN == DevGen.Aie4:
            self._aie4_init_bd_counters()
        else:
            assert False
                    
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            # NOTE: Lock ID 0 is reserved as a "NOP" lock to be used for BDs that
            #       only perform and acquire op or only perform a release op.
            self.lock_counters[AieTile(TileType.Memtile, col, 0)] = 1
            self.lock_counters[AieTile(TileType.Shim, col, 0)] = 1
            for dir in (DmaDir.S2MM, DmaDir.MM2S):
                max_memtile_dma_channel = (
                    config.MAX_MEMTILE_S2MM_DMA_CHANNEL if dir == DmaDir.S2MM else
                    config.MAX_MEMTILE_MM2S_DMA_CHANNEL
                )
                max_shim_dma_channel = (
                    config.MAX_SHIM_S2MM_DMA_CHANNEL if dir == DmaDir.S2MM else
                    config.MAX_SHIM_MM2S_DMA_CHANNEL
                )
                for id in range(max_memtile_dma_channel + 1):
                    self.task_counters[AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(dir, id))] = 0
                for id in range(max_shim_dma_channel + 1):
                    self.task_counters[AieDma(AieTile(TileType.Shim, col, 0), DmaChannel(dir, id))] = 0

    def _aie2p_init_bd_counters(self):
        # NOTE: Shim BD ID 0 is reserved for control packet transfer.
        for col in range(self.shape.start_col, self.shape.start_col + self.shape.num_cols):
            self.bd_counter[str(AieTile(TileType.Memtile, col, 0))] =  [0] * (config.MAX_MEMTILE_BD_HI_ID  + 1)
            self.bd_counter[str(AieTile(TileType.Shim, col, 0))] = [1] + [0] * config.MAX_SHIM_BD_ID
        
    def _aie4_init_bd_counters(self):
        ch_id_range = {
            TileType.Memtile: {
                DmaDir.S2MM: (config.MAX_MEMTILE_S2MM_DMA_CHANNEL + 1),
                DmaDir.MM2S: (config.MAX_MEMTILE_MM2S_DMA_CHANNEL + 1)
                },
            TileType.Shim: {
                DmaDir.S2MM: (config.MAX_SHIM_S2MM_DMA_CHANNEL + 1),
                DmaDir.MM2S: (config.MAX_SHIM_MM2S_DMA_CHANNEL + 1)
                }
        } 
        
        for col in range(self.shape.start_col, self.shape.start_col + self.shape.num_cols):
            for tile_type in [TileType.Memtile, TileType.Shim]:
                aie_tile = AieTile(tile_type, col, 0)
                if tile_type == TileType.Memtile:
                    for ch_dir in [DmaDir.S2MM, DmaDir.MM2S]:
                        for ch_id in range(ch_id_range[tile_type][ch_dir]):
                            self.bd_counter[str(AieDma(aie_tile, DmaChannel(ch_dir, ch_id)))] = [0] * (config.MAX_MEMTILE_PRIVATE_BD_ID + 1)
                elif tile_type == TileType.Shim:
                    self.bd_counter[str(aie_tile)] = [0] * (config.MAX_SHIM_HI_BD_ID + 1)
    
    def _index_of_first_zero(self, nums: List[int], start_index: int, stop_index: int) -> Optional[int]:
        for idx in range(start_index, min(stop_index+1, len(nums))):
            if nums[idx] == 0:
                return idx
        raise RuntimeError(f'No more BDs available!')
    
    def _find_last_positive_index(self, values: List[int]) -> Optional[int]:
        for idx in range(len(values) - 1, -1, -1):
            if values[idx] > 0:
                return idx
        return None
    
    def _extract_aiedma_config(self, dma: AieDma):
        dma_tile = dma.tile
        dma_channel = dma.channel
        return (dma_tile.col, dma_tile.type, dma_channel.dir, dma_channel.id)

    def _aie2p_bd_id(self, dma: AieDma) -> int:
        if dma.tile.type == TileType.Memtile:
            idx = dma.channel.id % 2           
            start_index = 0 if idx == 0 else (config.MAX_MEMTILE_BD_LO_ID + 1)
            stop_index = config.MAX_MEMTILE_BD_LO_ID if idx == 0 else config.MAX_MEMTILE_BD_HI_ID
            id = self._index_of_first_zero(self.bd_counter[str(dma.tile)], start_index, stop_index)
            self.bd_counter[str(dma.tile)][id] = 1
            is_invalid = id > [config.MAX_MEMTILE_BD_LO_ID, config.MAX_MEMTILE_BD_HI_ID][idx]
            
        elif dma.tile.type == TileType.Shim:
            # NOTE: Shim BD ID 0 is reserved for control packet transfer.
            start_index = 1
            id = self._index_of_first_zero(self.bd_counter[str(dma.tile)], start_index, config.MAX_SHIM_BD_ID)
            self.bd_counter[str(dma.tile)][id] = 1
            is_invalid = id > config.MAX_SHIM_BD_ID
        if is_invalid:
            raise RuntimeError(f'Failed to allocate BD ID on {dma}!')
        return id

    def _aie4_bd_id(self, dma: AieDma) -> int:   
        if dma.tile.type == TileType.Memtile:
            is_mm2s = int(dma.channel.dir == DmaDir.MM2S)
            num_s2mm = config.MAX_MEMTILE_S2MM_DMA_CHANNEL + 1
            index = dma.channel.id + (is_mm2s * num_s2mm)
            private_id = self._index_of_first_zero(self.bd_counter[str(dma)], 0, config.MAX_MEMTILE_PRIVATE_BD_ID)
            self.bd_counter[str(dma)][private_id] = 1
            num_ids = config.MAX_MEMTILE_PRIVATE_BD_ID + 1
            id = private_id + (index * num_ids)
            is_invalid = private_id > config.MAX_MEMTILE_PRIVATE_BD_ID
            
        elif dma.tile.type == TileType.Shim:
            max_lo_shim_channel = (
                config.MAX_SHIM_LO_S2MM_DMA_CHANNEL if dma.channel.dir == DmaDir.S2MM else
                config.MAX_SHIM_LO_MM2S_DMA_CHANNEL
            )
            index = 0 if dma.channel.id <= max_lo_shim_channel else 1
            start = 0 if dma.channel.id <= max_lo_shim_channel else (config.MAX_SHIM_LO_BD_ID + 1)
            stop = config.MAX_SHIM_LO_BD_ID  if dma.channel.id <= max_lo_shim_channel else config.MAX_SHIM_HI_BD_ID
            id = self._index_of_first_zero(self.bd_counter[str(dma.tile)], start, stop)
            self.bd_counter[str(dma.tile)][id] += 1
            is_invalid = id > [config.MAX_SHIM_LO_BD_ID, config.MAX_SHIM_HI_BD_ID][index]
            
        if is_invalid:
            raise RuntimeError(f'Failed to allocate BD ID on {dma}!')
        return id

    def bd_id(self, dma: AieDma) -> int:
        if config.DEV_GEN == DevGen.Aie2p:
            return self._aie2p_bd_id(dma)
        elif config.DEV_GEN == DevGen.Aie4:
            return self._aie4_bd_id(dma)
        else:
            assert False

    def lock_id(self, tile: AieTile) -> int:
        id = self.lock_counters[tile]
        self.lock_counters[tile] += 1
        is_invalid = ((tile.type == TileType.Memtile and id > config.MAX_MEMTILE_LOCK_ID) or
                      (tile.type == TileType.Shim and id > config.MAX_SHIM_LOCK_ID))
        if is_invalid:
            raise RuntimeError(f'Failed to allocate lock ID on {tile}!')
        return id

    def barrier_id(self) -> int:
        if config.DEV_GEN == DevGen.Aie2p:
            raise RuntimeError(f'Remote Barrier is only support on CERT based devices!')
        id = self.barrier_count
        is_invalid = id > config.MAX_REMOTE_BARRIER
        if is_invalid:
            raise RuntimeError('Failed to allocate Remote Barrier!')
        self.barrier_count += 1
        return id
    
    def check_next_available_lock_id(self, tile: AieTile) -> int:
        id = self.lock_counters[tile]
        return id

    def enqueue_task(self, dma: AieDma):
        self.task_counters[dma] += 1
        if self.task_counters[dma] > config.MAX_TASK_QUEUE_SIZE:
            raise RuntimeError(f'Failed to enqueue task on {dma}!')

    def clear_tasks(self):
        for dma in self.task_counters:
            self.task_counters[dma] = 0


# Compute unique DMA channels for transfer params
def transfer_param_dmas(transfer_params: List[TransferParams]) -> List[AieDma]:
    return list(OrderedDict.fromkeys([param.dma for param in transfer_params]))


# Compute chains of transfer params that share the same DMA channel
def transfer_param_chains(transfer_params: List[TransferParams]) -> List[List[TransferParams]]:
    dmas = transfer_param_dmas(transfer_params)
    chains = [[param for param in transfer_params if param.dma == dma]
              for dma in dmas]
    return chains


# Compute unique DMA channels for buffer descriptors
def buffer_descriptor_dmas(buffer_descriptors: List[BufferDescriptor]) -> List[AieDma]:
    return list(OrderedDict.fromkeys([bd.aie_dma for bd in buffer_descriptors]))


# Compute chains of buffer descriptors that share the same DMA channel
def buffer_descriptor_chains(buffer_descriptors: List[BufferDescriptor]) -> List[List[BufferDescriptor]]:
    dmas = buffer_descriptor_dmas(buffer_descriptors)
    chains = [[bd for bd in buffer_descriptors if bd.aie_dma == dma]
              for dma in dmas]
    return chains


def conv_param_to_bd(
    alloc: DmaAllocator,
    tile: AieTile,
    buffer_addr: int,
    param: TransferParams,
    remote_barrier_id: int = None,
) -> BufferDescriptor:
    if param.dma.tile.type == TileType.Memtile:
        tile_diff = (tile.col - param.dma.tile.col) + config.MAX_NEIGHBOR_ACCESS
        buffer_offset = config.MEMTILE_BASE_ADDR + (tile_diff * (config.MAX_MEMTILE_ADDR + 1))
        buffer_addr += buffer_offset
    else:
        if param.shim_buffer_index is not None:
            buffer_addr = param.shim_buffer_index
    packet_enable = (
        (config.DEV_GEN == DevGen.Aie2p) and
        (param.dma.tile.type == TileType.Shim) and
        (param.dma.channel.dir == DmaDir.MM2S) and
        (param.dma.channel.id == config.SHIM_CTRL_MM2S_CHANNEL_ID)
    )
    packet_id = config.DATA_TRANSFER_PKT_ID if packet_enable else None
    return BufferDescriptor(
        param.dma, alloc.bd_id(param.dma),
        buffer_addr=buffer_addr,
        offset=param._offset,
        length=param._length,
        step=param._step,
        wrap=param._wrap,
        padding=param._padding,
        iter_step=param._iter_step,
        iter_wrap=param._iter_wrap,
        packet_enable=packet_enable,
        packet_id=packet_id,
        name=param.name,
        barrier_id=remote_barrier_id,
    )


#
# Optimization Pass #1: Data transfer resource mapping
#
#       a. Buffer descriptor allocation
#       b. Lock allocation
#       c. Buffer synchronization
#
# Inputs: Data transfer, DMA allocators
# Output: BD and lock allocations
#
# NOTE: The output List[List[BufferDescriptor]] is a nested list
#       where the inner list represents the allocations for a single buffer
#       and the outer list represents all allocations for a multiple
#       buffering scheme (single, double, triple, etc). We need to organize
#       allocations this way to perform odd repeat count detection during
#       queue depth allocation in optimization pass #3.
#


BufferAllocation = Tuple[List[List[BufferDescriptor]], List[Lock]]


def compile_parallel_1_to_N_transfer(
    alloc: DmaAllocator,
    tile: AieTile,
    buffer_addrs: List[int],
    write_params: List[TransferParams],
    read_params: List[TransferParams],
    reuse_ratio: int,
) -> BufferAllocation:

    writers = transfer_param_dmas(write_params)
    readers = transfer_param_dmas(read_params)
    num_writers = len(writers)
    num_readers = len(readers)

    assert num_writers == 1
    assert num_readers >= 1

    alloc_bds = []
    alloc_locks = []

    for buffer_addr in buffer_addrs:
        # Allocate one lock per DMA channel            
        prod_lock = Lock(tile, alloc.lock_id(tile), +num_readers * reuse_ratio)
        cons_locks = [Lock(tile, alloc.lock_id(tile), +0)
                      for _ in range(num_readers)]

        bds = []

        # Allocate BDs for the writer transfer chain
        write_param_chains = transfer_param_chains(write_params)
        bd_chain = [conv_param_to_bd(alloc, tile, buffer_addr, param)
                    for param in write_param_chains[0]]
        # First BD in the chain must acquire the producer lock
        bd_chain[0].lock_enable = True
        bd_chain[0].lock_acq = prod_lock
        bd_chain[0].lock_acq_value = -num_readers * reuse_ratio
        # Last BD in the chain will give credits to the first consumer lock
        bd_chain[-1].lock_enable = True
        bd_chain[-1].lock_rel = cons_locks[0]
        bd_chain[-1].lock_rel_value = +reuse_ratio
        bds += bd_chain
        # Create N - 1 zero length BDs to give credits to the remaining consumer locks
        for i in range(1, num_readers):
            bds.append(BufferDescriptor(
                writers[0], alloc.bd_id(writers[0]),
                lock_enable=True,
                lock_rel=cons_locks[i],
                lock_rel_value=+reuse_ratio,
                is_lock_bd=True,
                name=f"Lock BD - {bd_chain[0].name}",
            ))

        # Allocate BDs for each read transfer chain
        read_param_chains = transfer_param_chains(read_params)
        for i in range(num_readers):
            bd_chain = [conv_param_to_bd(alloc, tile, buffer_addr, param)
                        for param in read_param_chains[i]]
            # First BD in the chain must decrement the i'th consumer lock
            bd_chain[0].lock_enable = True
            bd_chain[0].lock_acq = cons_locks[i]
            bd_chain[0].lock_acq_value = -1
            # Last BD in the chain must increment the producer lock
            bd_chain[-1].lock_enable = True
            bd_chain[-1].lock_rel = prod_lock
            bd_chain[-1].lock_rel_value = +1
            bds += bd_chain

        # Save allocations from this buffer
        alloc_bds.append(bds)
        alloc_locks.append(prod_lock)
        alloc_locks += cons_locks

    return (alloc_bds, alloc_locks)


def compile_parallel_N_to_1_transfer(
    alloc: DmaAllocator,
    tile: AieTile,
    buffer_addrs: List[int],
    write_params: List[TransferParams],
    read_params: List[TransferParams],
) -> BufferAllocation:

    writers = transfer_param_dmas(write_params)
    readers = transfer_param_dmas(read_params)
    num_writers = len(writers)
    num_readers = len(readers)

    assert num_writers >= 1
    assert num_readers == 1

    alloc_bds = []
    alloc_locks = []

    for buffer_addr in buffer_addrs:
        # Allocate one lock per DMA channel           
        prod_locks = [Lock(tile, alloc.lock_id(tile), +1)
                      for _ in range(num_writers)]
        cons_lock = Lock(tile, alloc.lock_id(tile), +0)

        bds = []

        write_param_chains = transfer_param_chains(write_params)
        for i in range(num_writers):
            bd_chain = [conv_param_to_bd(alloc, tile, buffer_addr, param)
                        for param in write_param_chains[i]]
            # First BD in the chain must decrement the i'th producer lock
            bd_chain[0].lock_enable = True
            bd_chain[0].lock_acq = prod_locks[i]
            bd_chain[0].lock_acq_value = -1
            # Last BD in the chain must increment the consumer lock
            bd_chain[-1].lock_enable = True
            bd_chain[-1].lock_rel = cons_lock
            bd_chain[-1].lock_rel_value = +1
            bds += bd_chain

        read_param_chains = transfer_param_chains(read_params)
        bd_chain = [conv_param_to_bd(alloc, tile, buffer_addr, param)
                    for param in read_param_chains[0]]
        # First BD in the chain must acquire the consumer lock
        # after all producers have completed.
        bd_chain[0].lock_enable = True
        bd_chain[0].lock_acq = cons_lock
        bd_chain[0].lock_acq_value = -num_writers
        # Last BD in the chain must increment the first producer lock
        bd_chain[-1].lock_enable = True
        bd_chain[-1].lock_rel = prod_locks[0]
        bd_chain[-1].lock_rel_value = +1
        bds += bd_chain
        # Create N - 1 zero length BDs to increment the remaining producer locks
        for i in range(1, num_writers):
            bds.append(BufferDescriptor(
                readers[0], alloc.bd_id(readers[0]),
                lock_enable=True,
                lock_rel=prod_locks[i],
                lock_rel_value=+1,
                is_lock_bd = True,
                name =f"Lock BD - {bd_chain[0].name}",
            ))

        alloc_bds.append(bds)
        alloc_locks += prod_locks
        alloc_locks.append(cons_lock)

    return (alloc_bds, alloc_locks)


def compile_serial_M_to_N_transfer(
    alloc: DmaAllocator,
    tile: AieTile,
    buffer_addrs: List[int],
    write_params: List[TransferParams],
    read_params: List[TransferParams],
    reuse_ratio: int,
) -> BufferAllocation:

    writers = transfer_param_dmas(write_params)
    readers = transfer_param_dmas(read_params)
    num_writers = len(writers)
    num_readers = len(readers)

    assert num_writers >= 1
    assert num_readers >= 1

    alloc_bds = []
    alloc_locks = []
    
    for buffer_addr in buffer_addrs:
        # Allocate one lock per DMA channel           
        prod_locks = [Lock(tile, alloc.lock_id(tile), (+reuse_ratio if i == 0 else +0))
                      for i in range(num_writers)]
        cons_locks = [Lock(tile, alloc.lock_id(tile), +0)
                      for _ in range(num_readers)]
        locks = prod_locks + cons_locks

        bds = []

        # Allocate BDs for write param chains
        write_param_chains = transfer_param_chains(write_params)
        for i in range(num_writers):
            bd_chain = [conv_param_to_bd(alloc, tile, buffer_addr, param)
                        for param in write_param_chains[i]]
            # First BD in the chain must acquire the i'th producer lock
            bd_chain[0].lock_enable = True
            bd_chain[0].lock_acq = prod_locks[i]
            bd_chain[0].lock_acq_value = -reuse_ratio
            # Last BD in the chain must release the next producer lock or
            # the first consumer lock (if this is the last producer)
            bd_chain[-1].lock_enable = True
            bd_chain[-1].lock_rel = locks[i + 1]
            bd_chain[-1].lock_rel_value = +reuse_ratio
            bds += bd_chain
            
        # Allocate BDs for read param chains
        read_param_chains = transfer_param_chains(read_params)
        for i in range(num_readers):
            bd_chain = [conv_param_to_bd(alloc, tile, buffer_addr, param)
                        for param in read_param_chains[i]]
            # First BD in the chain must acquire the i'th consumer lock
            bd_chain[0].lock_enable = True
            bd_chain[0].lock_acq = cons_locks[i]
            bd_chain[0].lock_acq_value = -1
            # Last BD in the chain must release the next consumer lock or
            # the first producer lock (if this is the last consumer)
            bd_chain[-1].lock_enable = True
            bd_chain[-1].lock_rel = locks[(num_writers + i + 1) % len(locks)]
            bd_chain[-1].lock_rel_value = +1
            bds += bd_chain

        alloc_bds.append(bds)
        alloc_locks += locks

    return (alloc_bds, alloc_locks)


def compile_async_transfer(
    alloc: DmaAllocator,
    tile: AieTile,
    buffer_addrs: List[int],
    transfer_params: List[TransferParams],
    sync_strategy: SyncStrategy,
) -> BufferAllocation:

    # Check that all transfers have the same DMA direction
    # NOTE: Async transfers must be read-only or write-only
    for i in range(1, len(transfer_params)):
        assert transfer_params[i].dma.channel.dir == transfer_params[0].dma.channel.dir

    if sync_strategy == SyncStrategy.Remote_Barrier and config.IS_MULTI_UC:
        remote_barrier_id = alloc.barrier_id()
    else:
        remote_barrier_id = None

    # Allocate BDs for each transfer
    alloc_bds = [[conv_param_to_bd(alloc, tile, buffer_addr, param, remote_barrier_id)
                  for param in transfer_params]
                 for buffer_addr in buffer_addrs]

    return (alloc_bds, [])


#
# Optimization Pass #2: Buffer descriptor chaining
#
#       a. Multiple buffer chaining
#       b. Channel re-use chaining
#
# Inputs: BD allocations
# Output: None (BDs are chained inplace)
#


def chain_buffer_descriptors(buffer_descriptors: List[List[BufferDescriptor]]):
    # NOTE: This pass is surprisingly easy to implement due to the
    #       structure of the IR. Under the hood, there are two different
    #       types of chaining happening here.
    #
    #       1. Multiple buffer chaining
    #       2. Channel re-use chaining
    #
    #       When we flatten the list of BD allocations, this will automatically
    #       chain together BDs on the same channel in the multiple buffering
    #       scheme. Within a single buffer, the user may specify multiple
    #       write/read transfer params running on the same channel. These
    #       should run in the same order as specified, so we can chain them
    #       together all in one pass.
    bds = [buffer_descriptors[i][j]
           for i in range(len(buffer_descriptors))
           for j in range(len(buffer_descriptors[i]))]
    chains = buffer_descriptor_chains(bds)
    for chain in chains:
        for i in range(len(chain) - 1):
            chain[i].use_next_bd = True
            chain[i].next_bd = chain[i + 1]


#
# Optimization Pass #3: Task queue management
#
#       a. Queue depth allocation
#       b. DMA Channel synchronization
#
# Inputs: Allocated BDs, repeat counts
# Output: Buffer tasks
#
# NOTE: Each element of the output List[List[BufferTask]] implies a synchronization
#       point in the host configuration code. The elements of each List[BufferTask]
#       will be enqueued followed by a wait for DMA completion. No tasks
#       from the next List[BufferTask] will start until all tasks from
#       the previous one finish.
#


def enqueue_multiple_buffer(
    alloc: DmaAllocator,
    buffer_descriptors: List[List[BufferDescriptor]],
    repeat_count: int
) -> List[BufferTask]:

    # Extract the start BDs for each buffer in the multiple buffering scheme
    #   NOTE: This step effectively filters each buffer to only contain
    #         the first BD for each DMA channel in that buffer. We will
    #         use this to enqueue from the middle of the chain if an
    #         odd repeat count is needed.
    num_buffers = len(buffer_descriptors)
    buffer_start_bds = [[chain[0] for chain in buffer_descriptor_chains(buffer)]
                        for buffer in buffer_descriptors]

    # Assert that every buffer in the multiple buffer chain has
    # the same number of start BDs.
    #       NOTE: This must hold for any valid buffer allocation.
    #             All buffers are running the same access pattern,
    #             just on different address locations
    for bds in buffer_start_bds:
        assert len(bds) == len(buffer_start_bds[0])

    buffer_tasks = []
    num_start_tasks = len(buffer_start_bds[0])

    for i in range(num_start_tasks):
        counter = repeat_count
        while counter >= num_buffers * config.MAX_REPEAT_COUNT:
            # Enqueue the first buffer in the chain with
            # the specified repeat count
            alloc.enqueue_task(buffer_start_bds[0][i].aie_dma)
            buffer_tasks.append(BufferTask(buffer_start_bds[0][i], config.MAX_REPEAT_COUNT))
            counter -= num_buffers * config.MAX_REPEAT_COUNT
        # Perform two enqueues when repeat count is not
        # divisible by number of buffers. The second enqueue
        # will start in the middle of the multiple buffer chain.
        repeat1 = counter // num_buffers
        repeat2 = counter % num_buffers
        if repeat1 > 0:
            alloc.enqueue_task(buffer_start_bds[0][i].aie_dma)
            buffer_tasks.append(BufferTask(buffer_start_bds[0][i], repeat1))
        if repeat2 > 0:
            buffer_index = num_buffers - repeat2
            alloc.enqueue_task(buffer_start_bds[buffer_index][i].aie_dma)
            buffer_tasks.append(BufferTask(buffer_start_bds[buffer_index][i], 1))

    return buffer_tasks


def compute_buffer_tasks(
    alloc: DmaAllocator,
    buffer_descriptors: List[List[BufferDescriptor]],
    repeat_count: int,
    reuse_ratio: int
) -> List[BufferTask]:
    writers = [[buffer_descriptors[i][j]
                for j in range(len(buffer_descriptors[i]))
                if buffer_descriptors[i][j].aie_dma.channel.dir == DmaDir.S2MM]
               for i in range(len(buffer_descriptors))]
    readers = [[buffer_descriptors[i][j]
                for j in range(len(buffer_descriptors[i]))
                if buffer_descriptors[i][j].aie_dma.channel.dir == DmaDir.MM2S]
               for i in range(len(buffer_descriptors))]
    writer_repeat = repeat_count
    reader_repeat = repeat_count * reuse_ratio
    buffer_tasks = []
    buffer_tasks += enqueue_multiple_buffer(alloc, writers, writer_repeat)
    buffer_tasks += enqueue_multiple_buffer(alloc, readers, reader_repeat)
    return buffer_tasks

'''
    Optimization Pass #4: Maximize Task Queue Depth Utilization

    This optimization pass aims to maximize Task Queue Depth utilization, thereby reducing the number of 
    Mask_Poll operations in the dataflow schedule. Each Mask_Poll accounts for approximately 95 µs in execution time.

    Inputs: transfer_buffers, transfer_locks, transfer_repeats, transfer_reuse

    Steps:
    1. This pass targets BDs scheduled for reconfiguration. Instead of reconfiguring and executing them after a 
    Mask_Poll, we replace them with new DataTransfer objects that have the same configuration as the BD 
    post-reconfiguration. These new objects are queued for immediate execution, skipping the Mask_Poll.

    2. New DataTransfer objects consume new BDs and Locks. Thus, this optimization is limited by the remaining 
    available resources. However, it guarantees that if the initial resource allocation succeeded, 
    this optimization will not degrade performance.

    3. We first create a snapshot of the initial resource allocations, referred to as "virtual allocations". 
    These virtual allocations are updated while determining which DataTransfer objects can be folded and 
    how many times. Once all available resources are exhausted in the virtual space, we update the 
    actual resource allocations with the new folded DataTransfer objects.

    4. The decision to fold a DataTransfer object depends on four key factors:
    - The total number of BDs used by the DataTransfer
    - The total number of Locks it requires
    - The Task Queue depth required for folding, which depends on its repeat_count
    - Whether the DataTransfer is involved in channel sharing, and if so, the valid folding window 
        (i.e., from which phase to which phase folding is allowed)

    5. If we fold a few reconfigurations and then run out of Task Queue Depth, we do reconfigurations until the 
    depth becomes available again. Folding can then resume. However, if BDs or Locks are exhausted 
    after folding a few reconfigurations, no further folding can occur.
'''
def get_positive_indices(values: List[int]) -> List[int]:
    """
        Return the indices of all strictly positive elements (greater than zero) in the list.
    """
    return [i for i, val in enumerate(values) if val > 0]


def count_non_zero(values: List[int]) -> int:
    '''
        Return the count of non-zero elements in the list.
    '''
    return sum(1 for v in values if v != 0)


def count_zeros_in_range(lst: List[int], start_idx: int, end_idx: int):
    """
        Count the number of zeros in the list `lst` between `start_idx` and `end_idx` (inclusive).
    """
    if start_idx < 0 or end_idx >= len(lst) or start_idx > end_idx:
        raise ValueError("Invalid index range")
    return lst[start_idx:end_idx + 1].count(0)

def task_queue_optimization(shape: OverlayShape,
                            alloc: DmaAllocator,
                            transfer_buffers: List[List[List[BufferDescriptor]]],
                            transfer_locks: List[List[Lock]],
                            transfer_repeats: List[List[int]],
                            transfer_reuse: List[int]):
    if DEBUG_MODE:
        print_task_queue_optimization_inputs(shape,
                            alloc,
                            transfer_buffers,
                            transfer_locks,
                            transfer_repeats,
                            transfer_reuse)
    
    NUM_COLS = shape.num_cols
    '''
        Sync Data Transfer or SyncDT
        NOTE: The purpose of SyncDT is to maintain synchronization when a phases DataTransfer is folded
    '''
    def print_table(headers, rows):
        col_widths = [max(len(str(item)) for item in col) for col in zip(headers, *rows)]
        
        def format_row(row):
            return " | ".join(str(item).ljust(width) for item, width in zip(row, col_widths))
        
        print(format_row(headers))
        print("-+-".join('-' * width for width in col_widths))
        for row in rows:
            print(format_row(row))
            

    def get_bd_dma_config(bd: Union[BufferDescriptor, AieDma]):
        '''
            Return the Column, Tile Type, Channel Direction and
            Channel ID for a BufferDescriptor or AieDma.
        '''
        if isinstance(bd, BufferDescriptor):
            tile = bd.aie_dma.tile
            channel = bd.aie_dma.channel
        else:
            tile = bd.tile
            channel = bd.channel
        return tile.col, tile.type, channel.dir, channel.id

        
    def check_if_not_async(bds: List[BufferDescriptor]):
        '''
            To check if a DataTransfer is asynchronous or not i.e. only has either readers or writers
        '''
        num_s2mm = 0
        num_mm2s = 0
        for bd in bds:
            if bd.aie_dma.channel.dir == DmaDir.S2MM:
                num_s2mm += 1
            elif bd.aie_dma.channel.dir == DmaDir.MM2S:
                num_mm2s += 1
        return num_s2mm > 0 and num_mm2s > 0


    sync_dt_locks = {}
    def sync_dt_lock_management(writer: AieDma, local_initial_lock_left):
        '''
            SyncDT is used whenever a DataTransfer is folded. To minimize resource 
            utilization for maintaining synchronization, we reuse the same Locks. 
            For this purpose, we archive the used Locks and, instead of reserving 
            Locks in advance, we dynamically track and reuse them as needed.
        '''
        if str(writer) not in sync_dt_locks.keys():
            sync_dt_locks[str(writer)] = {}
            sync_dt_locks[str(writer)]["ACQ"] = Lock(writer.tile, alloc.lock_id(writer.tile), +1)
            sync_dt_locks[str(writer)]["REL"] = Lock(writer.tile, alloc.lock_id(writer.tile), +0)
            local_initial_lock_left[str(writer.tile)] -= 2
            initial_lock_left[str(writer.tile)] -= 2
        return sync_dt_locks[str(writer)]["ACQ"], sync_dt_locks[str(writer)]["REL"]


    def bd_id_offset_adjust(bd_id: int, tile_type: TileType) -> int:
        '''
            For AIE4 Memtiles, the BD ID must be a private BD ID.
        '''
        dev_gen = config.DEV_GEN
        if dev_gen == DevGen.Aie4 and tile_type == TileType.Memtile:
            bd_id = bd_id % (config.MAX_MEMTILE_PRIVATE_BD_ID + 1)
        return bd_id


    sync_dt_bds = {}
    def sync_dt_bd_management(writer: AieDma, reader: AieDma,
                              local_bd_id_tracking_queue, local_bd_left_queue):
        '''
            SyncDT is used whenever a DataTransfer is folded. To minimize resource 
            utilization for maintaining synchronization, we reuse the same BDs. 
            For this purpose, we archive the used BDs and, instead of reserving 
            BDs in advance, we dynamically track and reuse them as needed.
        '''
        for task in [writer, reader]:
            if str(task) not in sync_dt_bds.keys():
                new_bd_id = alloc.bd_id(task)
                sync_dt_bds[str(task)] = new_bd_id
                new_bd_id_local = bd_id_offset_adjust(new_bd_id, writer.tile.type)

                col, tiletype, ch_dir, ch_id = get_bd_dma_config(task)
                search_key = get_search_key(col, tiletype, ch_dir, ch_id)
                assert local_bd_id_tracking_queue[search_key][new_bd_id_local] == 0, "BD_ID has already been used by another buffer"
                local_bd_id_tracking_queue[search_key][new_bd_id_local] += 1
                local_bd_left_queue[search_key] -= 1
                
                bd_id_tracking_queue[search_key][new_bd_id_local] += 1
                bd_left_queue[search_key] -= 1
        return sync_dt_bds[str(writer)], sync_dt_bds[str(reader)]
    

    def generate_sync_dt(col, writer: AieDma, reader: AieDma, fold_id: int):
        '''
            Return a list of BDs that serve as the SyncDT, based on the column 
            and the writer-reader pair between which synchronization needs to 
            be established. Also returns the Locks used by these BDs.
        '''
        writer_lock, reader_lock = sync_dt_lock_management(writer, initial_lock_left)
        writer_bd_id, reader_bd_id = sync_dt_bd_management(writer, reader, bd_id_tracking_queue, bd_left_queue)
        return [
            BufferDescriptor(AieDma(AieTile(TileType.Memtile, col, 0), writer.channel),
                             id=writer_bd_id,
                             lock_enable=True,
                             lock_acq=writer_lock,
                             lock_acq_value=-1,
                             lock_rel=reader_lock,
                             lock_rel_value=+1,
                             fold=fold_id,
                             name="SyncDT"),
            BufferDescriptor(AieDma(AieTile(TileType.Memtile, col, 0), reader.channel),
                             id=reader_bd_id,
                             lock_enable=True,
                             lock_acq=reader_lock,
                             lock_acq_value=-1,
                             lock_rel=writer_lock,
                             lock_rel_value=+1,
                             fold=fold_id,
                             name="SyncDT"),
        ], [writer_lock, reader_lock]


    def get_writer_reader_pair(bds: List[BufferDescriptor]):
        '''
            A DataTransfer can involve multiple writers and readers, but 
            synchronization does not need to be established between all possible pairs. 
            Establishing sync between the writer with the smallest channel ID 
            and the reader with the largest channel ID is sufficient to ensure 
            synchronization across all other writer-reader pairs.

        '''
        writer = None
        reader = None
        min_s2mm_ch_id = float('inf')
        max_mm2s_ch_id = float('-inf')
        for bd in bds:
            _, tile_type, ch_dir, ch_id = get_bd_dma_config(bd)
            if tile_type is not TileType.Memtile:
                continue                  
            if ch_dir is DmaDir.S2MM and ch_id < min_s2mm_ch_id:
                min_s2mm_ch_id = ch_id
                writer = bd.aie_dma
            elif ch_dir is DmaDir.MM2S and ch_id > max_mm2s_ch_id:
                max_mm2s_ch_id = ch_id
                reader = bd.aie_dma

        return writer, reader


    def insert_sync_dt(bds: List[BufferDescriptor], num_phase: int, phase: int, fold_id: int):
        '''
            Return the pair of BDs, the repeat count, reuse ratio, and Locks for the SyncDT.
        '''
        repeat_count = [0] * num_phase
        repeat_count[phase] = 2
        writer, reader = get_writer_reader_pair(bds)
        sync_dt_bds, locks = generate_sync_dt(bds[0].aie_dma.tile.col, writer, reader, fold_id)
        reuse_ratio = 1
        return sync_dt_bds, locks, repeat_count, reuse_ratio
        

    def index_of_first_zero(nums: List[int], start_index: int, stop_index: int) -> Optional[int]:
        """
            Return the index of the first zero in the list `lst` 
            between `start_index` and `stop_index` (inclusive).
        """
        for idx in range(start_index, min(stop_index+1, len(nums))):
            if nums[idx] == 0:
                return idx
        raise RuntimeError(f'No more Virtual BDs available!')


    def find_last_positive_index(values: List[int]) -> Optional[int]:
        '''
            Return the index of the last positive value in the list.
        '''
        for idx in range(len(values) - 1, -1, -1):
            if values[idx] > 0:
                return idx
        return None


    def has_multiple_non_zero(values: List[int]) -> bool:
        '''
            Return True if more than one non-zero value exists in the list.
        '''
        return sum(1 for v in values if v != 0) > 1
        

    def analyze_bd_member_types(buf_desc: BufferDescriptor) -> dict:
        def normalize(value):
            if isinstance(value, list):
                return [normalize(x) for x in value]
            elif isinstance(value, tuple):
                return tuple(normalize(x) for x in value)
            else:
                try:
                    return int(value)
                except (ValueError, TypeError):
                    return value  # leave unchanged if not castable
        
        result = {}

        # Normalize everything
        offset = normalize(buf_desc._offset)
        length = normalize(buf_desc._length)
        step = normalize(buf_desc._step)
        wrap = normalize(buf_desc._wrap)
        padding = normalize(buf_desc._padding)
        iter_step = buf_desc._iter_step
        iter_wrap = buf_desc._iter_wrap

        def is_list_of_ints(val):
            return isinstance(val, list) and all(isinstance(x, int) for x in val)

        def is_list_of_list_of_ints(val):
            return isinstance(val, list) and all(isinstance(x, list) and all(isinstance(i, int) for i in x) for x in val)

        def is_list_of_tuples(val):
            return isinstance(val, list) and all(
                isinstance(t, tuple) and len(t) == 2 and all(isinstance(i, int) for i in t) for t in val
            )

        def is_list_of_list_of_tuples(val):
            return isinstance(val, list) and all(
                isinstance(x, list) and all(
                    isinstance(t, tuple) and len(t) == 2 and all(isinstance(i, int) for i in t) for t in x
                ) for x in val
            )

        # offset: List[int] vs int
        result["offset"] = is_list_of_ints(offset)

        # length: List[int] vs int
        result["length"] = is_list_of_ints(length)

        # step: List[List[int]] vs List[int]
        if isinstance(step, list):
            if is_list_of_list_of_ints(step):
                result["step"] = True
            elif is_list_of_ints(step):
                result["step"] = False
            else:
                result["step"] = False
        else:
            result["step"] = False

        # wrap: List[List[int]] vs List[int]
        if isinstance(wrap, list):
            if len(wrap) == 0:
                result["wrap"] = False
            elif is_list_of_list_of_ints(wrap):
                result["wrap"] = True
            elif is_list_of_ints(wrap):
                result["wrap"] = False
            else:
                result["wrap"] = False
        else:
            result["wrap"] = False

        # padding: List[List[Tuple[int, int]]] vs List[Tuple[int, int]]
        if isinstance(padding, list):
            if len(padding) == 0:
                result["padding"] = False
            elif is_list_of_list_of_tuples(padding):
                result["padding"] = True
            elif is_list_of_tuples(padding):
                result["padding"] = False
            else:
                result["padding"] = False
        else:
            result["padding"] = False

        # iter_step: List[Optional[int]] vs Optional[int]
        result["iter_step"] = isinstance(iter_step, list)

        # iter_wrap: List[Optional[int]] vs Optional[int]
        result["iter_wrap"] = isinstance(iter_wrap, list)

        return result


    def construct_buffer_descriptor(bd: BufferDescriptor,
                                    new_bd_id: int = None,
                                    num_phases: int = None,
                                    folded_phase_indices: List[int] = None,
                                    phase_idx_mapping: List[int] = None,
                                    fold_id: int = None,
                                    new_lock_ids: Dict = None,
                                    ):
        '''
            Construct a BufferDescriptor using the provided List[index], the new BD ID, 
            and the original BD configuration. The new BD should be configured such that 
            its first enqueue phase is already covered in the Control Packets.
        '''
        bd_member_types = analyze_bd_member_types(bd)
        
        def get_val(val, idx):
            if isinstance(val, list):
                return val[idx]
            return val
        
        def bd_data_slice(val, default_value, key):
            if bd_member_types[key]:
                repeat_list = [default_value] * num_phases

                for i in range(1, len(folded_phase_indices)):
                    start = folded_phase_indices[i - 1] if i > 1 else 0
                    end = folded_phase_indices[i]
                    repeat_list[start:end] = [get_val(val, phase_idx_mapping[i - 1])] * (end - start)

                # Handle the final segment
                start = folded_phase_indices[-1]
                repeat_list[start:] = [get_val(val, phase_idx_mapping[-1])] * (len(repeat_list) - start)
            else:
                repeat_list = val
            return repeat_list

        if bd.is_lock_bd or isinstance(bd._length, int):  # Scalar config
            return BufferDescriptor(
                bd.aie_dma,
                id=new_bd_id if new_bd_id is not None else bd.id,
                buffer_addr=bd.buffer_addr,
                offset=bd._offset,
                length=bd._length,
                step=bd._step,
                wrap=bd._wrap,
                padding=bd._padding,
                iter_step=bd._iter_step,
                iter_wrap=bd._iter_wrap,
                lock_enable=bd.lock_enable,
                lock_acq=new_lock_ids[str(bd.lock_acq)] if bd.lock_acq is not None else None,
                lock_acq_value=bd.lock_acq_value,
                lock_rel=new_lock_ids[str(bd.lock_rel)] if bd.lock_rel is not None else None,
                lock_rel_value=bd.lock_rel_value,
                use_next_bd=bd.use_next_bd,
                next_bd=bd.next_bd,
                packet_enable=bd.packet_enable,
                packet_id=bd.packet_id,
                fold=fold_id,
                name=f"{bd.name} -- Fold: {fold_id}",
            )
        else:
            return BufferDescriptor(
                bd.aie_dma,
                id=new_bd_id if new_bd_id is not None else bd.id,
                buffer_addr=bd.buffer_addr,
                offset=bd_data_slice(bd._offset, 0, "offset"),
                length=bd_data_slice(bd._length, 0, "length"),
                step=bd_data_slice(bd._step, [1], "step"),
                wrap=bd_data_slice(bd._wrap, [], "wrap"),
                padding=bd_data_slice(bd._padding, [], "padding"),
                iter_step=bd_data_slice(bd._iter_step, None, "iter_step"),
                iter_wrap=bd_data_slice(bd._iter_wrap, None, "iter_wrap"),
                lock_enable=bd.lock_enable,
                lock_acq=new_lock_ids[str(bd.lock_acq)] if bd.lock_acq is not None else None,
                lock_acq_value=bd.lock_acq_value,
                lock_rel=new_lock_ids[str(bd.lock_rel)] if bd.lock_rel is not None else None,
                lock_rel_value=bd.lock_rel_value,
                use_next_bd=bd.use_next_bd,
                next_bd=bd.next_bd,
                packet_enable=bd.packet_enable,
                packet_id=bd.packet_id,
                fold=fold_id,
                name=f"{bd.name} -- Fold: {fold_id}",
            )


    '''
        All of the following functions are related to BD tracking and 
        BD bottleneck computation. The bottleneck computation handles both 
        channel-sharing and non-channel-sharing cases.
    '''

    
    def init_bd_tracker_aie2p(col, tracker):
        '''
            Initialize a list of zeros to track virtual allocations, 
            based on the size of the AIE2p BD pool.
        '''
        memtile_key = str(AieTile(TileType.Memtile, col, 0))
        shim_key = str(AieTile(TileType.Shim, col, 0))
        tracker[memtile_key] = [0] * (config.MAX_MEMTILE_BD_HI_ID + 1)
        tracker[shim_key] = [1] + [0] * config.MAX_SHIM_BD_ID


    def init_bd_tracker_aie4(col, tracker):
        '''
            Initialize a list of zeros to track virtual allocations, 
            based on the size of the AIE4 BD pool.
        '''
        for tile_type in [TileType.Memtile, TileType.Shim]:
            tile = AieTile(tile_type, col, 0)
            tile_key = str(tile)

            if tile_type == TileType.Memtile:
                for ch_dir in [DmaDir.S2MM, DmaDir.MM2S]:
                    for ch_id in range(config.NUM_CHANNEL_LUT[tile_type][ch_dir]):
                        dma = AieDma(tile, DmaChannel(ch_dir, ch_id))
                        tracker[str(dma)] = [0] * (config.MAX_MEMTILE_PRIVATE_BD_ID + 1)
            else:  # Shim
                tracker[tile_key] = [0] * (config.MAX_SHIM_HI_BD_ID + 1)


    def increment_bd_usage(bd, tracker):
        '''
            Mark a BD ID as used by setting its value from 0 to 1, 
            indicating that it is no longer available for allocation.
        '''
        dma = bd.aie_dma
        bd_id = bd_id_offset_adjust(bd.id, dma.tile.type)

        if config.DEV_GEN == DevGen.Aie2p:
            assert tracker[str(dma.tile)][bd_id] == 0, "BD_ID has already been used by another buffer"
            tracker[str(dma.tile)][bd_id] += 1
        else:  # Aie4
            _, tiletype, _, _ = get_bd_dma_config(bd)
            if tiletype == TileType.Memtile:
                assert tracker[str(dma)][bd_id] == 0, "BD_ID has already been used by another buffer"
                tracker[str(dma)][bd_id] += 1
            else:
                assert tracker[str(dma.tile)][bd_id] == 0, "BD_ID has already been used by another buffer"
                tracker[str(dma.tile)][bd_id] += 1


    def compute_pre_opt_bd_tracking_queue():
        '''
            Parse all DataTransfer objects and their ping/pong
            buffers to increment BD usage counts.
        '''
        tracker = {}
        dev_gen = config.DEV_GEN

        for col in range(NUM_COLS):
            if dev_gen == DevGen.Aie2p:
                init_bd_tracker_aie2p(col, tracker)
            elif dev_gen == DevGen.Aie4:
                init_bd_tracker_aie4(col, tracker)

        for tb_group in transfer_buffers:
            for buffer in tb_group:
                for bd in buffer:
                    increment_bd_usage(bd, tracker)

        return tracker

    bd_id_tracking_queue = compute_pre_opt_bd_tracking_queue()
    
    if DEBUG_MODE:
        # ====== PRE-OPTIMIZATION HW BD TRACKER ======
        print("\n====== PRE-OPTIMIZATION HW BD TRACKER ======")
        rows = [[str(k), v] for k, v in alloc.bd_counter.items()]
        print_table(["DMA Tile", "HW BD Counter"], rows)

        # ====== PRE-OPTIMIZATION VIRTUAL BD TRACKER ======
        print("\n====== PRE-OPTIMIZATION VIRTUAL BD TRACKER ======")
        rows = [[str(k), v] for k, v in bd_id_tracking_queue.items()]
        print_table(["DMA Tile", "Virtual BD IDs"], rows)

    
    assert bd_id_tracking_queue == alloc.bd_counter, "Initial Virtual BD tracking should match with actual HW resources allocated"


    def get_bd_key(tile_type, col, ch_dir=None, ch_id=None):
        '''
            Each AIE4 Memtile channel has its own private BD pool, so the key is the AieDma. 
            In contrast, AIE4 Shim tiles and all AIE2p tiles share a BD pool per tile, 
            so the AieTile is used as the key to access the BD pool.
        '''
        tile = AieTile(tile_type, col, 0)
        if ch_dir is not None:
            return str(AieDma(tile, DmaChannel(ch_dir, ch_id)))
        return str(tile)


    def compute_bd_usage_aie2p(col, tile_type, bds_used, bds_left):
        tile_key = get_bd_key(tile_type, col)
        max_bds = (config.MAX_MEMTILE_BD_HI_ID + 1) if tile_type == TileType.Memtile else (config.MAX_SHIM_BD_ID + 1)
        used = count_non_zero(bd_id_tracking_queue[tile_key])
        bds_used[tile_key] = used
        bds_left[tile_key] = max_bds - used


    def compute_bd_usage_aie4(col, tile_type, bds_used, bds_left):
        tile_key = get_bd_key(tile_type, col)

        if tile_type == TileType.Memtile:
            for ch_dir in [DmaDir.S2MM, DmaDir.MM2S]:
                for ch_id in range(config.NUM_CHANNEL_LUT[tile_type][ch_dir]):
                    dma_key = get_bd_key(tile_type, col, ch_dir, ch_id)
                    used = count_non_zero(bd_id_tracking_queue[dma_key])
                    max_bds = config.MAX_MEMTILE_PRIVATE_BD_ID + 1
                    bds_used[dma_key] = used
                    bds_left[dma_key] = max_bds - used
        else:  # Shim
            used = count_non_zero(bd_id_tracking_queue[tile_key])
            max_bds = config.MAX_SHIM_HI_BD_ID + 1
            bds_used[tile_key] = used
            bds_left[tile_key] = max_bds - used


    def compute_pre_opt_num_bd_used():
        '''
            Parse all DataTransfer objects and their ping/pong
            buffers to increment total BD usage counts.
        '''
        bds_left: Dict[str, int] = {}
        bds_used: Dict[str, int] = {}
        dev_gen = config.DEV_GEN

        for col in range(NUM_COLS):
            for tile_type in [TileType.Memtile, TileType.Shim]:
                if dev_gen == DevGen.Aie2p:
                    compute_bd_usage_aie2p(col, tile_type, bds_used, bds_left)
                elif dev_gen == DevGen.Aie4:
                    compute_bd_usage_aie4(col, tile_type, bds_used, bds_left)

        return bds_left, bds_used
    
    bd_left_queue, bd_used_queue = compute_pre_opt_num_bd_used()
    
    if DEBUG_MODE:
        # ====== PRE-OPTIMIZATION BD USED ======
        print("\n====== PRE-OPTIMIZATION BD USED ======")
        rows = [[str(k), v] for k, v in bd_used_queue.items()]
        print_table(["DMA Tile", "BD Used"], rows)

        # ====== PRE-OPTIMIZATION BD LEFT ======
        print("\n====== PRE-OPTIMIZATION BD LEFT ======")
        rows = [[str(k), v] for k, v in bd_left_queue.items()]
        print_table(["DMA Tile", "BD Left"], rows)


    def valid_bd_id_range(tile_type, misc=None):
        """
            Returns the valid start and stop indices to search for the 
            first available zero during new BD ID allocation, depending on 
            how the BD pool is shared based on the DEV_GEN.
        """
        dev_gen = config.DEV_GEN

        if dev_gen == DevGen.Aie2p:
            if tile_type == TileType.Memtile:
                start = 0  if misc == 0 else (config.MAX_MEMTILE_BD_LO_ID + 1) # Odd Channel can used BD_ID between 0 to 23 
                stop  = config.MAX_MEMTILE_BD_LO_ID if misc == 0 else config.MAX_MEMTILE_BD_HI_ID # Even Channel can used BD_ID between 24 to 47
            elif tile_type == TileType.Shim:
                start, stop = 1, config.MAX_SHIM_BD_ID # Shim BD_ID = 0 is reserved for control packets

        elif dev_gen == DevGen.Aie4: 
            if tile_type == TileType.Memtile:
                start, stop = 0, config.MAX_MEMTILE_PRIVATE_BD_ID
            elif tile_type == TileType.Shim:
                start = 0  if misc == 0 else (config.MAX_SHIM_LO_BD_ID + 1)
                stop  = config.MAX_SHIM_LO_BD_ID if misc == 0 else config.MAX_SHIM_HI_BD_ID
        return start, stop


    def get_bd_pool_idx(tiletype, ch_dir, ch_id):
        '''
            Returns the index of the shard based on how the BD pool 
            is shared, such as whether the channel is odd (as in AIE2p).
        '''
        dev_gen = config.DEV_GEN
        misc = None
        if dev_gen == DevGen.Aie2p:
            if tiletype == TileType.Memtile:
                misc = ch_id % 2
            elif tiletype == TileType.Shim:
                misc = 0
        elif dev_gen == DevGen.Aie4:
            if tiletype == TileType.Memtile:
                misc = None
            elif tiletype == TileType.Shim:
                max_lo = config.MAX_SHIM_LO_S2MM_DMA_CHANNEL if ch_dir == DmaDir.S2MM else config.MAX_SHIM_LO_MM2S_DMA_CHANNEL
                index = 0 if ch_id <= max_lo else 1
                misc = index
        return misc
    
    
    def bd_bottleneck_for_transfer_buffer_queue(buffer_idx_list: Union[int, List[int]],
                                                local_bd_id_tracking_queue=bd_id_tracking_queue):
        '''
            Returns a boolean indicating whether BD availability is a bottleneck 
            for folding a single DataTransfer or a group of DataTransfers 
            (in the case of channel sharing). First, the total number of BDs 
            required per BD pool is calculated, then compared against the 
            available resources in each pool.
        '''
        if isinstance(buffer_idx_list, int):
            buffer_idx_list = [buffer_idx_list]
            
        max_bds_needed = defaultdict(int)
        is_bottleneck = False
        dev_gen = config.DEV_GEN
        sync_bd_tracker = {}

        for buffer_idx in buffer_idx_list:
            for buffer in transfer_buffers[buffer_idx]: 
                for bd in buffer:
                    col, tiletype, ch_dir, ch_id = get_bd_dma_config(bd)
                    if dev_gen == DevGen.Aie4 and tiletype == TileType.Memtile:
                        key = (col, tiletype, ch_dir, ch_id)
                    else:
                        misc = get_bd_pool_idx(tiletype, ch_dir, ch_id)
                        key = (col, tiletype, misc)
                    max_bds_needed[key] += 1
        
            # SyncDT changes    
            is_not_async = check_if_not_async(transfer_buffers[buffer_idx][0])
            writer, reader = get_writer_reader_pair(transfer_buffers[buffer_idx][0])
            if is_not_async:
                for task in [writer, reader]:
                    '''
                        1. Check if the task has already been assigned a BD_ID using `sync_dt_bds`.
                        2. Since a task reuses the same BD_ID every time it is enqueued, 
                           we must avoid duplicate BD counting in the bottleneck computation. 
                           This is ensured using `sync_bd_tracker`.
                    '''
                    if str(task) not in sync_dt_bds.keys() and str(task) not in sync_bd_tracker.keys():
                        col, tiletype, ch_dir, ch_id = get_bd_dma_config(task)
                        if dev_gen == DevGen.Aie4 and tiletype == TileType.Memtile:
                            key = (col, tiletype, ch_dir, ch_id)
                        else:
                            misc = get_bd_pool_idx(tiletype, ch_dir, ch_id)
                            key = (col, tiletype, misc)
                        max_bds_needed[key] += 1
                        sync_bd_tracker[str(task)] = True

        for unique_key, bds_needed in max_bds_needed.items():
            if dev_gen == DevGen.Aie4 and tiletype == TileType.Memtile:
                col, tiletype, ch_dir, ch_id = unique_key    
                search_key = str(AieDma(AieTile(tiletype, col, 0), DmaChannel(ch_dir, ch_id)))
                start, stop = valid_bd_id_range(tiletype)
                available_bds = count_zeros_in_range(local_bd_id_tracking_queue[search_key], start, stop)
            else:
                col, tiletype, misc = unique_key
                search_key = str(AieTile(tiletype, col, 0))
                start, stop = valid_bd_id_range(tiletype, misc)
                available_bds = count_zeros_in_range(local_bd_id_tracking_queue[search_key], start, stop)

            if bds_needed > available_bds: # NOTE: Need to check if > or >=
                is_bottleneck = True
                
        return is_bottleneck


    '''
        All of the following functions are related to Lock tracking and 
        Lock bottleneck computation. The bottleneck computation handles both 
        channel-sharing and non-channel-sharing cases.
    '''
    def compute_pre_opt_locks_used():   
        '''
            A snapshot of the remaining Locks after the initial resource allocation.
        '''
        lock_used_dict = {}
                 
        hw_lock_config = {
            TileType.Memtile: (config.MAX_MEMTILE_LOCK_ID  + 1),
            TileType.Shim: (config.MAX_SHIM_LOCK_ID + 1)
        }
        
        for tile_type in [TileType.Shim, TileType.Memtile]:
            for col in range(NUM_COLS):
                dummy_tile = AieTile(tile_type, col, 0)
                lock_used_dict[str(dummy_tile)] = hw_lock_config[tile_type] - alloc.check_next_available_lock_id(dummy_tile)
                
        return lock_used_dict
    
    initial_lock_left = compute_pre_opt_locks_used()
    
    if DEBUG_MODE:
        # ====== PRE-OPTIMIZATION LOCKS LEFT ======
        print("\n====== PRE-OPTIMIZATION LOCKS LEFT ======")
        rows = [[str(k), v] for k, v in initial_lock_left.items()]
        print_table(["Tile", "Locks Left"], rows)


    def unique_locks_in_buffer(transfer_buffer_idx: int):
        '''
            Since Writers and Readers use the same Locks for ACQ and REL, 
            this function returns the list of unique Locks used in a DataTransfer.
        '''
        unique_locks = {}
        num_ping_pong_buffers = len(transfer_buffers[transfer_buffer_idx])
        for ping_pong_buffer_idx in range(num_ping_pong_buffers):
            for bd in transfer_buffers[transfer_buffer_idx][ping_pong_buffer_idx]:
                if bd.lock_acq is not None and str(bd.lock_acq) not in unique_locks.keys():
                    unique_locks[str(bd.lock_acq)] = bd.lock_acq
                if bd.lock_rel is not None and str(bd.lock_rel) not in unique_locks.keys():
                    unique_locks[str(bd.lock_rel)] = bd.lock_rel
        
        return unique_locks


    def get_new_lock_ids(transfer_buffer_idx, repeat_batch_idx):
        '''
            This function creates a mapping of which new Lock ID should replace 
            the original Lock ID during folding. This is necessary because 
            Readers and Writers use the same Lock IDs for ACQ and REL in a 
            switched manner, as established in the `compile_transfer` function.

            However, since `compile_transfer` operates at the DataTransfer level 
            and we are currently working at the BD level, this mapping acts as 
            a lookup table for assigning the correct new Lock IDs.
        '''
        new_lock_id = {}
        unique_locks = unique_locks_in_buffer(transfer_buffer_idx)

        for key, lock_instance in unique_locks.items():
            lock_tile = lock_instance.aie_tile
            lock_value = lock_instance.init_value
            new_lock_id[key] = lock_instance if repeat_batch_idx == 0 else Lock(lock_tile, alloc.lock_id(lock_tile), lock_value)

        return new_lock_id


    def compute_lock_bottleneck(buffer_idx_list: Union[int, List[int]],
                                local_initial_lock_left=initial_lock_left):
        '''
            Returns a boolean indicating whether Lock availability is a bottleneck 
            for folding a single DataTransfer or a group of DataTransfers 
            (in the case of channel sharing). First, the total number of Locks 
            required per tile is calculated, then compared against the 
            available resources in each tile.
        '''
        if isinstance(buffer_idx_list, int):
            buffer_idx_list = [buffer_idx_list]    
        max_locks_needed = defaultdict(int)  
        is_bottleneck = False
        sync_lock_tracker = {}
        
        for transfer_buffer_idx in buffer_idx_list:
            unique_locks = unique_locks_in_buffer(transfer_buffer_idx)    
            for key, lock_instance in unique_locks.items():
                lock_tile = lock_instance.aie_tile
                max_locks_needed[str(lock_tile)] += 1 
                
            # SyncDT changes    
            is_not_async = check_if_not_async(transfer_buffers[transfer_buffer_idx][0])
            writer, reader = get_writer_reader_pair(transfer_buffers[transfer_buffer_idx][0])
            '''
                1. Check if the task has already been assigned a Lock_ID using `sync_dt_locks`.
                2. Since a task uses the same Lock_ID every time it is enqueued, 
                   duplicate counting of Locks in the bottleneck computation must be avoided. 
                   This is ensured using `sync_lock_tracker`.
            '''
            if is_not_async and str(writer) not in sync_dt_locks.keys() and str(writer) not in sync_lock_tracker.keys():
                max_locks_needed[str(writer.tile)] += 2
                sync_lock_tracker[str(writer)] = True
                
        for unique_key, locks_needed in max_locks_needed.items():
            if locks_needed >= local_initial_lock_left[unique_key]: # NOTE: Need to check if > or >=
                is_bottleneck = True          
                
        return is_bottleneck
           

    '''
        All of the following functions are related to Task Queue Depth tracking and 
        Task Queue Depth bottleneck computation. The bottleneck computation handles both 
        channel-sharing and non-channel-sharing cases.
    '''
    def get_num_enqueue_task(counter, num_buffers):
        '''
            A task with a repeat_count of "x" may require more than one entry 
            in the task queue, depending on the repeat_count value and whether 
            it uses ping-pong buffering or not.
        '''
        required_enqueue = 0
        while counter >= num_buffers * config.MAX_REPEAT_COUNT:
            required_enqueue += 1
            counter -= num_buffers * config.MAX_REPEAT_COUNT

        repeat1 = counter // num_buffers
        repeat2 = counter % num_buffers
        required_enqueue += 1 if repeat1 > 0 else 0    
        required_enqueue += 1 if repeat2 > 0 else 0  
        
        return required_enqueue            


    def compute_pre_opt_task_queue_depth():
        '''
            A snapshot of the remaining Task Queue Depth after the initial resource allocation.
        '''
        num_phases = len(transfer_repeats[0])

        tq_used_per_phase = {}
        tq_left_per_phase = {}

        # Initialize per-phase DMA usage and capacity
        for phase in range(num_phases):
            used = {}
            left = {}
            for col in range(NUM_COLS):
                for tile in [TileType.Memtile, TileType.Shim]:
                    for direction in [DmaDir.S2MM, DmaDir.MM2S]:
                        for ch_id in range(config.NUM_CHANNEL_LUT[tile][direction]):
                            dma_key = str(AieDma(AieTile(tile, col, 0), DmaChannel(direction, ch_id)))
                            used[dma_key] = 0
                            left[dma_key] = config.MAX_TASK_QUEUE_SIZE
            tq_used_per_phase[phase] = used
            tq_left_per_phase[phase] = left

        # Count DMA usage per buffer per phase
        for i in range(num_phases):
            for buf_idx, buffers in enumerate(transfer_buffers):
                curr_repeat_count = transfer_repeats[buf_idx][i]
                num_internal_buffers = len(buffers)
                
                if curr_repeat_count <= 0:
                    continue
                
                dmas_in_this_buffer = set()
                for buffer in buffers:
                    for bd in buffer:
                        dmas_in_this_buffer.add(str(bd.aie_dma))

                task_queue_cost = get_num_enqueue_task(curr_repeat_count, num_internal_buffers)
                    
                for dma_key in dmas_in_this_buffer:
                    tq_used_per_phase[i][dma_key] += task_queue_cost
                    tq_left_per_phase[i][dma_key] -= task_queue_cost
        
        """         
            NOTE: This was added during BMMs debug based on the conclusion that becuase of WGT/ QDQ Param on 
                  Even Col Broadcast channel, these Col consume more resources. Hence, Odd Col do more folding
                  leading to synchronization error.

            if config.DEV_GEN == DevGen.Aie2p:
                for i in range(num_phases):           
                    for tile in [TileType.Memtile, TileType.Shim]:
                        for direction in [DmaDir.S2MM, DmaDir.MM2S]:
                            for ch_id in range(config.NUM_CHANNEL_LUT[tile][direction]):
                                max_task_queue_used = float("-inf")
                                min_task_queue_used = float("inf")
                                for col in range(NUM_COLS):
                                    dma_key = str(AieDma(AieTile(tile, col, 0), DmaChannel(direction, ch_id)))
                                    max_task_queue_used = max(max_task_queue_used, tq_used_per_phase[i][dma_key])
                                    min_task_queue_used = min(min_task_queue_used, tq_left_per_phase[i][dma_key])
                                    
                                for col in range(NUM_COLS):
                                    dma_key = str(AieDma(AieTile(tile, col, 0), DmaChannel(direction, ch_id)))
                                    tq_used_per_phase[i][dma_key] = max_task_queue_used
                                    tq_left_per_phase[i][dma_key] = min_task_queue_used
        """

        return tq_used_per_phase, tq_left_per_phase

    tq_used_per_phase, tq_left_per_phase = compute_pre_opt_task_queue_depth()
    
    if DEBUG_MODE:
        # ====== PRE-OPTIMIZATION TASK QUEUE DEPTH ======
        print("\n====== PRE-OPTIMIZATION TASK QUEUE DEPTH ======")
        for phase in tq_used_per_phase:
            print(f"\nPhase {phase}:")
            rows = []
            for dma_key in tq_used_per_phase[phase]:
                used = tq_used_per_phase[phase][dma_key]
                left = tq_left_per_phase[phase][dma_key]
                rows.append([str(dma_key), used, left])
            print_table(["DMA Key", "Used", "Left"], rows)

    
    def task_queue_bottleneck_for_transfer_buffer(buffer_idx_list: Union[int, List[int]],
                                                  phase: int,
                                                  local_tq_left_per_phase=tq_left_per_phase):
        '''
            Returns a boolean indicating whether Task Queue Depth availability is 
            a bottleneck for folding a single DataTransfer or a group of DataTransfers 
            (in the case of channel sharing). First, the total number of Task Queue Depths 
            required per Channel is calculated, then compared against the 
            available resources in each Channel.
        '''
        if isinstance(buffer_idx_list, int):
            buffer_idx_list = [buffer_idx_list]
        
        required_enqueue_map = defaultdict(int)
        is_bottleneck = False
        for buffer_idx in buffer_idx_list:
            num_buffers = len(transfer_buffers[buffer_idx]) # minimum required enqueues
            curr_repeat_count = transfer_repeats[buffer_idx][phase]
            required_enqueue = get_num_enqueue_task(curr_repeat_count, num_buffers)
            for buffer in transfer_buffers[buffer_idx]: 
                for bd in buffer:
                    required_enqueue_map[str(bd.aie_dma)] += required_enqueue
                
            # SyncDT changes    
            is_not_async = check_if_not_async(transfer_buffers[buffer_idx][0])
            writer, reader = get_writer_reader_pair(transfer_buffers[buffer_idx][0])
            if is_not_async and str(writer) in required_enqueue_map.keys():
                required_enqueue_map[str(writer)] += 1 
            if is_not_async and str(reader) in required_enqueue_map.keys():
                required_enqueue_map[str(reader)] += 1
                    
        for key, required_enqueue in required_enqueue_map.items():
            bottleneck = local_tq_left_per_phase[phase][key]
            if bottleneck < required_enqueue:
                is_bottleneck = True
        return is_bottleneck



    '''
        All of the following functions are related to
        resource updates performed during folding.
    ''' 
    def get_search_key(col, tiletype, ch_dir, ch_id):
        if config.DEV_GEN == DevGen.Aie4 and tiletype == TileType.Memtile:
            return str(AieDma(AieTile(tiletype, col, 0), DmaChannel(ch_dir, ch_id)))
        else:
            return str(AieTile(tiletype, col, 0))


    def fold_bd_resource_update(transfer_buffer_idx_list: Union[int, List[int]],
                                     local_bd_id_tracking_queue=bd_id_tracking_queue,
                                     local_bd_left_queue=bd_left_queue):
        if isinstance(transfer_buffer_idx_list, int):
            transfer_buffer_idx_list = [transfer_buffer_idx_list]
            
        bds_used = {}   # Only for debug purpose to match with HW Allocator
        for transfer_buffer_idx in transfer_buffer_idx_list:
            for ping_pong_buffer_id, ping_pong_buffer in enumerate(transfer_buffers[transfer_buffer_idx]): 
                bds_used[ping_pong_buffer_id] = []
                for bd in ping_pong_buffer:
                    col, tiletype, ch_dir, ch_id = get_bd_dma_config(bd)
                    start, stop = valid_bd_id_range(tiletype, get_bd_pool_idx(tiletype, ch_dir, ch_id))
                    search_key = get_search_key(col, tiletype, ch_dir, ch_id)
                    new_bd_id = index_of_first_zero(local_bd_id_tracking_queue[search_key], start, stop)
                    bds_used[ping_pong_buffer_id].append([bd.id, new_bd_id])
                    assert local_bd_id_tracking_queue[search_key][new_bd_id] == 0, "BD_ID has already been used by another buffer"
                    local_bd_id_tracking_queue[search_key][new_bd_id] += 1
                    local_bd_left_queue[search_key] -= 1
            
        if DEBUG_MODE:
            print("====== BD ID POOL MAPPING ======")
            rows = []
            for key in bds_used.keys():
                for old, new in bds_used[key]:
                    rows.append([key, str(old), str(new)])
            print_table(["bd_id_pool", "old_bd_id", "new_bd_id"], rows)


            
    def fold_phase_resource_update(transfer_buffer_idx_list: Union[int, List[int]],
                                   folded_phase_idx, initial_phase_idx,
                                   local_tq_left_per_phase=tq_left_per_phase):
        if isinstance(transfer_buffer_idx_list, int):
            transfer_buffer_idx_list = [transfer_buffer_idx_list] 

        for transfer_buffer_idx in transfer_buffer_idx_list:
            curr_repeat_count = transfer_repeats[transfer_buffer_idx][initial_phase_idx]     
            num_ping_pong_buffers = len(transfer_buffers[transfer_buffer_idx])
            
            for buffer in transfer_buffers[transfer_buffer_idx]: 
                unique_aie_dma = set()
                for bd in buffer:         
                    unique_aie_dma.add(str(bd.aie_dma))
            num_enqueues = get_num_enqueue_task(curr_repeat_count, num_ping_pong_buffers)
        
            # SyncDT changes    
            is_not_async = check_if_not_async(transfer_buffers[transfer_buffer_idx][0])
            writer, reader = get_writer_reader_pair(transfer_buffers[transfer_buffer_idx][0])
            if is_not_async and str(writer) in unique_aie_dma:
                local_tq_left_per_phase[folded_phase_idx][str(writer)] -= 1 
            if is_not_async and str(reader) in unique_aie_dma:
                local_tq_left_per_phase[folded_phase_idx][str(reader)] -= 1
                        
            for bd_aie_dma in unique_aie_dma:
                local_tq_left_per_phase[initial_phase_idx][bd_aie_dma] += num_enqueues
                local_tq_left_per_phase[folded_phase_idx][bd_aie_dma] -= num_enqueues
                assert local_tq_left_per_phase[folded_phase_idx][bd_aie_dma] >= 0, "Task Queue Depth exceeded during Task Queue Optimization"
            
    
    def fold_lock_resource_update(transfer_buffer_idx_list: Union[int, List[int]],
                                  local_initial_lock_left=initial_lock_left):
        if isinstance(transfer_buffer_idx_list, int):
            transfer_buffer_idx_list = [transfer_buffer_idx_list]
            
        for transfer_buffer_idx in transfer_buffer_idx_list:
            unique_locks = unique_locks_in_buffer(transfer_buffer_idx)
            for key, lock_instance in unique_locks.items():
                lock_tile = lock_instance.aie_tile
                local_initial_lock_left[str(lock_tile)] -= 1

        
    def track_channel_sharing():
        """
        Tracks folding eligibility of transfer buffers (TB_IDs) across phases based on channel sharing.

        Steps:
        1. Identify TB_IDs sharing the same DMA channel per tile, column, and phase.  
        2. For each tile, column, and phase, create a union of TB_IDs using S2MM and  
           MM2S directions to form sharing groups.                                                                             
        3. For each group of TB_IDs, determine the phases where all TB_IDs have a     
           non-zero repeat_count, these are candidate phases for folding.             
        4. For candidate folding phases, compute the cumulative resource bottleneck   
           (Task Queue depth, BDs, Locks), and confirm folding feasibility if all     
           TB_IDs in the group satisfy the constraint.                                
        5. Return a dictionary mapping each TB_ID to a list of phases where unfolding 
           (i.e., execution) is possible:                                             
        {                                                                             
            TB_ID_0: [phase_1, phase_3, ...],                                         
            TB_ID_1: [phase_0, phase_2, ...],                                         
            ...                                                                       
        }                                                                             
        """
        num_phases = len(transfer_repeats[0])
        
        all_channels = {}
        for col in range(NUM_COLS):
            for tile_type in [TileType.Memtile, TileType.Shim]:
                for ch_dir in [DmaDir.S2MM, DmaDir.MM2S]:
                    for ch_id in range(config.NUM_CHANNEL_LUT[tile_type][ch_dir]):
                        all_channels[str(AieDma(AieTile(tile_type, col), DmaChannel(ch_dir, ch_id)))] = set()
                    
        dependency_graph = {}
        for phase in range(num_phases):
            dependency_graph[phase] = deepcopy(all_channels)
            
        for transfer_buffer_idx, transfer_buffer in enumerate(transfer_buffers):
            if has_multiple_non_zero(transfer_repeats[transfer_buffer_idx]):
                phase_indices = get_positive_indices(transfer_repeats[transfer_buffer_idx])
                for buffers in transfer_buffer:
                    for bd in buffers:
                        for phase in phase_indices:
                            dependency_graph[phase][str(bd.aie_dma)].add(transfer_buffer_idx)
         
        # -------------------------------------------------------------------------------------------------------  
        
        def merge_overlapping_sets(sets):
            merged = []
            for s in sets:
                found = False
                for group in merged:
                    if s & group:
                        group |= s  # Merge in-place
                        found = True
                        break
                if not found:
                    merged.append(set(s))
            # Re-run to catch transitive overlaps
            changed = True
            while changed:
                changed = False
                new_merged = []
                while merged:
                    first, rest = merged[0], merged[1:]
                    merged = []
                    for r in rest:
                        if first & r:
                            first |= r
                            changed = True
                        else:
                            merged.append(r)
                    new_merged.append(first)
                merged = new_merged
            return [sorted(group) for group in merged]

        # Final result: phase -> tile -> list of TB index groups
        phase_tile_tb_groups = {}

        for phase, phase_data in dependency_graph.items():
            tile_to_sets = defaultdict(list)

            for dma_key, tb_indices in phase_data.items():
                if not tb_indices:
                    continue
                tile = '_'.join(dma_key.split('_')[:3])
                tile_to_sets[tile].append(set(tb_indices))

            # Merge overlaps
            phase_tile_tb_groups[phase] = {}
            for tile, sets in tile_to_sets.items():
                groups = merge_overlapping_sets(sets)
                phase_tile_tb_groups[phase][tile] = groups
                
        # -------------------------------------------------------------------------------------------------------  
        # Step 1: Build the initial group_to_valid_phases
        group_to_valid_phases = defaultdict(list)

        
        # Loop over each phase
        for phase, tile_groups in phase_tile_tb_groups.items():
            for tile, groups in tile_groups.items():
                for group in groups:
                    group_key = tuple(sorted(group))
                    # Check that all TBs in the group have a positive repeat count in this phase
                    if all(transfer_repeats[tb_id][phase] > 0 for tb_id in group):
                        # Step 2: For each TB, find the previous phase index where repeat was positive
                        prev_positive_indices = []
                        for tb_id in group:
                            repeats = transfer_repeats[tb_id]
                            # Find previous index before 'phase' with repeat > 0
                            prev_index = -1
                            for i in range(phase - 1, -1, -1):
                                if repeats[i] > 0:
                                    prev_index = i
                                    break
                            prev_positive_indices.append(prev_index)
                        # Only if all previous positive indices match
                        if len(set(prev_positive_indices)) == 1:
                            group_to_valid_phases[group_key].append(phase)
                        
        # Step 2: Remove singleton groups that appear inside multi-member groups
        # Build a set of all TB_IDs that appear in multi-member groups
        tb_ids_in_multi_groups = {
            tb_id for group in group_to_valid_phases
            if len(group) > 1
            for tb_id in group
        }

        # Filter the dictionary
        cleaned_group_to_valid_phases = {
            group: phases
            for group, phases in group_to_valid_phases.items()
            if not (len(group) == 1 and group[0] in tb_ids_in_multi_groups)
        }
    
        # -------------------------------------------------------------------------------------------------------  
        # Copy of Allocations for resource update during tracking channel sharing
        local_bd_id_tracking_queue = deepcopy(bd_id_tracking_queue)
        local_bd_left_queue = deepcopy(bd_left_queue)
        local_initial_lock_left = deepcopy(initial_lock_left)
        local_tq_left_per_phase = deepcopy(tq_left_per_phase)
        
        group_to_phases_with_resources = defaultdict(lambda: [0] * num_phases)
        
        for group_key, valid_phases in cleaned_group_to_valid_phases.items():
            group = list(group_key)
            temp_folded_phases = []
            for phase_idx, phase in enumerate(valid_phases):
                bd_bottleneck = bd_bottleneck_for_transfer_buffer_queue(group, local_bd_id_tracking_queue)
                lock_bottleneck = compute_lock_bottleneck(group, local_initial_lock_left)

                if phase_idx == 0:
                    new_row = [0] * num_phases
                    new_row[phase] = 1
                    temp_folded_phases.append(new_row)
                    group_to_phases_with_resources[group_key][phase] = 1

                else:
                    previous_phase = find_last_positive_index(temp_folded_phases[-1])
                    task_queue_bottleneck = task_queue_bottleneck_for_transfer_buffer(group,
                                                                                 previous_phase,
                                                                                 local_tq_left_per_phase)               
                    if bd_bottleneck or lock_bottleneck or task_queue_bottleneck:
                        # Reuse the previous row to add another task
                        temp_folded_phases[-1][phase] = 1
                    # No update to BD_IDs or task queue tracking required
                    
                    elif not bd_bottleneck and not lock_bottleneck and not task_queue_bottleneck:
                        new_row = [0] * num_phases
                        new_row[previous_phase] = 1
                        temp_folded_phases.append(new_row)
                        group_to_phases_with_resources[group_key][phase] = 1

                        # Update BD_ID tracking per channel
                        for buffer_idx in group:
                            is_not_async = check_if_not_async(transfer_buffers[buffer_idx][0])
                            writer, reader = get_writer_reader_pair(transfer_buffers[buffer_idx][0])
                            if is_not_async:
                                writer_bd_id, reader_bd_id = sync_dt_bd_management(writer, reader, local_bd_id_tracking_queue, local_bd_left_queue) 
                        fold_bd_resource_update(group, local_bd_id_tracking_queue, local_bd_left_queue)

                        # Update task queue tracking
                        fold_phase_resource_update(group, previous_phase, phase, local_tq_left_per_phase)

                        # Update lock id tracking
                        for buffer_idx in group:
                            is_not_async = check_if_not_async(transfer_buffers[buffer_idx][0])
                            writer, reader = get_writer_reader_pair(transfer_buffers[buffer_idx][0])
                            if is_not_async:
                                writer_lock, reader_lock = sync_dt_lock_management(writer, local_initial_lock_left) # This step also reduce the virtual lock used by 2
                        fold_lock_resource_update(group, local_initial_lock_left)     

        # -------------------------------------------------------------------------------------------------------  
        
        tbid_to_phases = {}
        for group, phases in group_to_phases_with_resources.items():
            for tb_id in group:
                tbid_to_phases[tb_id] = phases
            
        return dependency_graph, phase_tile_tb_groups, cleaned_group_to_valid_phases, group_to_phases_with_resources, tbid_to_phases 
    
    dep_graph, dep_groups, valid_phases, phases_with_resources, phases_with_valid_fold  = track_channel_sharing()
    
    if DEBUG_MODE:
        print("====== DEP GRAPH TB INDICES ======")
        for phase in sorted(dep_graph.keys()):
            print(f"\n{'=' * 60}\nPHASE {phase}\n{'=' * 60}")
            phase_data = dep_graph[phase]
            prev_col = -1
            grouped_rows = []
            for key, tb_indices in phase_data.items():
                col = key.split('_')[1]
                if col != prev_col:
                    if grouped_rows:
                        print_table(["AieDma", "TB Indices"], grouped_rows)
                        grouped_rows = []
                    print(f"\n{'-' * 50}")
                    print(f"Phase {phase} | Column {col}")
                    prev_col = col
                grouped_rows.append([key, tb_indices])
            if grouped_rows:
                print_table(["AieDma", "TB Indices"], grouped_rows)

        # ====== TB Index Groups Per Phase and Tile ======
        print("\n=== TB Index Groups Per Phase and Tile ===")
        for phase in sorted(dep_groups.keys()):
            print(f"\n{'=' * 60}\nPHASE {phase}\n{'=' * 60}")
            for tile, groups in dep_groups[phase].items():
                print(f"\n-- Tile: {tile} --")
                for group in groups:
                    print(f"  Group: {group}")

        # ====== VALID PHASES PER GROUP ======
        rows = [[str(group), str(phases) if phases else "None"] for group, phases in valid_phases.items()]
        print_table(["Group", "Folding Phases"], rows)

        # ====== PHASES WITH RESOURCES ======
        rows = [[str(key), str(value) if value else "None"] for key, value in phases_with_resources.items()]
        print_table(["TB_IDs", "Valid Folding Phases"], rows)

        # ====== PHASES WITH VALID FOLD ======
        rows = [[str(tb_id), str(phases_with_valid_fold[tb_id])] for tb_id in sorted(phases_with_valid_fold)]
        print_table(["TB_ID", "Valid Folding Phases"], rows)


    def fold_phases(transfer_buffer_idx: int) -> List[List[int]]:
        """
        Unfolds the input_row across multiple phases while obeying BD bottleneck
        and updating task_queue_max accordingly.

        Returns:
        - result: List[List[int]] -- Unfolded matrix across phases.
        """
        repeat_count_list = transfer_repeats[transfer_buffer_idx]
        non_zero_phase_indices = get_positive_indices(repeat_count_list)
        total_num_phases = len(repeat_count_list)
        
        folded_phases = []
        for unfold_id, phase in enumerate(non_zero_phase_indices):
            repeat_count = repeat_count_list[phase]
            bd_bottleneck = bd_bottleneck_for_transfer_buffer_queue(transfer_buffer_idx)
            lock_bottleneck = compute_lock_bottleneck(transfer_buffer_idx)
            
            if unfold_id == 0:
                # First phase: just drop the task into a new row
                new_row = [0] * total_num_phases
                new_row[phase] = repeat_count
                folded_phases.append(new_row)

            else:
                previous_phase = find_last_positive_index(folded_phases[-1])
                task_queue_bottleneck = task_queue_bottleneck_for_transfer_buffer(transfer_buffer_idx, previous_phase)
                if bd_bottleneck or lock_bottleneck or task_queue_bottleneck or not bool(phases_with_valid_fold[transfer_buffer_idx][phase]):
                    # Reuse the previous row to add another task
                    folded_phases[-1][phase] = repeat_count
                    # No update to BD_IDs or task queue tracking required

                else:
                    # Create a new row, carry over a task from previous phase
                    new_row = [0] * total_num_phases

                    # Place the carried-over task and the new one
                    new_row[previous_phase] = repeat_count_list[phase]
                    folded_phases.append(new_row)

                    # Update BD_ID tracking per channel
                    fold_bd_resource_update(transfer_buffer_idx)
                    
                    # Update task queue tracking
                    fold_phase_resource_update(transfer_buffer_idx, previous_phase, phase)
                    
                    # Update lock id tracking
                    fold_lock_resource_update(transfer_buffer_idx)
                    
        return folded_phases


    transfer_buffers_optimized = []
    transfer_locks_optimized   = []
    transfer_repeats_optimized = []
    transfer_repeats_batches   = []
    transfer_reuse_optimized   = []

    
    for idx, transfer_repeat in enumerate(transfer_repeats):
        if has_multiple_non_zero(transfer_repeat):
            folded_phase = []
            folded_phase = fold_phases(idx)
            # transfer_repeats_optimized += folded_phase
            transfer_repeats_batches.append(folded_phase)
        else:
            # transfer_repeats_optimized.append(transfer_repeat)
            transfer_repeats_batches.append([transfer_repeat])   
            
    if DEBUG_MODE:
        # ====== TRANSFER REPEATS ======
        print("\n====== TRANSFER REPEATS ======")
        rows = [[i, len(entry), entry] for i, entry in enumerate(transfer_repeats)]
        print_table(["Index", "Length", "Value"], rows)

        # ====== TRANSFER REPEATS BATCHES ======
        print("\n====== TRANSFER REPEATS BATCHES ======")
        rows = [[i, len(entry), entry] for i, entry in enumerate(transfer_repeats_batches)]
        print_table(["Index", "Length", "Value"], rows)


    def map_to_new_phase_idx(transfer_repeat_batch, phase_indices):
        phase_idx_mapping = []
        num_non_zero = 0
        for fold in transfer_repeat_batch:
            fold_num_non_zero = count_non_zero(fold)
            phase_idx_mapping.append(phase_indices[num_non_zero:num_non_zero+fold_num_non_zero])
            num_non_zero += fold_num_non_zero
        return phase_idx_mapping    
    
    for transfer_buffer_idx in range(len(transfer_buffers)):
        if has_multiple_non_zero(transfer_repeats[transfer_buffer_idx]):
            for fold_id, fold in enumerate(transfer_repeats_batches[transfer_buffer_idx]):
                ping_pong_buffer_bds = []
                lock_group_for_transfer_buffer = []
                
                num_phases = len(transfer_repeats[transfer_buffer_idx])
                phase_indices = get_positive_indices(transfer_repeats[transfer_buffer_idx])
                folded_phase_indices = get_positive_indices(fold)
                phase_idx_map = map_to_new_phase_idx(transfer_repeats_batches[transfer_buffer_idx], phase_indices)
                new_locks = get_new_lock_ids(transfer_buffer_idx, fold_id)
                lock_group_for_transfer_buffer.extend(list(new_locks.values()))
                
                for ping_pong_buffer_idx, buffers in enumerate(transfer_buffers[transfer_buffer_idx]):
                    bd_group = []
                    for bd in buffers:
                        # The first fold does not require a new BD ID
                        new_bd_id = bd.id if fold_id == 0 else alloc.bd_id(bd.aie_dma)
                        bd_group.append(construct_buffer_descriptor(bd, new_bd_id,
                                                                    num_phases,
                                                                    folded_phase_indices,
                                                                    phase_idx_map[fold_id],
                                                                    fold_id,
                                                                    new_locks))
                    ping_pong_buffer_bds.append(bd_group)

                is_not_async = check_if_not_async(transfer_buffers[transfer_buffer_idx][0])
                if is_not_async and fold_id > 0:
                    syncdt_bds,\
                    syncdt_locks,\
                    syncdt_repeat_count,\
                    syncdt_reuse_ratio = insert_sync_dt(transfer_buffers[transfer_buffer_idx][0],
                                                        num_phases, get_positive_indices(fold)[0],
                                                        fold_id)
                    # SyncDT
                    transfer_buffers_optimized.append([syncdt_bds])        
                    transfer_locks_optimized.append(syncdt_locks) 
                    transfer_reuse_optimized.append(syncdt_reuse_ratio)
                    transfer_repeats_optimized.append(syncdt_repeat_count)
                
                # DataTransfer folds that are not optimized    
                transfer_buffers_optimized.append(ping_pong_buffer_bds)
                transfer_locks_optimized.append(lock_group_for_transfer_buffer)
                transfer_reuse_optimized.append(transfer_reuse[transfer_buffer_idx])
                transfer_repeats_optimized.append(transfer_repeats_batches[transfer_buffer_idx][fold_id])
        else:
            # DataTransfer that are not optimized
            transfer_buffers_optimized.append(transfer_buffers[transfer_buffer_idx])
            transfer_locks_optimized.append(transfer_locks[transfer_buffer_idx])
            transfer_reuse_optimized.append(transfer_reuse[transfer_buffer_idx])
            transfer_repeats_optimized += transfer_repeats_batches[transfer_buffer_idx]
         
    return (
        transfer_buffers_optimized,
        transfer_locks_optimized,
        transfer_repeats_optimized,
        transfer_reuse_optimized
        )

#
# Run all optimization passes
#
# Inputs: Overlay location, memtile transfers, shim transfers
# Output: Compiled data buffers
#


def compile_transfer_alloc(
    alloc: DmaAllocator,
    transfer: DataTransfer,
) -> BufferAllocation:
    config.check_init()
    strategy = transfer.sync_strategy
    if strategy == SyncStrategy.Remote_Barrier:
        strategy = SyncStrategy.Default
    if strategy == SyncStrategy.Default:
        if len(transfer.write_params) == 0 or len(transfer.read_params) == 0:
            strategy = SyncStrategy.Async
        else:
            # TODO: Should default strategy be serial or parallel?
            strategy = SyncStrategy.Serial_M_to_N
    is_reuse_invalid = (
        (transfer.reuse_ratio > 1) and
        (strategy not in (SyncStrategy.Parallel_1_to_N, SyncStrategy.Serial_M_to_N))
    )
    if is_reuse_invalid:
        raise RuntimeError(f'Invalid reuse ratio {transfer.reuse_ratio} for {strategy}!')
    if strategy == SyncStrategy.Parallel_1_to_N:
        buffers, locks = compile_parallel_1_to_N_transfer(
            alloc,
            transfer.tile,
            transfer.buffer_addrs,
            transfer.write_params,
            transfer.read_params,
            transfer.reuse_ratio,
        )
    elif strategy == SyncStrategy.Parallel_N_to_1:
        buffers, locks = compile_parallel_N_to_1_transfer(
            alloc,
            transfer.tile,
            transfer.buffer_addrs,
            transfer.write_params,
            transfer.read_params,
        )
    elif strategy == SyncStrategy.Serial_M_to_N:
        buffers, locks = compile_serial_M_to_N_transfer(
            alloc,
            transfer.tile,
            transfer.buffer_addrs,
            transfer.write_params,
            transfer.read_params,
            transfer.reuse_ratio,
        )
    else:
        if len(transfer.read_params) == 0:
            buffers, locks = compile_async_transfer(
                alloc,
                transfer.tile,
                transfer.buffer_addrs,
                transfer.write_params,
                transfer.sync_strategy,
            )
        else:
            assert len(transfer.write_params) == 0
            buffers, locks = compile_async_transfer(
                alloc,
                transfer.tile,
                transfer.buffer_addrs,
                transfer.read_params,
                transfer.sync_strategy
            )
    return buffers, locks


def split_data_transfer(
    transfer: DataTransfer,
) -> List[DataTransfer]:

    if transfer.buffer_split == 1:
        return [transfer]

    def is_linear(param: TransferParams) -> bool:
        return (
            param._num_reconfig() == 1 and
            isinstance(param._length, int) and
            param._offset == 0 and
            param._step == [1] and
            param._wrap == [] and
            param._padding == [] and
            param._iter_step is None and
            param._iter_wrap is None and
            param.shim_buffer_index is None
        )

    def split_linear_transfer(
        param: TransferParams,
        split_index: int,
        num_splits: int,
    ) -> TransferParams:
        assert is_linear(param)
        assert split_index < num_splits
        if param._length % num_splits != 0:
            raise ValueError(f'Invalid transfer length {param._length}, ' +
                             f'must be divisible by buffer splits {num_splits}!')
        new_length = param._length // num_splits
        new_offset = new_length * split_index
        new_param = TransferParams(param.dma, new_length, offset=new_offset)
        return new_param

    all_linear = (
        all(is_linear(param) for param in transfer.write_params) and
        all(is_linear(param) for param in transfer.read_params)
    )
    if not all_linear:
        raise ValueError(f'Invalid buffer split {transfer.buffer_split}, '
                         f'requires linear access pattern!')
    splits = []
    for split_index in range(transfer.buffer_split):
        splits.append(DataTransfer(
            transfer.repeat_counts,
            transfer.tile, transfer.buffer_addrs, transfer.buffer_size,
            [split_linear_transfer(param, split_index, transfer.buffer_split)
             for param in transfer.write_params],
            [split_linear_transfer(param, split_index, transfer.buffer_split)
             for param in transfer.read_params],
            sync_strategy=transfer.sync_strategy,
            reuse_ratio=transfer.reuse_ratio,
        ))
    return splits


def auto_chain_transfers(data_transfers: List[DataTransfer]) -> List[DataTransfer]:
    '''
    This optimization pass will automatically chain single-DMA transfers with
    repeats counts that contain all zeros and ones. For example, if two transfers
    running on the same DMA channel have a repeat count list of [1, 1, 1], then
    this is equivalent to enqueuing a single BD chain.

    This optimization is relevant to L2 fused schedules with buffer fragmentation,
    where some transfers can be chained automatically to reduce control overhead.
    The input and output DataTransfer lists must be behaviorally equivalent, but may
    have some re-enqueues replaced with chaining to reduce control overhead.
    '''

    def construct_transfer_map() -> Dict[Sequence[int], List[DataTransfer]]:
        transfer_map = {}
        for transfer in data_transfers:
            repeats = tuple(transfer.repeat_counts)
            if repeats not in transfer_map:
                transfer_map[repeats] = [transfer]
            else:
                transfer_map[repeats].append(transfer)
        return transfer_map

    def can_auto_chain(
        transfer_map: Dict[Sequence[int], DataTransfer],
        repeats: Sequence[int],
    ) -> bool:

        def has_overlapping_reenqueue(
            repeats_1: Sequence[int],
            repeats_2: Sequence[int],
        ) -> bool:
            N = len(repeats_1)
            assert len(repeats_1) == len(repeats_2)
            for i in range(N - 1):
                a1 = repeats_1[i]
                a2 = repeats_1[i + 1]
                b1 = repeats_2[i]
                b2 = repeats_2[i + 1]
                if all(r > 0 for r in (a1, a2, b1, b2)):
                    return True
            return False

        transfer_group = transfer_map[repeats]
        transfer_counts = {}
        for transfer in transfer_group:
            for param in (transfer.read_params + transfer.write_params):
                if param.dma not in transfer_counts:
                    transfer_counts[param.dma] = 1
                else:
                    transfer_counts[param.dma] += 1
        all_unit_repeats = all(count in (0, 1) for count in repeats)
        all_read_only = all(len(transfer.write_params) == 0 for transfer in transfer_group)
        all_single_buffer = all(len(transfer.buffer_addrs) == 1 for transfer in transfer_group)
        all_single_config = all(param._num_reconfig() == 1
                                for transfer in transfer_group
                                for param in transfer.read_params)
        all_single_reuse = all(transfer.reuse_ratio == 1 for transfer in transfer_group)
        any_chain_exists = any(transfer_counts[dma] > 1 for dma in transfer_counts)
        no_dependent_reenqueue = all(not has_overlapping_reenqueue(repeats, other_repeats)
                                     for other_repeats in transfer_map if other_repeats != repeats)
        res = (
            all_unit_repeats and
            all_read_only and
            all_single_buffer and
            all_single_config and
            all_single_reuse and
            any_chain_exists and
            no_dependent_reenqueue
        )
        return res

    def chained_repeats(rs: Sequence[int]) -> List[int]:
        new_rs = [0 for _ in range(len(rs))]
        idx = 0
        while idx < len(rs):
            count = 0
            while ((idx + count) < len(rs)) and (rs[idx + count] == 1):
                count += 1
            new_rs[idx] = count
            idx += count + 1
        return new_rs

    transfer_map = construct_transfer_map()
    chain_map = {}
    for repeats in transfer_map:
        transfer_group = transfer_map[repeats]
        if can_auto_chain(transfer_map, repeats):
            params_map = {}
            for transfer in transfer_group:
                assert len(transfer.buffer_addrs) == 1
                assert len(transfer.write_params) == 0
                assert transfer.tile.type == TileType.Memtile
                for param in transfer.read_params:
                    word_size = 4
                    tile_diff = (transfer.tile.col - param.dma.tile.col)
                    buffer_addr = transfer.buffer_addrs[0]
                    neighbor_offset = tile_diff * (config.MAX_MEMTILE_ADDR + 1)
                    relative_offset = param._offset
                    new_param = copy(param)
                    new_param._offset = (
                        (buffer_addr // word_size) +
                        (neighbor_offset // word_size) +
                        relative_offset
                    )
                    if new_param.dma not in params_map:
                        params_map[new_param.dma] = [new_param]
                    else:
                        params_map[new_param.dma].append(new_param)
            for dma in params_map:
                new_repeats = chained_repeats(repeats)
                dummy_addrs = [0]
                dummy_size = 1
                write_params = []
                read_params = params_map[dma]
                chain_map[(repeats, dma)] = DataTransfer(
                    new_repeats,
                    dma.tile, dummy_addrs, dummy_size,
                    write_params,
                    read_params,
                )

    chained_transfers = []
    seen_chains = {}
    for transfer in data_transfers:
        repeats = tuple(transfer.repeat_counts)
        active_dmas = [param.dma for param in transfer.write_params + transfer.read_params]
        has_chain = all((repeats, dma) in chain_map for dma in active_dmas)
        has_no_chain = not any((repeats, dma) in chain_map for dma in active_dmas)
        assert has_chain ^ has_no_chain
        if has_chain:
            for dma in active_dmas:
                if (repeats, dma) not in seen_chains:
                    chained_transfers.append(chain_map[(repeats, dma)])
                    seen_chains[(repeats, dma)] = None
        else:
            chained_transfers.append(transfer)
    return chained_transfers


def auto_fold_enqueue(data_transfers: List[DataTransfer]) -> List[DataTransfer]:
    '''
    This optimization pass will automatically fold enqueues of single reader
    DMA transfers. For example, if a single DMA read transfer has the repeat
    count list with [1, 1, 1], then this can be folded to [3, 0, 0] as long
    as there aren't any conflicting tasks enqueued on the same channel.

    This optimization happens after auto chaining in L2 fused schedules to
    further reduce control overhead assosciated with polling. These extra optimizations
    are typically only relevant to L2 fused AIE-2p operators, which require complicated
    forms of buffer fragmentation.
    '''

    def index_of_first_positive(start: int, nums: List[int]) -> Optional[int]:
        for idx in range(start, len(nums)):
            if nums[idx] > 0:
                return idx
        return None

    def index_of_last_positive(row: List[int]) -> Optional[int]:
        """
        Returns the index of the last positive (> 0) element in the row.
        """
        for idx in reversed(range(len(row))):
            if row[idx] > 0:
                return idx
        return None

    def is_increasing_sequence(nums: List[Optional[int]]) -> bool:
        for i in range(len(nums)):
            is_invalid = (
                (nums[i] is None) or
                ((i > 0) and (nums[i - 1] >= nums[i]))
            )
            if is_invalid:
                return False
        return True

    folded_transfers = copy(data_transfers)
    for transfer in folded_transfers:
        transfer.repeat_counts = copy(transfer.repeat_counts)

    read_transfer_map = {}
    max_iteration_map = {}
    for transfer in folded_transfers:
        assert transfer.tile.type == TileType.Memtile
        has_single_reader = (
            (len(transfer.write_params) == 0) and
            (len(set([param.dma for param in transfer.read_params])) == 1)
        )
        if has_single_reader:
            dma = transfer.read_params[0].dma
            if dma not in read_transfer_map:
                read_transfer_map[dma] = [transfer]
            else:
                read_transfer_map[dma].append(transfer)
        for param in transfer.write_params + transfer.read_params:
            dma = param.dma
            iter = (
                0 if has_single_reader else
                index_of_last_positive(transfer.repeat_counts)
            )
            if dma not in max_iteration_map:
                max_iteration_map[dma] = iter
            else:
                max_iteration_map[dma] = max(max_iteration_map[dma], iter)

    for dma in read_transfer_map:
        dma_transfers = read_transfer_map[dma]
        num_iters = len(dma_transfers[0].repeat_counts)
        curr_iter = 0
        while curr_iter < num_iters:
            first_repeats = [index_of_first_positive(curr_iter, transfer.repeat_counts)
                             for transfer in dma_transfers]
            curr_iter = first_repeats[0] if first_repeats[0] is not None else num_iters
            can_fold_enqueue = (
                is_increasing_sequence(first_repeats) and
                (len(first_repeats) <= config.MAX_TASK_QUEUE_SIZE) and
                (curr_iter >= max_iteration_map[dma])
            )
            if can_fold_enqueue:
                for i in range(1, len(dma_transfers)):
                    transfer = dma_transfers[i]
                    prev_iter = first_repeats[i]
                    assert curr_iter < prev_iter
                    assert transfer.repeat_counts[curr_iter] == 0
                    assert transfer.repeat_counts[prev_iter] > 0
                    transfer.repeat_counts[curr_iter] = transfer.repeat_counts[prev_iter]
                    transfer.repeat_counts[prev_iter] = 0
            curr_iter += 1

    return folded_transfers


   
def compile_data_transfers(
    shape: OverlayShape,
    memtile_transfers: List[DataTransfer],
    shim_transfers: List[DataTransfer],
    enable_task_queue_optimization: bool,
) -> List[DataBuffer]:

    len_repeat_list = len(memtile_transfers[0].repeat_counts)
    if len_repeat_list == 1:
        enable_task_queue_optimization = False

    for tile_transfer in [memtile_transfers, shim_transfers]:
        for transfer in tile_transfer:
            assert len(transfer.repeat_counts) == len_repeat_list, "repeat_count list for all transfers should be of equal length"               
            
    alloc = DmaAllocator(shape)

    #
    # Run Optimization Pass #0
    #
    memtile_transfers = auto_chain_transfers(memtile_transfers)
    memtile_transfers = auto_fold_enqueue(memtile_transfers)

    #
    # Run Optimization Pass #1
    #
    # NOTE: The multi-dimensional lists have the following sizes for each axis:
    #
    #       transfer_buffers - num transfer locations x num buffers x num allocations for a single buffer
    #       transfer_locks   - num transfer locations x num locks
    #       transfer_repeats - num_transfer locations x num iters
    #       transfer_reuse   - num transfer locations
    #
    # NOTE: Each axis has the following meaning
    #
    #       - Transfer location is the specific tile (type, col, row)
    #       - Number of buffers is the pipelined buffering scheme (single, double, etc.)
    #       - Allocations are the DMA resources (BDs, locks, iterations)
    #
    transfers = memtile_transfers + shim_transfers
    transfer_buffers: List[List[List[BufferDescriptor]]] = []
    transfer_locks: List[List[Lock]] = []
    transfer_repeats: List[List[int]] = []
    transfer_reuse: List[int] = []
    for transfer in transfers:
        splits = split_data_transfer(transfer)
        buffers = [[] for _ in range(len(transfer.buffer_addrs))]
        locks = []
        for split in splits:
            bs, ls = compile_transfer_alloc(alloc, split)
            # Append allocations for each buffer to the corresponding split
            for i in range(len(buffers)):
                buffers[i] += bs[i]
            locks += ls
        transfer_buffers.append(buffers)
        transfer_locks.append(locks)
        transfer_repeats.append(transfer.repeat_counts)
        transfer_reuse.append(transfer.reuse_ratio)
        
    #
    # Run Optimization Pass #4
    #  
    if enable_task_queue_optimization:
        (
            transfer_buffers_optimized,
            transfer_locks_optimized,
            transfer_repeats_optimized,
            transfer_reuse_optimized
            ) = task_queue_optimization(
                shape,
                alloc,
                transfer_buffers,
                transfer_locks,
                transfer_repeats,
                transfer_reuse
                )

    else:
        transfer_buffers_optimized = transfer_buffers
        transfer_locks_optimized   = transfer_locks
        transfer_repeats_optimized = transfer_repeats
        transfer_reuse_optimized   = transfer_reuse  


    #
    # Run Optimization Pass #2
    #
    # NOTE: Buffer descriptors for each data transfer are chained inplace.
    #
    for buffers in transfer_buffers_optimized:
        chain_buffer_descriptors(buffers)

    #
    # Run Optimization Pass #3
    #
    # NOTE: The multi-dimensional lists have the following sizes for each axis:
    #
    #       transfer_tasks - num transfer locations x num iters x num allocations for a single iter
    #
    num_iters = len(transfer_repeats_optimized[0])
    for repeat in transfer_repeats_optimized:
        assert len(repeat) == num_iters
    transfer_tasks: List[List[List[BufferTask]]] = [[[] for _ in range(num_iters)]
                                                    for _ in range(len(transfer_buffers_optimized))]
    for i in range(num_iters):
        for j, buffers, repeats, reuse in zip(range(len(transfer_tasks)),
                                              transfer_buffers_optimized,
                                              transfer_repeats_optimized,
                                              transfer_reuse_optimized):
            transfer_tasks[j][i] = compute_buffer_tasks(
                alloc,
                buffers,
                repeats[i],
                reuse
            )
        alloc.clear_tasks()
    #
    # Pack final compilation result
    #
    data_buffers: List[DataBuffer] = []
    for buffers, locks, tasks in zip(transfer_buffers_optimized,
                                     transfer_locks_optimized,
                                     transfer_tasks):
        bds = [buffers[i][j]
               for i in range(len(buffers))
               for j in range(len(buffers[i]))]
        data_buffers.append(DataBuffer(bds, locks, tasks))

    return data_buffers

