'''
This module implements the tiling expression engine.
External facing functions are documented below.

generate_transfer_params - compiles a tiling expression
to a single TransferParams object or optionally a BD chain
when iter_step is enabled

generate_shim_data_transfer - compiles a tiling expression
to a DataTransfer object at the shim, where the iter_step usage
is handled automically with scaling of repeat counts

generate_core_buffer_config - compiles a tiling expression
to a ConfigBuffer core instruction

compute_buffer_size - computes the underlying buffer size from a
memory format string and data type
'''


from typing import List, Dict, Tuple, Optional, Union
from math import gcd

from .types import (
    DevGen, TileType, DmaDir, AieDma, TransferParams, DataTransfer, ConfigBuffer,
    SyncStrategy, AieTile,
)
from . import config


MemoryFormatDim = Tuple[str, int]
TilingFormatDim = Tuple[str, int, Optional[int], Optional[int], Optional[int]]
AxisMap = Dict[str, List[Tuple[int, int]]]


class SliceDim:
    __slots__ = ('axis', 'start', 'stop', 'stride', 'pad_before', 'pad_after')
    def __init__(
        self,
        axis: str,
        start: int,
        stop: int,
        stride: int,
        pad_before: int,
        pad_after: int,
    ):
        self.axis = axis
        self.start = start
        self.stop = stop
        self.stride = stride
        self.pad_before = pad_before
        self.pad_after = pad_after


class LoopDim:
    __slots__ = ('step', 'wrap', 'pad_before', 'pad_after')
    def __init__(
        self,
        step: int,
        wrap: int,
        pad_before: int,
        pad_after: int,
    ):
        self.step = step
        self.wrap = wrap
        self.pad_before = pad_before
        self.pad_after = pad_after


Tiling = Tuple[int, List[LoopDim]]


def prod(xs: List[int]) -> int:
    p = 1
    for x in xs: p *= x
    return p


def ceildiv(x: int, d: int) -> int:
    return -(x // -d)


def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d


def parse_memory_dim(dim: str) -> MemoryFormatDim:
    usage = ValueError('Invalid memory dimension format, '
                       'expected axis:size!')
    def parse_axis(ids: List[str], i: int) -> str:
        axis = ids[i]
        if not axis.isalpha():
            raise usage
        return axis
    def parse_size(ids: List[str], i: int) -> str:
        try:
            size = int(ids[i])
        except ValueError:
            raise usage
        if not (size > 0):
            raise usage
        return size
    ids = dim.split(':')
    if len(ids) != 2:
        raise usage
    axis = parse_axis(ids, 0)
    size = parse_size(ids, 1)
    return (axis, size)


def parse_memory_format(format: str) -> List[MemoryFormatDim]:
    return [parse_memory_dim(dim) for dim in reversed(format.split())]


def parse_tiling_dim(dim: str) -> TilingFormatDim:
    usage = ValueError('Invalid tiling dimension format, '
                       'expected axis:[start]:[stop]:[stride]!')
    def parse_axis(ids: List[str], i: int) -> str:
        axis = ids[i]
        if (not axis.isalpha()) and (axis != '_'):
            raise usage
        return axis
    def parse_bounds(ids: List[str], i: int) -> Optional[int]:
        if len(ids) > i and ids[i] != '':
            try:
                bounds = int(ids[i])
            except ValueError:
                raise usage
        else:
            bounds = None
        return bounds
    ids = dim.split(':')
    if not (1 <= len(ids) <= 4):
        raise usage
    axis = parse_axis(ids, 0)
    start = parse_bounds(ids, 1)
    stop = parse_bounds(ids, 2)
    stride = parse_bounds(ids, 3)
    return (axis, start, stop, stride)


def parse_tiling_format(format: str) -> List[TilingFormatDim]:
    return [parse_tiling_dim(dim) for dim in reversed(format.split())]


def memory_dims_to_axis_map(dims: List[MemoryFormatDim]) -> AxisMap:
    axis_map: AxisMap = {axis: [] for axis, _ in dims}
    curr_step = 1
    curr_wrap = {axis: 1 for axis, _ in dims}
    for axis, size in dims:
        if size % curr_wrap[axis] != 0:
            raise ValueError(f'Invalid size {size} for axis {axis}, '
                             f'must be divisible by {curr_wrap[axis]}!')
        wrap = size // curr_wrap[axis]
        axis_map[axis].append((curr_step, wrap))
        curr_step *= wrap
        curr_wrap[axis] *= wrap
    # NOTE: The underscore _ is a wildcard axis that allows a tiling
    # expression to loop over the same data with a step size of zero.
    axis_map['_'] = [(0, 2**32 - 1)]
    return axis_map


def infer_slice_dims(
    tiling_dims: List[TilingFormatDim],
    memory_dims: List[MemoryFormatDim],
    enable_padding: bool,
) -> List[SliceDim]:
    '''
    Infer slice (start, stop, stride) dimensions from parsed tiling format
    with bounds checking of memory axis dimensions.

        1. Default start index is zero
        2. Default stop index is the axis size
        3. Default stride is 1
        4. Negative start indicies imply padding before the loop
        5. Positive stop indicies larger than the axis size imply padding after the loop
    '''
    def max_index(start: int, stop: int, stride: int) -> int:
        return start + (((stop - start - 1) // stride) * stride)
    axis_size = {axis: size for axis, size in memory_dims}
    axis_size['_'] = 2**32 - 1
    infer_dims = []
    for axis, start, stop, stride in tiling_dims:
        size = axis_size[axis]
        start = 0 if start is None else start
        stop = size if stop is None else stop
        stride = 1 if stride is None else stride
        # NOTE: If any tiling dimension has zero iterations, then
        # we automatically generate a zero-length transfer.
        if start >= stop:
            return []
        if not ((start < size) and (stop > 0) and (stride > 0)):
            raise ValueError(f'Invalid start {start} and stop {stop} '
                             f'for axis {axis} with size {size}!')
        stop = start + iceil(stop - start, stride)
        pad_before = -(min(start, 0) // stride)
        pad_after = max(0, stop - size) // stride
        start_no_pad = start + (pad_before * stride)
        stop_no_pad = stop - (pad_after * stride)
        if (pad_before > 0) and (not enable_padding):
            raise ValueError(f'Invalid start {start} for axis {axis}, '
                             f'less than zero but padding is not enabled!')
        if (pad_after > 0) and (not enable_padding):
            raise ValueError(f'Invalid stop {stop} for axis {axis}, '
                             f'exceeds size {size} but padding is not enabled!')
        assert start_no_pad < stop_no_pad
        assert start_no_pad >= 0
        assert stop_no_pad - stride < size
        assert (
            (((stop_no_pad - start_no_pad) // stride) + pad_before + pad_after) ==
            ((stop - start) // stride)
        )
        infer_dims.append(SliceDim(
            axis,
            start_no_pad,
            stop_no_pad,
            stride,
            pad_before,
            pad_after,
        ))
    for axis in axis_size:
        axis_dims = [(dim.start, dim.stop, dim.stride) for dim in infer_dims
                     if dim.axis == axis]
        index = sum([max_index(start, stop, stride)
                     for start, stop, stride in axis_dims])
        size = axis_size[axis]
        if index >= size:
            raise ValueError(f'Invalid index {index} for axis {axis}, '
                             f'must be less than size {size}!')
    return infer_dims


def factor_loop_dims(
    axis_map: AxisMap,
    slice_dims: List[SliceDim],
) -> List[LoopDim]:
    '''
    Convert logical slice dimensions in (start, stop, stride) format
    to physical loop dimensions in (step, wrap) format.

    This uses the underlying memory order to determine how many physical
    loop dimensions are required for a logical slice. One logical slice
    may require several physical loops if the memory axis is not contiguous,
    as is the case with C8 data in convolution feature maps for example.
    '''
    def nesting_level(wraps: List[int], size: int) -> int:
        total_wrap = 1
        nesting = -1
        for wrap in wraps:
            total_wrap *= wrap
            nesting += 1
            if (size <= total_wrap):
                break
        assert 0 <= nesting < len(wraps)
        return nesting
    loop_dims = []
    for dim in slice_dims:
        axis_steps = [step for step, _ in axis_map[dim.axis]]
        axis_wraps = [wrap for _, wrap in axis_map[dim.axis]]
        inner_nesting = nesting_level(axis_wraps, dim.stride)
        outer_nesting = nesting_level(axis_wraps, dim.stop)
        inner_wrap = prod(axis_wraps[:inner_nesting])
        outer_wrap = prod(axis_wraps[:outer_nesting])
        if dim.start % outer_wrap != 0:
            raise ValueError(f'Invalid start {dim.start} for axis {dim.axis}, '
                             f'must be divisible by {outer_wrap}!')
        if dim.stop % outer_wrap != 0:
            raise ValueError(f'Invalid stop {dim.stop} for axis {dim.axis}, '
                             f'must be divisible by {outer_wrap}!')
        loop_steps = [step for step in axis_steps]
        loop_wraps = [wrap for wrap in axis_wraps]
        loop_wraps[outer_nesting] = (dim.stop // outer_wrap) - (dim.start // outer_wrap)
        if dim.stride % inner_wrap != 0:
            raise ValueError(f'Invalid stride {dim.stride} for axis {dim.axis}, '
                             f'must be divisible by {inner_wrap}!')
        stride_coeff = dim.stride // inner_wrap
        if loop_wraps[inner_nesting] % stride_coeff != 0:
            raise ValueError(f'Invalid loop dimension {loop_wraps[inner_nesting]} '
                             f'for axis {dim.axis}, must be divisible by {stride_coeff}!')
        loop_steps[inner_nesting] *= stride_coeff
        loop_wraps[inner_nesting] //= stride_coeff
        pad_divisor = prod(loop_wraps[inner_nesting:outer_nesting])
        if dim.pad_before % pad_divisor != 0:
            raise ValueError(f'Invalid pad before {dim.pad_before} for axis {dim.axis}, '
                             f'must be divisible by {pad_divisor}!')
        if dim.pad_after % pad_divisor != 0:
            raise ValueError(f'Invalid pad after {dim.pad_after} for axis {dim.axis}, '
                             f'must be divisible by {pad_divisor}!')
        for i in range(inner_nesting, outer_nesting + 1):
            step = loop_steps[i]
            wrap = loop_wraps[i]
            pad_before = (dim.pad_before // pad_divisor) if i == outer_nesting else 0
            pad_after = (dim.pad_after // pad_divisor) if i == outer_nesting else 0
            loop_dims.append(LoopDim(step, wrap, pad_before, pad_after))
    return loop_dims


def filter_loop_dims(dims: List[LoopDim]) -> List[LoopDim]:
    '''Remove redundant loops with just one iteration'''
    def can_remove(dim: LoopDim) -> bool:
        return (
            (dim.wrap == 1) and
            (dim.pad_before == 0) and
            (dim.pad_after ==  0)
        )
    filter_dims = [dim for dim in dims if not can_remove(dim)]
    # NOTE: Here we have special case, so that not all dimensions
    # are filtered out. Other optimization passes expect the dimension
    # list to be non-empty.
    if len(filter_dims) == 0:
        filter_dims.append(dims[0])
    return filter_dims


def fold_loop_dims(dims: List[LoopDim]) -> List[LoopDim]:
    '''Combine adjacent loops that map to a fixed step size data traversal'''
    def can_fold(step: List[int], wrap: List[int],
                 pad_before: List[int], pad_after: List[int],
                 i: int, j: int) -> bool:
        return (
            (j < len(step)) and
            (step[j] == step[i] * prod(wrap[i:j])) and
            (sum(pad_before[i:j]) == 0) and
            (sum(pad_after[i:j]) == 0)
        )
    step = [dim.step for dim in dims]
    wrap = [dim.wrap for dim in dims]
    pad_before = [dim.pad_before for dim in dims]
    pad_after = [dim.pad_after for dim in dims]
    fold_dims = []
    start_dim = 0
    while start_dim < len(dims):
        i = start_dim
        j = start_dim + 1
        while can_fold(step, wrap, pad_before, pad_after, i, j):
            j += 1
        folded_step = step[i]
        folded_wrap = prod(wrap[i:j])
        folded_pad_before = pad_before[j - 1] * prod(wrap[i:(j - 1)])
        folded_pad_after = pad_after[j - 1] * prod(wrap[i:(j - 1)])
        fold_dims.append(LoopDim(folded_step, folded_wrap,
                                 folded_pad_before, folded_pad_after))
        start_dim += j - i
    return fold_dims


def compute_axis_offset(axis_map: AxisMap, axis: str, idx: int) -> int:
    offset = 0
    for step, wrap in axis_map[axis]:
        offset += (idx % wrap) * step
        idx //= wrap
    return offset


def compute_start_offset(
    axis_map: AxisMap,
    slice_dims: List[SliceDim],
) -> int:
    offset = 0
    for dim in slice_dims:
        offset += compute_axis_offset(axis_map, dim.axis, dim.start)
    return offset


def generate_virtual_tiling(
    memory_format: str,
    tiling_format: str,
    enable_padding: bool = False,
) -> Tiling:
    memory_dims = parse_memory_format(memory_format)
    tiling_dims = parse_tiling_format(tiling_format)
    axis_map = memory_dims_to_axis_map(memory_dims)
    slice_dims = infer_slice_dims(tiling_dims, memory_dims, enable_padding)
    if len(slice_dims) == 0:
        offset, fold_dims = 0, []
    else:
        factor_dims = factor_loop_dims(axis_map, slice_dims)
        filter_dims = filter_loop_dims(factor_dims)
        fold_dims = fold_loop_dims(filter_dims)
        offset = compute_start_offset(axis_map, slice_dims)
    return offset, fold_dims


def convert_physical_dims(
    dims: List[LoopDim],
    bits_per_block: int,
    elements_per_block: int,
    word_size: int,
) -> List[LoopDim]:
    '''Convert dimensions to physical units of 32-bit words'''
    assert len(dims) > 0
    def to_words(num_elements: int):
        bits_per_word = word_size * 8
        if (num_elements % elements_per_block) != 0:
            raise ValueError(f'Invalid dimension size {num_elements} elements, '
                             f'must be divisible by {elements_per_block}!')
        num_blocks = num_elements // elements_per_block
        num_bits = num_blocks * bits_per_block
        if num_bits % bits_per_word != 0:
            raise ValueError(f'Invalid dimension size {num_bits} bits, '
                             f'must be divisible by {bits_per_word}!')
        num_words = num_bits // bits_per_word
        return num_words
    step0 = dims[0].step
    wrap0 = dims[0].wrap
    pad_before0 = dims[0].pad_before
    pad_after0 = dims[0].pad_after
    if step0 == 1:
        conv_dims = [
            LoopDim(1, to_words(wrap0), to_words(pad_before0), to_words(pad_after0)),
        ]
    else:
        conv_dims = [
            LoopDim(1, to_words(1), 0, 0),
            LoopDim(to_words(step0), wrap0, pad_before0, pad_after0),
        ]
    for dim in dims[1:]:
        conv_dims.append(
            LoopDim(to_words(dim.step), dim.wrap, dim.pad_before, dim.pad_after)
        )
    # NOTE: The conversion of wrap units to words may have created
    # an unnecessary dimension, so we re-run the filtering pass.
    if config.DEV_GEN == DevGen.Aie2p:
        conv_dims = filter_loop_dims(conv_dims)
    elif config.DEV_GEN == DevGen.Aie4:
        assert conv_dims[0].step == 1
    else:
        assert False
    return conv_dims


def convert_physical_offset(
    offset: int,
    bits_per_block: int,
    elements_per_block: int,
    word_size: int,
) -> int:
    bits_per_word = word_size * 8
    if offset % elements_per_block != 0:
        raise ValueError(f'Invalid offset {offset} elements, '
                         f'must be divisible by {elements_per_block}!')
    offset_blocks = offset // elements_per_block
    offset_bits = offset_blocks * bits_per_block
    if offset_bits % bits_per_word != 0:
        raise ValueError(f'Invalid offset {offset_bits} bits, '
                         f'must be divisible by {bits_per_word}!')
    offset_words = offset_bits // bits_per_word
    return offset_words


def generate_physical_tiling(
    memory_format: str,
    tiling_format: str,
    bits_per_block: int = 8,
    elements_per_block: int = 1,
    enable_padding: bool = False,
) -> Tiling:
    word_size: int = 4
    virtual_offset, virtual_dims = generate_virtual_tiling(
        memory_format,
        tiling_format,
        enable_padding=enable_padding,
    )
    if len(virtual_dims) == 0:
        physical_offset, physical_dims = 0, []
    else:
        physical_offset = convert_physical_offset(
            virtual_offset,
            bits_per_block,
            elements_per_block,
            word_size,
        )
        physical_dims = convert_physical_dims(
            virtual_dims,
            bits_per_block,
            elements_per_block,
            word_size,
        )
    return physical_offset, physical_dims


def factor_dimension_overflow(
    tile_type: TileType,
    dims: List[LoopDim],
) -> List[LoopDim]:
    '''
    Factor a loop into multiple loops when the wrap or padding
    fields overflow the maximum allowed value. This logic will first
    try factoring one loop into two loops. If that fails, then it will try
    factoring one loop into three loops. If both factorizations fail, then
    the error is reported.
    '''

    if tile_type == TileType.Core:
        max_wrap = config.MAX_CORE_WRAP
        max_pads = []
    elif tile_type == TileType.Memtile:
        max_wrap = config.MAX_MEMTILE_WRAP
        max_pads = [config.MAX_MEMTILE_D0_PAD, config.MAX_MEMTILE_D1_PAD, config.MAX_MEMTILE_D2_PAD]
    elif tile_type == TileType.Shim:
        max_wrap = config.MAX_SHIM_WRAP
        max_pads = []
    else:
        assert False

    def has_overflow(wrap: int, pad_before: int, pad_after: int, i: int) -> bool:
        if i < len(max_pads):
            max_pad = max_pads[i]
        else:
            max_pad = 0
        has_overflow = (
            (wrap > max_wrap) or
            (pad_before > max_pad) or
            (pad_after > max_pad)
        )
        return has_overflow

    def dim_has_overflow(dim: LoopDim, i: int):
        return has_overflow(dim.wrap, dim.pad_before, dim.pad_after, i)

    def factor2_wrap_with_pad(
        wrap: int,
        pad_before: int,
        pad_after: int,
        i: int,
    ) -> Tuple[Tuple[int, int], Tuple[int, int]]:
        for d in range(gcd(gcd(wrap, pad_before), pad_after), 1, -1):
            is_divisible = (
                ((pad_before % d) == 0) and
                ((pad_after % d) == 0) and
                ((wrap % d) == 0)
            )
            if is_divisible:
                wrap0 = d
                wrap1 = wrap // d
                pad1_before = pad_before // d
                pad1_after = pad_after // d
                is_valid = (
                    (not has_overflow(wrap0, 0, 0, i)) and
                    (not has_overflow(wrap1, pad1_before, pad1_after, i + 1))
                )
                if is_valid:
                    return ((wrap0, wrap1), (pad1_before, pad1_after))
        raise ValueError(f'Invalid wrap {wrap} with pad before {pad_before} and '
                         f'pad after {pad_after}, failed to factor dimension!')

    def factor3_wrap_with_pad(
        wrap: int,
        pad_before: int,
        pad_after: int,
        i: int,
    ) -> Tuple[Tuple[int, int, int], Tuple[int, int]]:
        for d_inner in range(gcd(gcd(wrap, pad_before), pad_after), 1, -1):
            is_inner_divisible = (
                ((pad_before % d_inner) == 0) and
                ((pad_after % d_inner) == 0) and
                ((wrap % d_inner) == 0)
            )
            if is_inner_divisible:
                wrap0 = d_inner
                inner_wrap = wrap // d_inner
                inner_pad_before = pad_before // d_inner
                inner_pad_after = pad_after // d_inner
                for d_outer in range(gcd(gcd(inner_wrap, inner_pad_before), inner_pad_after), 1, -1):
                    is_outer_divisible = (
                        ((inner_pad_before % d_outer) == 0) and
                        ((inner_pad_after % d_outer) == 0) and
                        ((inner_wrap % d_outer) == 0)
                    )
                    if is_outer_divisible:
                        wrap1 = d_outer
                        wrap2 = inner_wrap // d_outer
                        pad2_before = inner_pad_before // d_outer
                        pad2_after = inner_pad_after // d_outer
                        is_valid = (
                            (not has_overflow(wrap0, 0, 0, i)) and
                            (not has_overflow(wrap1, 0, 0, i + 1)) and
                            (not has_overflow(wrap2, pad2_before, pad2_after, i + 2))
                        )
                        if is_valid:
                            return ((wrap0, wrap1, wrap2), (pad2_before, pad2_after))
        raise ValueError(f'Invalid wrap {wrap} with pad before {pad_before} and '
                         f'pad after {pad_after}, failed to factor dimension!')

    factor_dims = []
    for i in range(len(dims)):
        has_padding = ((dims[i].pad_before > 0) or (dims[i].pad_after > 0))
        is_last_dim = (i == len(dims) - 1)
        if dim_has_overflow(dims[i], i) and ((not is_last_dim) or has_padding):
            try:
                (wrap0, wrap1), (pad1_before, pad1_after) = factor2_wrap_with_pad(
                    dims[i].wrap, dims[i].pad_before, dims[i].pad_after, i,
                )
                step0 = dims[i].step
                step1 = dims[i].step * wrap0
                factor_dims.append(LoopDim(step0, wrap0, 0, 0))
                factor_dims.append(LoopDim(step1, wrap1, pad1_before, pad1_after))
            except ValueError:
                (wrap0, wrap1, wrap2), (pad2_before, pad2_after) = factor3_wrap_with_pad(
                    dims[i].wrap, dims[i].pad_before, dims[i].pad_after, i,
                )
                step0 = dims[i].step
                step1 = dims[i].step * wrap0
                step2 = dims[i].step * wrap0 * wrap1
                factor_dims.append(LoopDim(step0, wrap0, 0, 0))
                factor_dims.append(LoopDim(step1, wrap1, 0, 0))
                factor_dims.append(LoopDim(step2, wrap2, pad2_before, pad2_after))
        else:
            factor_dims.append(dims[i])
    return factor_dims


def assign_physical_dims(
    dma: AieDma,
    offset: int,
    dims: List[LoopDim],
    use_iter_step: bool,
    max_chain_length: int,
    name: str = "",
) -> Union[TransferParams, Tuple[int, List[TransferParams]]]:
    '''
    Assign loops to physical BD dimensions. This may involve splitting
    loops when the wrap or padding values overflow the maximum allowed fields.

    The user can optionally enable factoring a dimension into the iteration
    step field. This requires multiplying the repeat count of the BD by the
    iteration wrap value.
    '''
    error_msg = RuntimeError('Failed to assign tiling dimensions!')
    def pack_wrap(dims: List[LoopDim]) -> List[int]:
        if (dims[-1].pad_before > 0) or (dims[-1].pad_after > 0):
            return [dim.wrap for dim in factor_dims]
        else:
            return [dim.wrap for dim in factor_dims[:-1]]
    def pack_padding(dims: List[LoopDim]) -> List[Tuple[int, int]]:
        padding = [(dim.pad_before, dim.pad_after) for dim in dims]
        end = len(padding)
        while end > 0 and padding[end - 1] == (0, 0):
            end -= 1
        return padding[:end]
    def iteration_chain_length(wrap: int):
        for chain_length in range(1, max_chain_length + 1):
            is_valid = (
                ((wrap % chain_length) == 0) and
                ((wrap // chain_length) <= config.MAX_ITER_WRAP)
            )
            if is_valid: return chain_length
        raise error_msg
    has_iter_step = use_iter_step and (len(dims) > 1)
    num_buffer_dims = len(dims) - 1 if has_iter_step else len(dims)
    length = prod([dim.wrap + dim.pad_before + dim.pad_after
                   for dim in dims[:num_buffer_dims]])
    factor_dims = factor_dimension_overflow(dma.tile.type, dims[:num_buffer_dims])
    buffer_step = [dim.step for dim in factor_dims]
    buffer_wrap = pack_wrap(factor_dims)
    padding = pack_padding(factor_dims)
    if has_iter_step:
        assert len(dims) == num_buffer_dims + 1
        dim = dims[num_buffer_dims]
        if (dim.pad_before > 0) or (dim.pad_after > 0):
            raise error_msg
        chain_length = iteration_chain_length(dim.wrap)
        iter_step = dim.step * chain_length
        iter_wrap = dim.wrap // chain_length
        chain_offsets = [offset + (i * dim.step)
                         for i in range(chain_length)]
    else:
        assert len(dims) == num_buffer_dims
        iter_step = None
        iter_wrap = None
        chain_offsets = [offset]
    repeat_coeff = iter_wrap if has_iter_step else 1
    transfer_chain = [
        TransferParams(
            dma, length,
            offset=chain_offset,
            step=buffer_step,
            wrap=buffer_wrap,
            padding=padding,
            iter_step=iter_step,
            iter_wrap=iter_wrap,
            name=name,
        ) for chain_offset in chain_offsets
    ]
    return (repeat_coeff, transfer_chain) if use_iter_step else transfer_chain[0]


def generate_transfer_params(
    dma: AieDma,
    memory_format: str,
    tiling_format: str,
    bits_per_block: int = 8,
    elements_per_block: int = 1,
    enable_padding: bool = False,
    use_iter_step: bool = False,
    max_chain_length: int = 4,
    buffer_offset: int = 0,
    verbose: bool = True,
    use_bd_chain_for_dims: bool = False,
    name: str = "",
) -> Union[TransferParams, Tuple[int, List[TransferParams]]]:
    config.check_init()
    if verbose:
        print(dma, memory_format, tiling_format, sep=', ')
    offset = None
    dims = None
    transfers = None
    try:
        offset, dims = generate_physical_tiling(
            memory_format,
            tiling_format,
            bits_per_block=bits_per_block,
            elements_per_block=elements_per_block,
            enable_padding=enable_padding,
        )
        word_size = 4
        assert (buffer_offset % word_size) == 0
        offset += buffer_offset // word_size
        if len(dims) == 0:
            if use_iter_step:
                transfers = (1, [TransferParams(dma, 0)])
            else:
                transfers = TransferParams(dma, 0)
        else:
            transfers = assign_physical_dims(
                dma, offset, dims,
                use_iter_step,
                max_chain_length,
                name=name,
            )
    except ValueError as e:
        if e == "Failed to assign tiling dimensions!" and use_bd_chain_for_dims:
            '''
            This pass will try to use BD chaining for the outer most tiling dimension
            if the transfer assignment fails.
            '''
            tiling_parts = tiling_format.split()
            outer_dim = tiling_parts[0]
            outer_dim_parts = outer_dim.split(':')
            chained_transfers = []
            if len(outer_dim_parts) == 4:
                outer_dim_loop = outer_dim_parts[2] // outer_dim_parts[3]
                for i in range(outer_dim_loop):
                    outer_dim_start = outer_dim_parts[1] + (i * outer_dim_parts[3])
                    outer_dim_stop = outer_dim_start + outer_dim_parts[3]
                    outer_dim = f'{outer_dim_parts[0]}:{outer_dim_start}:{outer_dim_stop}'
                    new_tiling_format = ' '.join([outer_dim] + tiling_parts[1:])
                    print(f'Attempting to use BD chaining for tiling format: {new_tiling_format}')
                    try:
                        offset, dims = generate_physical_tiling(
                            memory_format,
                            new_tiling_format,
                            bits_per_block=bits_per_block,
                            elements_per_block=elements_per_block,
                            enable_padding=enable_padding,
                        )
                        word_size = 4
                        offset += buffer_offset // word_size
                        transfer = assign_physical_dims(
                            dma, offset, dims,
                            use_iter_step,
                            max_chain_length,
                            name=name,
                        )
                        chained_transfers.append(transfer)
                    except ValueError as e:
                        raise e
                transfers = (1, chained_transfers)
            else:
                '''
                NOTE: This case there is no stride in the last dimnension
                Nothing to do here. Raise the error to the user.
                '''
                raise e
        else:
            raise e
    return transfers


def generate_shim_data_transfer(
    repeat_counts: List[int],
    dma: AieDma,
    shim_buffer_idx: int,
    memory_format: str,
    tiling_format: str,
    bits_per_block: int = 8,
    elements_per_block: int = 1,
    max_chain_length: int = 4,
    buffer_offset: int = 0,
    verbose: bool = True,
    name: str = "",
) -> DataTransfer:
    config.check_init()
    assert dma.tile.type == TileType.Shim
    if verbose:
        print(dma, memory_format, tiling_format, sep=', ')
    # NOTE: Here we first attempt to map the transfer without
    # iter step and chaining, then revert to those if the mapping fails
    try:
        transfer_params = generate_transfer_params(
            dma,
            memory_format,
            tiling_format,
            bits_per_block=bits_per_block,
            elements_per_block=elements_per_block,
            max_chain_length=max_chain_length,
            buffer_offset=buffer_offset,
            verbose=False,
            name=name,
        )
        repeat_coeff = 1
        transfer_chain = [transfer_params]
    except:
        repeat_coeff, transfer_chain = generate_transfer_params(
            dma,
            memory_format,
            tiling_format,
            bits_per_block=bits_per_block,
            elements_per_block=elements_per_block,
            use_iter_step=True,
            max_chain_length=max_chain_length,
            buffer_offset=buffer_offset,
            verbose=False,
            name=name,
        )
    buffer_size = compute_buffer_size(memory_format, bits_per_block)
    write_params = transfer_chain if dma.channel.dir == DmaDir.S2MM else []
    read_params = transfer_chain if dma.channel.dir == DmaDir.MM2S else []
    return DataTransfer(
        [count * repeat_coeff for count in repeat_counts],
        dma.tile, [shim_buffer_idx], buffer_size,
        write_params,
        read_params,
    )


def generate_core_buffer_config(
    dma: AieDma,
    ping_addr: int,
    pong_addr: Optional[int],
    memory_format: str,
    tiling_format: str,
    bits_per_block: int = 8,
    elements_per_block: int = 1,
    verbose: bool = True,
) -> ConfigBuffer:
    config.check_init()
    assert dma.tile.type == TileType.Core
    if verbose:
        print(dma, memory_format, tiling_format, sep=', ')
    offset, dims = generate_physical_tiling(
        memory_format,
        tiling_format,
        bits_per_block=bits_per_block,
        elements_per_block=elements_per_block,
    )
    if len(dims) == 0:
        cfg = ConfigBuffer(dma.channel, ping_addr, pong_addr, 0)
    else:
        dims = factor_dimension_overflow(TileType.Core, dims)
        assert all(dim.pad_before == 0 for dim in dims)
        assert all(dim.pad_after == 0 for dim in dims)
        word_size = 4
        length = prod([dim.wrap for dim in dims]) * word_size
        step = [dim.step for dim in dims]
        wrap = [dim.wrap for dim in dims][:-1]
        cfg = ConfigBuffer(
            dma.channel, ping_addr, pong_addr, length,
            offset=offset,
            step=step,
            wrap=wrap,
        )
    return cfg


def compute_reuse_chain_length(
    reuse_ratio: int,
    num_consumers: int,
    max_chain_length: int = 4
) -> int:
    max_lock_value = 63
    for i in range(1, max_chain_length + 1):
        is_valid = (
            ((reuse_ratio % i) == 0) and
            (((reuse_ratio // i) * num_consumers) <= max_lock_value)
        )
        if is_valid: return i
    raise RuntimeError('Failed to allocate reuse chain!')


def pack_reconfig_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int] = [1],
    bits_per_elem: int = 8,
    use_iter_step: List[bool] = [False],
    buffer_offset: List[int] = [0],
    name: str = "",
) -> Union[TransferParams, Tuple[List[int], TransferParams]]:
    '''
    Generate packed transfer parameters for multiple memory and tiling formats.
    This function normalizes the inputs to lists of the same length and generates
    transfer parameters for each format. If `use_iter_step` is set to True for any
    format, it returns a tuple with the repeat coefficients and the packed transfer
    parameters. Otherwise, it returns just the packed transfer parameters.
    Tiling_iters is an optional parameter that specifies how many times
    each tiling format should be repeated. If it is a single value, it is
    repeated for each format. If it is a list, it must match the length of
    memory_fmts and tiling_fmts.
    buffer_offset is an optional parameter that specifies the offset for each
    tiling format. If it is a single value, it is repeated for each format.
    If it is a list, it must match the length of memory_fmts and tiling_fmts.
    use_iter_step is an optional parameter that specifies whether to use the
    iteration step for each tiling format. If it is a single value, it is
    repeated for each format. If it is a list, it must match the length of
    memory_fmts and tiling_fmts.
    If use_iter_step is True for any format, the function returns a tuple with
    the repeat coefficients and the packed transfer parameters. Otherwise, it
    returns just the packed transfer parameters.
    '''
    assert len(memory_fmts) == len(tiling_fmts)

    # Normalize inputs to lists
    num_fmts = len(tiling_fmts)
    if len(tiling_iters) == 1:
        tiling_iters = tiling_iters * num_fmts
    elif len(tiling_iters) != num_fmts:
        raise ValueError('tiling_iters must be a list of the same length as memory_fmts and tiling_fmts')

    if len(buffer_offset) == 1:
        buffer_offset = buffer_offset * num_fmts
    elif len(buffer_offset) != num_fmts:
        raise ValueError('buffer_offset must be the same length as memory_fmts and tiling_fmts')

    if len(use_iter_step) == 1:
        use_iter_step = use_iter_step * num_fmts
    elif len(use_iter_step) != num_fmts:
        raise ValueError('use_iter_step must be a list of the same length as memory_fmts and tiling_fmts')

    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res

    params = []
    repeat_coeff_list = []
    has_iter_step = any(use_iter_step)

    # Generate transfer parameters for each format
    for i in range(num_fmts):
        transfers = generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
            use_iter_step=use_iter_step[i],
            buffer_offset=buffer_offset[i],
            name=name,
        )
        if isinstance(transfers, tuple):
            params.extend(transfers[1])
            repeat_coeff_list.append(transfers[0])
        else:
            params.append(transfers)
            repeat_coeff_list.append(1)

    # Pack all parameters into a single TransferParams object
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
        iter_step=pack([param.iter_step_i(0) for param in params]),
        iter_wrap=pack([param.iter_wrap_i(0) for param in params]),
        name=name,
    )

    # Return appropriate result based on whether iter_step was used
    if has_iter_step:
        return (repeat_coeff_list, packed_param)
    else:
        return packed_param


def compute_buffer_size(
    memory_format: str,
    bits_per_block: int = 8,
    elements_per_block: int = 1,
) -> int:
    config.check_init()
    memory_dims = parse_memory_format(memory_format)
    axis_size = {axis: size for axis, size in memory_dims}
    buffer_elements = prod([axis_size[axis] for axis in axis_size])
    if (buffer_elements % elements_per_block) != 0:
        raise ValueError(f'Invalid buffer size {buffer_elements} elements, '
                         f'must be divisible by {elements_per_block}!')
    buffer_blocks = buffer_elements // elements_per_block
    buffer_bits = buffer_blocks * bits_per_block
    bits_per_byte = 8
    if buffer_bits % bits_per_byte != 0:
        raise ValueError(f'Invalid buffer size {buffer_bits} bits, '
                         f'must be divisible by {bits_per_byte}!')
    buffer_bytes = buffer_bits // bits_per_byte
    return buffer_bytes


def generate_packed_shim_data_transfer(
    repeat_counts: List[int],
    dma: AieDma,
    shim_buffer_idx: int,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iter_nums: List[int],
    tiling_start_iter: List[int],
    bits_per_elem: int,
    max_chain_length: int = 4,
    name: str = "",
) -> DataTransfer:
    '''
    Reconfigures a BD with different transfer
    params at the shim for poll and re-enqueue
    '''
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iter_nums)
    assert len(tiling_start_iter) == len(tiling_iter_nums)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iter_nums)
        res = []
        for item, num in zip(items, tiling_iter_nums):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = []
    repeat_coeff_iter = [0] * len(tiling_iter_nums)
    for i in range(num_fmts):
        repeat_coeff, transfer_chain = generate_transfer_params(
                dma,
                memory_fmts[i],
                tiling_fmts[i],
                bits_per_block=bits_per_elem,
                enable_padding=False,
                use_iter_step=True,
                max_chain_length=max_chain_length,
                name=name,
        )
        repeat_coeff_iter[i] = repeat_coeff
        for transfer in transfer_chain:
            params.append(transfer)
    packed_params = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
        iter_step=pack([param.iter_step_i(0) for param in params]),
        iter_wrap=pack([param.iter_wrap_i(0) for param in params]),
        name=name,
    )
    buffer_size = compute_buffer_size(memory_fmts[0], bits_per_elem)
    if dma.channel.dir == DmaDir.S2MM:
        write_params = [packed_params]
        read_params = []
    else:
        read_params = [packed_params]
        write_params = []            
    for idx, count in enumerate( tiling_start_iter):
        repeat_counts[count] *= repeat_coeff_iter[idx]

    return DataTransfer(
        repeat_counts,
        dma.tile, [shim_buffer_idx], buffer_size,
        write_params,
        read_params
)


'''
NOTE: The following function is to provide a function
for schedule developer to generate packed memtile transfers
The same function can also serve 1 to 1 transfers
repeat_counts:
    1. Has to be a list of integers.
    2. if there is reconfig, the length of the list must be equal to the number of reconfigs
    3. If there is no reconfig, the same access is repeated.
write_dma:
    1. This is the write DMA for the memtile
    2. This is the S2MM DMA
read_dma_list:
    1. This is a list of read DMAs for the memtile
    2. These are the MM2S DMAs
    3. Can also be a list of same DMA if there are chained transfers
    with each BD having multiple access patterns
buffer_addrs:
    1. This is a list of buffer addresses for the memtile buffers
buffer_size:
    1. This is the size of the buffer in bytes
Memory formats:
    1. Can be a single string
    2. Can be a list of strings, This is used for case of chained transfers
    3. Can be a list of list of strings, This is used for case of reconfig with chained transfers
    4. The number of access patterns for each DMA must be the same and
    equal to the number of memory formats
Write tiling formats:
    1. Can be a single string
    2. Can be a list of strings, This is used for case of chained transfers
    3. Can be a list of list of strings, This is used for case of reconfig with chained transfers
    4. The number of access patterns for each DMA must be the same and
    equal to the number of memory formats
Read tiling formats:
    1. These are a list of strings per DMA
    The outer most list is the number of read DMAs
    The inner most list is the number of access patterns for each DMA
    The number of access patterns for each DMA must be the same and
    equal to the number of memory formats
bits_per_block:
    1. This is the number of bits per block
elements_per_block:
    1. This is the number of elements per block
max_chain_length:
    1. If the repeat is too large, it can be broken into multiple BDs
write_buffer_offset:
    1. This is the offset for the write DMA within the buffers
read_buffer_offset:
    1. This is the offset for each read DMA within the buffers
reuse_ratio:
    1. If the buffer is reused, this is the ratio of the reuse
parallel_locking:
    1. Enable or disable parallel locking
    2. This is switch exposed to the user to control based on the BD resouces available.
'''
def generate_memtile_data_transfers_1_to_N(
    repeat_counts: List[int],
    write_dma: AieDma,
    read_dma_list: List[AieDma],
    buffer_addrs: List[int],
    buffer_size: int,
    memory_format: Union[str, List[str], List[List[str]]],
    write_tiling_format: Union[str, List[str], List[List[str]]],
    read_tiling_format: Union[List[str], List[List[str]]],
    bits_per_block: int = 8,
    elements_per_block: int = 1,
    max_chain_length: int = 4,
    write_buffer_offset: Union[int, List[int]] = 0,
    read_buffer_offset: Union[int, List[int]] = 0,
    reuse_ratio: int = 1,
    parallel_locking: bool = False,
    name: str = "",
) -> List[DataTransfer]:
    '''
    The following asserts checks the validity of the input parameters
    This function does not support use of iter_step at the memtile level
    as it affects both the read and write transfers and it cannot be determined
    at just the single transfer level, must be handled at the OP schedule level
    '''
    config.check_init()
    assert write_dma.tile.type == TileType.Memtile and write_dma.channel.dir == DmaDir.S2MM, \
    f"Invalid write DMA {write_dma}, expected memtile S2MM DMA!"
    if len(read_dma_list) == 0:
        raise ValueError(f"Invalid read DMA {read_dma_list}, expected at least one read DMA!")
    else:
        for dma in read_dma_list:
            assert dma.tile.type == TileType.Memtile and dma.channel.dir == DmaDir.MM2S, \
            f"Invalid list of read DMA {dma}, expected memtile MM2S DMA!"
    if isinstance(memory_format, str):
        assert isinstance(write_tiling_format, str), \
        f"memory format and write tiling both must be string per DMA"
        for read_fmt in read_tiling_format:
            assert isinstance(read_fmt, str), \
            f"memory format and read tiling both must be string per DMA"
    if isinstance(memory_format, list):
        assert isinstance(write_tiling_format, list) and len(write_tiling_format) == len(memory_format), \
        f"memory format and write tiling both must be list of strings of equal length"
        assert isinstance(read_tiling_format, list) and len(read_tiling_format) == len(memory_format), \
        f"memory format and read tiling both must be list of strings of equal length"
    memtile_data_transfers = []
    write_data_transfers = []
    read_data_transfers = []
    '''
    NOTE: This section generates the write transfer for the memtile
    '''
    if isinstance(write_tiling_format, str):
        '''
        NOTE: This is a case of no reconfig for the write DMA
        '''
        write_transfer = None
        try:
            write_transfer = generate_transfer_params(
                write_dma,
                memory_format,
                write_tiling_format,
                bits_per_block=bits_per_block,
                elements_per_block=elements_per_block,
                max_chain_length=max_chain_length,
                buffer_offset=read_buffer_offset,
                name=name,
            )
        except RuntimeError as e:
            print(f"Failed to generate write transfer for {write_dma} with "
                  f"memory format {memory_format} and tiling format {write_tiling_format}: {e}")
        write_data_transfers.append(write_transfer)
    else:
        '''
        NOTE: This is a case of chained transfers for the write DMA
        '''
        # print(f"Generating chained transfers for write DMA {write_dma} with "
                # f"memory format {memory_format} and tiling format {write_tiling_format}")
        for idx, write_fmt in enumerate(write_tiling_format):
            write_transfer = None
            if isinstance(write_fmt, str):
                # NOTE: This is a case of no reconfig with chained transfers
                print("No reconfigs case with chained transfers for write DMA")
                assert len(repeat_counts) == 1, \
                f"Invalid repeat counts {repeat_counts}, expected single repeat count for no reconfig!"
                write_transfer = generate_transfer_params(
                                    write_dma,
                                    memory_format[idx],
                                    write_fmt,
                                    bits_per_block=bits_per_block,
                                    elements_per_block=elements_per_block,
                                    max_chain_length=max_chain_length,
                                    buffer_offset=read_buffer_offset,
                                    name=name,
                                ) 
                write_data_transfers.append(write_transfer)
            else:
                # NOTE: Within a chained BD for each reconfig for the write DMA pack the transfers
                print("Reconfig case with chained transfers for write DMA")
                try:
                    write_transfer = pack_reconfig_transfers(
                        write_dma,
                        memory_format[idx],
                        write_fmt,
                        bits_per_elem=bits_per_block,
                        buffer_offset=write_buffer_offset,
                        name=name,
                    )
                except RuntimeError as e:
                    print(f"Failed to generate write transfer for {write_dma} with "
                          f"memory format {memory_format} and tiling format {write_fmt}: {e}")
                write_data_transfers.append(write_transfer)

    '''
    NOTE: This section generates the read transfer for the memtile
    It iterates over the read DMAs and generates the transfer for each DMA
    If there are multiple access patterns for the each DMA, it packs the transfers
    into a single transfer per DMA
    '''
    for idx, read_fmt_per_dma in enumerate(read_tiling_format):
        # print(f"Generating read transfer for read DMA {read_dma_list[idx]} with "
                # f"memory format {memory_format} and tiling format {read_fmt_per_dma}")
        if isinstance(read_fmt_per_dma, str):
            '''
            This is a case of no reconfig for the read DMAs
            generate read transfers for each DMA by using the generate_transfer_params
            '''
            read_transfer = None
            try:
                read_transfer = generate_transfer_params(
                    read_dma_list[idx],
                    memory_format,
                    read_fmt_per_dma,
                    bits_per_block=bits_per_block,
                    elements_per_block=elements_per_block,
                    max_chain_length=max_chain_length,
                    buffer_offset=read_buffer_offset,
                    name=name,
                )
            except RuntimeError as e:
                print(f"Failed to generate read transfer for {read_dma_list[idx]} with "
                      f"memory format {memory_format} and tiling format {read_fmt_per_dma}: {e}")
            read_data_transfers.append(read_transfer)
        else:
            '''
            NOTE: This is a case of reconfig for the read DMAs
            '''
            read_transfer = None
            try:
                read_transfer = pack_reconfig_transfers(
                    read_dma_list[idx],
                    memory_format[idx],
                    read_fmt_per_dma,
                    bits_per_elem=bits_per_block,
                    buffer_offset=read_buffer_offset,
                    name=name,
                )
            except RuntimeError as e:
                print(f"Failed to generate packed read transfer for {read_dma_list[idx]} with "
                      f"memory format {memory_format} and tiling format {read_fmt_per_dma}: {e}")
            read_data_transfers.append(read_transfer)

    '''
    At this point all the write and read transfers are generated
    Now we need to create the DataTransfer object for the memtile
    '''
    data_transfer_obj = DataTransfer(
        repeat_counts,
        write_dma.tile,
        buffer_addrs,
        buffer_size,
        write_data_transfers,
        read_data_transfers,
        sync_strategy=SyncStrategy.Parallel_1_to_N if parallel_locking else SyncStrategy.Serial_M_to_N,
        reuse_ratio=reuse_ratio,
    )
    memtile_data_transfers.append(data_transfer_obj)

    return memtile_data_transfers


'''
NOTE: The following function is to provide a function
for schedule developer to generate packed memtile transfers
repeat_counts:
    1. Has to be a list of integers.
    2. if there is reconfig, the length of the list must be equal to the number of reconfigs
    3. If there is no reconfig, the same access is repeated.
write_dma_list:
    1. This is a list of write DMAs for the memtile
    2. These are the S2MM DMAs
    3. Can also be a list of same DMA if there are chained transfers
    with each BD having multiple access patterns
read_dma:
    1. This is the read DMA for the memtile
    2. This is the MM2S DMA
buffer_addrs:
    1. This is a list of buffer addresses for the memtile buffers
buffer_size:
    1. This is the size of the buffer in bytes
memory_format:
    1. Can be a single string
    2. Can be a list of strings, This is used for case of reconfig
    3. Can be a list of list of strings, This is used for case of reconfig with chained transfers
write_tiling_format:
    1. These are a list of strings per DMA
    2. Can be a list of list of strings, This is used for case of reconfig
    The outer most list is the number of write DMAs
    The inner most list is the number of access patterns for each DMA
    The number of access patterns for each DMA must be the same and
    equal to the number of memory formats
read_tiling_format:
    1. Can be a single string
    2. Can be a list of strings, This is used for case of chained transfers
    3. Can be a list of list of strings, This is used for case of reconfig with chained transfers
    4. The number of access patterns for each DMA must be the same and
    equal to the number of memory formats
bits_per_block:
    1. This is the number of bits per block
elements_per_block:
    1. This is the number of elements per block
max_chain_length:
    1. If the repeat is too large, it can be broken into multiple BDs
write_buffer_offset:
    1. This is the offset for the write DMA within the buffers
read_buffer_offset:
    1. This is the offset for each read DMA within the buffers
parallel_locking:
    1. Enable or disable parallel locking
    2. This is switch exposed to the user to control based on the BD resouces available.
'''
def generate_memtile_data_transfers_N_to_1(
    repeat_counts: List[int],
    write_dma_list: List[AieDma],
    read_dma: AieDma,
    buffer_addrs: List[int],
    buffer_size: int,
    memory_format: Union[str, List[str], List[List[str]]],
    write_tiling_format: Union[List[str], List[List[str]]],
    read_tiling_format: Union[str, List[str], List[List[str]]],
    bits_per_block: int = 8,
    elements_per_block: int = 1,
    max_chain_length: int = 4,
    write_buffer_offset: Union[int, List[int]] = 0,
    read_buffer_offset: Union[int, List[int]] = 0,
    parallel_locking: bool = False,
    name: str = "",
) -> List[DataTransfer]:
    '''
    The following asserts checks the validity of the input parameters
    This function does not support use of iter_step at the memtile level
    as it affects both the read and write transfers and it cannot be determined
    at just the single transfer level, must be handled at the OP schedule level
    '''
    config.check_init()
    assert read_dma.tile.type == TileType.Memtile and read_dma.channel.dir == DmaDir.MM2S, \
    f"Invalid read DMA {read_dma}, expected memtile MM2S DMA!"
    if len(write_dma_list) == 0:
        raise ValueError(f"Invalid write DMA {write_dma_list}, expected at least one write DMA!")
    else:
        for dma in write_dma_list:
            assert dma.tile.type == TileType.Memtile and dma.channel.dir == DmaDir.S2MM, \
            f"Invalid list of write DMA {dma}, expected memtile S2MM DMA!"
    if isinstance(memory_format, str):
        assert isinstance(write_tiling_format, str), \
        f"memory format and write tiling both must be string per DMA"
        assert isinstance(read_tiling_format, str), \
        f"memory format and read tiling both must be string per DMA"
    if isinstance(memory_format, list):
        assert isinstance(write_tiling_format, list) and len(write_tiling_format) == len(memory_format), \
        f"memory format and write tiling both must be list of strings of equal length"
        assert isinstance(read_tiling_format, list) and len(read_tiling_format) == len(memory_format), \
        f"memory format and read tiling both must be list of strings of equal length"
    memtile_data_transfers = []
    write_data_transfers = []
    read_data_transfers = []
    '''
    NOTE: This section generates the read transfer for the memtile
    '''
    for idx, write_fmt in enumerate(write_tiling_format):
        if isinstance(write_fmt, str):
            '''
            This is a case of no reconfig for the write DMAs
            generate read transfers for each DMA by using the generate_transfer_params
            '''
            write_transfer = None
            try:
                write_transfer = generate_transfer_params(
                    write_dma_list[idx],
                    memory_format,
                    write_fmt,
                    bits_per_block=bits_per_block,
                    elements_per_block=elements_per_block,
                    max_chain_length=max_chain_length,
                    buffer_offset=write_buffer_offset,
                    name=name,
                )
            except RuntimeError as e:
                print(f"Failed to generate write transfer for {write_dma_list[idx]} with "
                      f"memory format {memory_format} and tiling format {write_fmt}: {e}")
            write_data_transfers.append(write_transfer)
        else:
            '''
            NOTE: This is a case of reconfig for the write DMA
            '''
            write_transfer = None
            try:
                write_transfer = pack_reconfig_transfers(
                    write_dma_list[idx],
                    memory_format[idx],
                    write_fmt,
                    bits_per_elem=bits_per_block,
                    buffer_offset=write_buffer_offset,
                    name=name,
                )
            except RuntimeError as e:
                print(f"Failed to generate packed write transfer for {write_dma_list[idx]} with "
                      f"memory format {memory_format} and tiling format {write_fmt}: {e}")
            write_data_transfers.append(write_transfer)
    '''
    NOTE: This section generates the read transfer for the memtile
    '''
    if isinstance(read_tiling_format, str):
        '''
        NOTE: This is a case of no reconfig for the read DMA
        '''
        read_transfer = None
        try:
            read_transfer = generate_transfer_params(
                read_dma,
                memory_format,
                read_tiling_format,
                bits_per_block=bits_per_block,
                elements_per_block=elements_per_block,
                max_chain_length=max_chain_length,
                buffer_offset=read_buffer_offset,
                name=name,
            )
        except RuntimeError as e:
            print(f"Failed to generate read transfer for {read_dma} with "
                  f"memory format {memory_format} and tiling format {read_tiling_format}: {e}")
        read_data_transfers.append(read_transfer)
    else:
        '''
        NOTE: This is a case of chained transfers for the read DMA
        '''
        for idx, read_fmt in enumerate(read_tiling_format):
            '''
            NOTE: Within a chained BD for each reconfig for the read DMA pack the transfers
            '''
            read_transfer = None
            try:
                read_transfer = pack_reconfig_transfers(
                    read_dma,
                    memory_format[idx],
                    read_fmt,
                    bits_per_elem=bits_per_block,
                    buffer_offset=read_buffer_offset,
                    name=name,
                )
            except RuntimeError as e:
                print(f"Failed to generate packed read transfer for {read_dma} with "
                      f"memory format {memory_format} and tiling format {read_tiling_format}: {e}")
            read_data_transfers.append(read_transfer)
    '''
    At this point all the write and read transfers are generated
    Now we need to create the DataTransfer object for the memtile
    '''
    data_transfer_obj = DataTransfer(
        repeat_counts,
        read_dma.tile,
        buffer_addrs,
        buffer_size,
        write_data_transfers,
        read_data_transfers,
        sync_strategy=SyncStrategy.Parallel_N_to_1 if parallel_locking else SyncStrategy.Serial_M_to_N,
    )
    memtile_data_transfers.append(data_transfer_obj)
    return memtile_data_transfers


def generate_memtile_data_transfer(
    repeat_counts: list[int],
    dma: AieDma,
    buffer_tile: AieTile,
    buffer_addr: int,
    memory_format: str,
    tiling_format: str,
    bits_per_block: int = 8,
    elements_per_block: int = 1,
    enable_padding: bool = False,
    max_chain_length: int = 4,
    verbose: bool = True,
    name: str = "",
) -> DataTransfer:
    '''Helper function to map a single DMA memtile transfer possibly with iter step'''
    assert dma.tile.type == TileType.Memtile
    if verbose:
        print(dma, memory_format, tiling_format, sep=', ')
    # NOTE: Here we first attempt to map the transfer without
    # iter step and chaining, then revert to those if the mapping fails
    try:
        transfer_params = generate_transfer_params(
            dma,
            memory_format,
            tiling_format,
            bits_per_block=bits_per_block,
            elements_per_block=elements_per_block,
            enable_padding=enable_padding,
            max_chain_length=max_chain_length,
            verbose=False,
            name=name,
        )
        repeat_coeff = 1
        transfer_chain = [transfer_params]
    except (RuntimeError, ValueError):
        repeat_coeff, transfer_chain = generate_transfer_params(
            dma,
            memory_format,
            tiling_format,
            bits_per_block=bits_per_block,
            elements_per_block=elements_per_block,
            enable_padding=enable_padding,
            use_iter_step=True,
            max_chain_length=max_chain_length,
            verbose=False,
            name=name,
        )
    buffer_size = compute_buffer_size(memory_format, bits_per_block)
    write_params = transfer_chain if dma.channel.dir == DmaDir.S2MM else []
    read_params = transfer_chain if dma.channel.dir == DmaDir.MM2S else []
    return DataTransfer(
        [count * repeat_coeff for count in repeat_counts],
        buffer_tile, [buffer_addr], buffer_size,
        write_params,
        read_params,
    )
