"""
Maxpooling layer dataflow generation
"""

from dataclasses import dataclass
import ctypes
import struct
import math
from typing import List, Any

from dmacompiler import (
    DevGen,
    OverlayShape,
    DataTransfer,
    TransferParams,
    SyncStrategy,
    BackEnd,
    generate_transfer_params,
    DmaChannel,
    DmaDir,
    AieTile,
    TileType,
    memtile_dma,
    shim_dma,
    ConfigBuffer,
    AcqBuffer,
    RelBuffer,
    CallKernel,
    Loop,
    run_layer_compilation,
    set_dev_gen,
    core_tile,
    compute_buffer_size,
    memory_tile,
    generate_shim_data_transfer,
    generate_core_buffer_config,
    core_dma,
)
from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    ceildiv,
    iceil,
    L2Alloc,
    log,
)
from scheduler.common import (
    overlay_3x4_dma_connections,
    overlay_3x4_param_channel_id,
    overlay_3x4_A_ids,
    overlay_3x4_B_ids,
    overlay_3x4_F_ids,
    overlay_3x4_O_ids,
    overlay_3x4_S_ids,
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_mm2s,
    prm_memtile_s2mm,
    prm_shim_mm2s,
    shim_alloc,
    L3Alloc,
    L3Alloc_to_Shim,
)
from buildscripts.common import ScheduleInputs

set_dev_gen(DevGen.Aie4)

maxpool_split_modes = {
    # NOTE: [Y_split, X_split, Co_split]
    "Y1X1C12": [1, 1, 12],
    "Y4X3C1": [4, 3, 1],
    "Y3X4C1": [3, 4, 1],
    "Y12X1C1": [12, 1, 1],
    "Y1X12C1": [1, 12, 1],
}


@dataclass
class Ctrl(ctypes.Structure):
    """
    Control structure for maxpool kernel parameters
    """

    _fields_ = [
        ("zero_init", ctypes.c_uint32, 1),
        ("sign_N", ctypes.c_uint32, 1),
        ("sign_O", ctypes.c_uint32, 1),
        ("reserved3", ctypes.c_uint32, 3),
        ("skip_casc_in", ctypes.c_uint32, 1),
        ("skip_casc_out", ctypes.c_uint32, 1),
        ("sign_W", ctypes.c_uint32, 1),
        ("sign_A", ctypes.c_uint32, 1),
        ("reserved10", ctypes.c_uint32, 14),
        ("norm_ch_g", ctypes.c_uint32, 8),
    ]


def unicast_channels() -> list[tuple[int, int]]:
    """Generate unicast memtile mm2s channel allocations (row, id)"""
    return list(enumerate(overlay_3x4_A_ids()))


def broadcast_channels(col: int) -> list[tuple[int, int]]:
    """Geneate broadcast memtile mm2s channel allocations (row, id)"""
    return list(enumerate(overlay_3x4_B_ids(col)))


class MaxpoolShape:
    """
    Maxpooling layer shape and parameters
    """

    __slots__ = (
        "ofm_dims",
        "ifm_dims",
        "ofm_gran",
        "filter_dims",
        "stride",
        "padding",
        "feasible",
        "ofm_subvol_dims",
        "ifm_subvol_dims",
        "iter_dims",
        "Y_split",
        "X_split",
        "C_split",
        "Cs_min",
        "ofm_addr",
        "ifm_addr",
        "ifm_no_l1_buffers",
        "ofm_no_l1_buffers",
        "ifm_bits",
        "ofm_bits",
        "bits_per_byte",
        "aie_cols",
        "aie_rows",
        "ifm_l1_ping_addr",
        "ifm_l1_pong_addr",
        "ofm_l1_ping_addr",
        "ofm_l1_pong_addr",
        "wgt_l1_addr",
        "wgt_l1_size",
    )

    def __init__(
        self,
        Yo: int,
        Xo: int,
        C: int,
        Yi: int,
        Xi: int,
        Ky: int,
        Kx: int,
        Sy: int,
        Sx: int,
        Py: int,
        Px: int,
        Y_gran: int,
        X_gran: int,
        C_gran: int,
        Cs_min: int,
        aie_cols: int,
        aie_rows: int,
        ifm_bits: int,
        ofm_bits: int,
        bits_per_byte: int,
    ):
        self.wgt_l1_size = 128
        self.Y_split = 0
        self.X_split = 0
        self.C_split = 0
        self.ofm_dims = (Yo, Xo, C)
        self.ifm_dims = (Yi, Xi, C)
        self.filter_dims = (Ky, Kx)
        self.stride = (Sy, Sx)
        self.padding = (Py, Px)
        self.feasible = False
        self.ofm_subvol_dims = (0, 0, 0)
        self.ifm_subvol_dims = (0, 0, 0)
        self.iter_dims = (0, 0, 0)
        self.ofm_addr = [0, 0]
        self.ifm_addr = [0, 0]
        self.ifm_no_l1_buffers = 2
        self.ofm_no_l1_buffers = 2
        self.ifm_bits = ifm_bits
        self.ofm_bits = ofm_bits
        self.bits_per_byte = bits_per_byte
        self.aie_cols = aie_cols
        self.aie_rows = aie_rows
        self.ofm_gran = (Y_gran, X_gran, C_gran)
        self.Cs_min = Cs_min
        self.ifm_l1_ping_addr = 0
        self.ifm_l1_pong_addr = 0
        self.ofm_l1_ping_addr = 0
        self.ofm_l1_pong_addr = 0
        self.wgt_l1_addr = 0


def calc_in_subvol_dims(mapped_pool: MaxpoolShape) -> None:
    """
    Given output subvol dimensions,
    this function calculates the input subvol dimensions for maxpooling layer
    """
    (Yos, Xos, Cs) = mapped_pool.ofm_subvol_dims
    (Sy, Sx) = mapped_pool.stride
    (Ky, Kx) = mapped_pool.filter_dims
    (_, _) = mapped_pool.padding
    Yis = Sy * (Yos - 1) + Ky
    Xis = Sx * (Xos - 1) + Kx
    mapped_pool.ifm_subvol_dims = (Yis, Xis, Cs)


class MaxPoolDims:
    """
    Maxpooling layer dimensions for dataflow generation
    """

    __slots__ = (
        "N",
        "param_size",
        "Yo",
        "Xo",
        "C",
        "Yi",
        "Xi",
        "Ky",
        "Kx",
        "Yi_pad",
        "Xi_pad",
        "Y_gran",
        "X_gran",
        "C_gran",
        "Sy",
        "Sx",
        "Py",
        "Px",
        "Yos",
        "Xos",
        "Cs",
        "Cs_pad",
        "Yis",
        "Xis",
        "Y_split",
        "X_split",
        "C_split",
        "Y_loop",
        "X_loop",
        "C_loop",
        "aie_cols",
        "aie_rows",
        "ifm_bits",
        "ofm_bits",
        "param_bits",
        "wgt_bits",
        "bits_per_byte",
        "ifm_subv_size",
        "ofm_subv_size",
        "wgt_subv_size",
        "size_bytes",
    )
    """
    Maxpooling layer dimensions
    This is used for dataflow generation
    """

    def __init__(self, mappings: MaxpoolShape):
        # NOTE: wgt_subv_size is the Runtime params
        self.wgt_subv_size = mappings.wgt_l1_size
        self.param_size = 1024
        self.param_bits = 8
        self.wgt_bits = 8
        self.N = 1
        self.Yo, self.Xo, self.C = mappings.ofm_dims
        self.Yi, self.Xi, self.C = mappings.ifm_dims
        self.Ky, self.Kx = mappings.filter_dims
        self.Sy, self.Sx = mappings.stride
        self.Py, self.Px = mappings.padding
        self.Yos, self.Xos, self.Cs = mappings.ofm_subvol_dims
        self.Yis, self.Xis, self.Cs = mappings.ifm_subvol_dims
        self.Y_gran, self.X_gran, self.C_gran = mappings.ofm_gran
        self.Y_split = mappings.Y_split
        self.X_split = mappings.X_split
        self.C_split = mappings.C_split
        self.Y_loop, self.X_loop, self.C_loop = mappings.iter_dims
        self.aie_cols = mappings.aie_cols
        self.aie_rows = mappings.aie_rows
        self.ifm_bits = mappings.ifm_bits
        self.ofm_bits = mappings.ofm_bits
        self.bits_per_byte = mappings.bits_per_byte
        self.Cs_pad = self.Cs if self.Cs >= mappings.Cs_min else mappings.Cs_min
        self.ifm_subv_size = (self.Yis * self.Xis * self.Cs_pad * self.ifm_bits) // self.bits_per_byte
        self.ofm_subv_size = (self.Yos * self.Xos * self.Cs_pad * self.ofm_bits) // self.bits_per_byte
        self.size_bytes = self.ifm_bits // self.bits_per_byte
        self.Yi_pad = self.Sy * ((self.Yos * self.Y_split * self.Y_loop) - 1) + self.Ky - 2 * self.Py
        self.Xi_pad = self.Sx * ((self.Xos * self.X_split * self.X_loop) - 1) + self.Kx - 2 * self.Px


def map_maxpool(spatial_split: tuple[int, int, int], mapped_pool: MaxpoolShape) -> None:
    """
    This function generates the subvolume dimensions for maxpooling layer
    Currently the function picks the subvolume based on the size of L1 memory available
    Which would result in least number of subvolume loops
    This functions also checks the feasibity of the spatial split for maxpooling layer
    """
    Yo, Xo, Co = mapped_pool.ofm_dims
    Y_gran, X_gran, C_gran = mapped_pool.ofm_gran
    Y_split, X_split, C_split = spatial_split
    ifm_l1_size = 0
    ofm_l1_size = 0
    mapped_pool.Y_split = Y_split
    mapped_pool.X_split = X_split
    mapped_pool.C_split = C_split
    if Yo < (Y_gran * Y_split) or Xo < (X_gran * X_split) or Co < (C_gran * C_split):
        mapped_pool.feasible = False
    else:
        # Calculate the initial subvolume dimensions
        Yo_subvol = ceildiv(Yo, Y_split)
        Xo_subvol = ceildiv(Xo, X_split)
        C_subvol = ceildiv(Co, C_split)
        Y_temporal_iters = 1
        X_temporal_iters = 1
        C_temporal_iters = 1
        mapped_pool.feasible = True
        mapped_pool.ofm_subvol_dims = (Yo_subvol, Xo_subvol, C_subvol)
        mapped_pool.iter_dims = (Y_temporal_iters, X_temporal_iters, C_temporal_iters)
        calc_in_subvol_dims(mapped_pool)
        available_l1_size = overlay_3x4_core_stack_addr() - mapped_pool.wgt_l1_size
        ifm_l1_size = (
            mapped_pool.ifm_subvol_dims[0] * mapped_pool.ifm_subvol_dims[1] * mapped_pool.ifm_subvol_dims[2] * mapped_pool.ifm_no_l1_buffers * mapped_pool.ifm_bits
        ) // mapped_pool.bits_per_byte
        ofm_l1_size = (
            mapped_pool.ofm_subvol_dims[0] * mapped_pool.ofm_subvol_dims[1] * mapped_pool.ofm_subvol_dims[2] * mapped_pool.ofm_no_l1_buffers * mapped_pool.ofm_bits
        ) // mapped_pool.bits_per_byte
        occupied_L1_size = ifm_l1_size + ofm_l1_size
        while occupied_L1_size > available_l1_size:
            # NOTE: First choice is to split Channel dimension
            if (C_subvol > C_gran) and (C_subvol % C_gran == 0):
                C_subvol = ceildiv(C_subvol, 2)
                C_temporal_iters = ceildiv(Co, (C_subvol * C_split))
            # NOTE: Second choice is to split Y dimension
            elif (Yo_subvol > Y_gran) and (Yo_subvol % Y_gran == 0):
                Yo_subvol = ceildiv(Yo_subvol, 2)
                Y_temporal_iters = ceildiv(Yo, (Yo_subvol * Y_split))
            # NOTE: Third choice is to split X dimension
            elif (Xo_subvol > X_gran) and (Xo_subvol % X_gran == 0):
                Xo_subvol = ceildiv(Xo_subvol, 2)
                X_temporal_iters = ceildiv(Xo, (Xo_subvol * X_split))
            else:
                if mapped_pool.ifm_no_l1_buffers == 1 and mapped_pool.ofm_no_l1_buffers == 1:
                    mapped_pool.feasible = False
                    break
                mapped_pool.ifm_no_l1_buffers = 1
                mapped_pool.ofm_no_l1_buffers = 1
            mapped_pool.ofm_subvol_dims = (Yo_subvol, Xo_subvol, C_subvol)
            mapped_pool.iter_dims = (Y_temporal_iters, X_temporal_iters, C_temporal_iters)
            calc_in_subvol_dims(mapped_pool)
            ifm_l1_size = mapped_pool.ifm_subvol_dims[0] * mapped_pool.ifm_subvol_dims[1] * mapped_pool.ifm_subvol_dims[2] * mapped_pool.ifm_no_l1_buffers
            ofm_l1_size = mapped_pool.ofm_subvol_dims[0] * mapped_pool.ofm_subvol_dims[1] * mapped_pool.ofm_subvol_dims[2] * mapped_pool.ofm_no_l1_buffers
            occupied_L1_size = ifm_l1_size + ofm_l1_size
        ifm_subv_size = (mapped_pool.ifm_subvol_dims[0] * mapped_pool.ifm_subvol_dims[1] * mapped_pool.ifm_subvol_dims[2] * mapped_pool.ifm_bits) // mapped_pool.bits_per_byte
        ofm_subv_size = (mapped_pool.ofm_subvol_dims[0] * mapped_pool.ofm_subvol_dims[1] * mapped_pool.ofm_subvol_dims[2] * mapped_pool.ofm_bits) // mapped_pool.bits_per_byte
        # The allocation scheme for the L1 memory is as follows:
        # ifm_ping + ofm_ping + ifm_pong + ofm_pong + wgt
        # Each buffer is appened to the previous buffer
        # Bank alignment of the buffers it not considered in this allocation
        CoreAlignment = 128
        CoreStackAddr = overlay_3x4_core_stack_addr()
        mapped_pool.ifm_l1_ping_addr = 0
        mapped_pool.ofm_l1_ping_addr = iceil(mapped_pool.ifm_l1_ping_addr + ifm_subv_size, CoreAlignment)
        if mapped_pool.ifm_no_l1_buffers == 2 and mapped_pool.ofm_no_l1_buffers == 2:
            mapped_pool.ifm_l1_pong_addr = iceil(mapped_pool.ofm_l1_ping_addr + ofm_subv_size, CoreAlignment)
            mapped_pool.ofm_l1_pong_addr = iceil(mapped_pool.ifm_l1_pong_addr + ifm_subv_size, CoreAlignment)
            mapped_pool.wgt_l1_addr = iceil(mapped_pool.ofm_l1_pong_addr + ofm_subv_size, CoreAlignment)
        else:
            mapped_pool.wgt_l1_addr = iceil(mapped_pool.ofm_l1_ping_addr + ofm_subv_size, CoreAlignment)
        assert mapped_pool.wgt_l1_addr + mapped_pool.wgt_l1_size <= CoreStackAddr


def maxpool_ranker(mapped_maxpool: dict) -> dict:
    """Filter out non-feasible mappings.
    First pass of cost function:
        Sort the mappings based on the total number of iterations (Y_iters * X_iters * C_iters)
    Second pass of cost function:
        Calculate the overcompute for each mapping and sort from least to highest
    """
    feasible_mappings = {key: value for key, value in mapped_maxpool.items() if value.feasible}
    ranked_maxpool = dict(sorted(feasible_mappings.items(), key=lambda item: math.prod(item[1].iter_dims)))
    overcompute_list = []
    for key, mapping in ranked_maxpool.items():
        Yo_overcompute = mapping.ofm_subvol_dims[0] * maxpool_split_modes[key][0] * mapping.iter_dims[0]
        Xo_overcompute = mapping.ofm_subvol_dims[1] * maxpool_split_modes[key][1] * mapping.iter_dims[1]
        Co_overcompute = mapping.ofm_subvol_dims[2] * maxpool_split_modes[key][2] * mapping.iter_dims[2]
        overcompute_macs = Yo_overcompute * Xo_overcompute * Co_overcompute
        overcompute_list.append((key, overcompute_macs))
    overcompute_list.sort(key=lambda x: x[1])
    # Create a new ranked_maxpool dictionary based on the sorted overcompute_list
    ranked_maxpool = {item[0]: ranked_maxpool[item[0]] for item in overcompute_list}
    return ranked_maxpool


def setup_maxpool_kernel_params(
    N: int,
    Sx: int,
    Ky: int,
    Kx: int,
    Xis: int,
    Cs: int,
    Y_gran: int,
    X_gran: int,
    C_gran: int,
    Yos: int,
    Xos: int,
    size_bytes: int,
) -> bytes:
    """
    NOTE: The smallest output subvol granularity for the maxpool kernel is 1x1x64
    """
    Kx_g = Kx
    Ky_g = Ky
    Ci_g = 1  # NOTE: Channel iterations is controlled by Co_g
    S_g = Sx
    N_g = N
    X_g = ceildiv(Xos, X_gran)
    Y_g = ceildiv(Yos, Y_gran)
    Co_g = ceildiv(Cs, C_gran)
    inner_g = Kx_g * Ky_g
    outer_g = X_g * Y_g * Co_g

    step_Kx = C_gran  # NOTE: Unused param for maxpool kernel
    step_Ky = Xis * size_bytes * Cs
    step_Xi = C_gran  # NOTE: Unused param for maxpool kernel
    step_Yi = Xis * size_bytes * Cs * S_g
    step_Ci = (Xis * C_gran * size_bytes) if (Cs > C_gran) else 1
    step_Xo = C_gran * size_bytes
    step_Yo = Xos * Cs * size_bytes
    step_Co = (Xos * C_gran * size_bytes) if (Cs > C_gran) else 1
    param_value = 0
    ctrl = Ctrl()
    shift_tdm = 0
    shift_norm = 0
    shift_bias = 0
    shift_res = 0
    struct_fields = (
        Kx_g,
        Ky_g,
        Ci_g,
        S_g,
        N_g,
        X_g,
        Y_g,
        Co_g,
        inner_g,
        outer_g,
        shift_tdm,
        shift_res,
        shift_norm,
        shift_bias,
        step_Kx,
        step_Ky,
        step_Ci,
        step_Xi,
        step_Yi,
        step_Xo,
        step_Yo,
        step_Co,
        param_value,
        ctypes.string_at(ctypes.addressof(ctrl), ctypes.sizeof(ctrl)),
    )
    log("Kx_g,       ", Kx_g)
    log("Ky_g,       ", Ky_g)
    log("Ci_g,       ", Ci_g)
    log("S_g,        ", S_g)
    log("N_g,        ", N_g)
    log("X_g,        ", X_g)
    log("Y_g,        ", Y_g)
    log("Co_g,       ", Co_g)
    log("inner_g,    ", inner_g)
    log("outer_g,    ", outer_g)
    log("shift_tdm,  ", shift_tdm)
    log("shift_res,  ", shift_res)
    log("shift_norm, ", shift_norm)
    log("shift_bias, ", shift_bias)
    log("step_Kx,    ", step_Kx)
    log("step_Ky,    ", step_Ky)
    log("step_Ci,    ", step_Ci)
    log("step_Xi,    ", step_Xi)
    log("step_Yi,    ", step_Yi)
    log("step_Xo,    ", step_Xo)
    log("step_Yo,    ", step_Yo)
    log("step_Co,    ", step_Co)
    log("param_value,", param_value)
    format_string = "BBBbBBBBHHbbbbHHHHHHHHi4s"
    kernel_params = struct.pack(format_string, *struct_fields)
    return kernel_params


def generate_aie4_maxpool_params(
    dims: MaxPoolDims,
    mode: int,
) -> bytes:
    """
    Generate AIE4 maxpool kernel parameters
    """
    kernel_params = setup_maxpool_kernel_params(
        N=dims.N,
        Sx=dims.Sx,
        Ky=dims.Ky,
        Kx=dims.Kx,
        Xis=dims.Xis,
        Cs=dims.Cs_pad,
        Y_gran=dims.Y_gran,
        X_gran=dims.X_gran,
        C_gran=dims.C_gran,
        Yos=dims.Yos,
        Xos=dims.Xos,
        size_bytes=dims.size_bytes,
    )
    layer_params = mode.to_bytes(length=2, byteorder="little", signed=False)
    return kernel_params + layer_params


def core_to_split(dims: MaxPoolDims, col: int, row: int) -> tuple[int, int, int]:
    """Map core (col, row) to logical image split (Y_idx, X_idx, Co_idx)"""

    def coreid(c: int, r: int) -> int:
        """Flatten core to 1d index"""
        return (c * dims.aie_rows) + r

    # Key format is (Y_split, X_split, Co_split)
    # Val is a lambda mapping physical core to image block position
    mode_lookup = {
        # Y_split, X_split, Co_split
        (1, 1, 12): (lambda id: (0, 0, id)),
        (4, 3, 1): (lambda id: (id % 4, id // 4, 0)),
        (3, 4, 1): (lambda id: (id // 3, id % 3, 0)),
        (12, 1, 1): (lambda id: (id, 0, 0)),
        (1, 12, 1): (lambda id: (0, id, 0)),
    }
    (
        Y_idx,
        X_idx,
        Co_idx,
    ) = mode_lookup[
        (dims.Y_split, dims.X_split, dims.C_split)
    ](coreid(col, row))
    return Y_idx, X_idx, Co_idx


def Yi_slice(dims: MaxPoolDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    """Slice for axis Yi at core (col, row) during iteration i of Y_loop"""
    Y_idx, _, _ = core_to_split(dims, col, row)
    log("Y_idx = ", Y_idx)
    in_bounds = ((Y_idx * dims.Yos) + (i * dims.Yos * dims.Y_split)) < dims.Yo
    Yi_stride = dims.Yos * dims.Sy * dims.Y_split
    Yi_start = (Y_idx * dims.Yos * dims.Sy) + (i * Yi_stride) - dims.Py if in_bounds else dims.Yi
    Yi_stop = Yi_start + dims.Yis if in_bounds else dims.Yi
    return Yi_start, Yi_stop, Yi_stride


def Xi_slice(dims: MaxPoolDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    """Slice for axis Xi at core (col, row) during iteration i of X_loop"""
    _, X_idx, _ = core_to_split(dims, col, row)
    in_bounds = ((X_idx * dims.Xos) + (i * dims.Xos * dims.X_split)) < dims.Xo
    Xi_stride = dims.Xos * dims.Sx * dims.X_split
    Xi_start = (X_idx * dims.Xos * dims.Sx) + (i * Xi_stride) - dims.Px if in_bounds else dims.Xi
    Xi_stop = Xi_start + dims.Xis if in_bounds else dims.Xi
    return Xi_start, Xi_stop, Xi_stride


def Yo_slice(dims: MaxPoolDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    """Slice for axis Yo at core (col, row) during iteration i of Y_loop"""
    Y_idx, _, _ = core_to_split(dims, col, row)
    Yo_stride = dims.Yos * dims.Y_split
    Yo_start = min((Y_idx * dims.Yos) + (i * Yo_stride), dims.Yo)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo)
    return Yo_start, Yo_stop, Yo_stride


def Xo_slice(dims: MaxPoolDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    """Slice for axis Xo at core (col, row) during iteration i of the X_loop"""
    _, X_idx, _ = core_to_split(dims, col, row)
    Xo_stride = dims.Xos * dims.X_split
    Xo_start = min((X_idx * dims.Xos) + (i * Xo_stride), dims.Xo)
    Xo_stop = min(Xo_start + dims.Xos, dims.Xo)
    return Xo_start, Xo_stop, Xo_stride


def ifm_memtile_channels() -> list[tuple[int, int]]:
    """
    Map spatial split and column to IFM channel allocations (row, id)
    IFM is always unicast
    """
    channels = unicast_channels()
    return channels


def wgt_memtile_channels(col: int) -> list[tuple[int, int]]:
    """
    Map spatial split and column to WGT channel allocations (row, id)
    WGT is always broadcast
    """
    channels = broadcast_channels(col)
    return channels


def ofm_memtile_channels() -> list[tuple[int, int]]:
    """Generate OFM channels allocations (row, id)"""
    return list(enumerate(overlay_3x4_O_ids()))


def ifm_core_channel() -> int:
    """IFM recieved on unicast channel"""
    channel = 0
    return channel


def wgt_core_channel() -> int:
    """WGT recieved on broadcast channel"""
    channel = 1
    return channel


def ifm_memtile_memory(dims: MaxPoolDims) -> str:
    """IFM memory tile layout"""
    return f"Yi:{dims.Yi} Xi:{dims.Xi} C:{dims.C}"


def ifm_memtile_s2mm(dims: MaxPoolDims) -> str:
    """IFM memory tile fill from shim access pattern"""
    return f"Yi:0:{dims.Yi} Xi:0:{dims.Xi} C:0:{dims.C}"


def ofm_memtile_memory(dims: MaxPoolDims) -> str:
    """OFM memory tile layout"""
    return f"Yo:{dims.Yo} Xo:{dims.Xo} C:{dims.C}"


def ofm_memtile_mm2s(dims: MaxPoolDims) -> str:
    """OFM memory tile spill to shim access pattern"""
    return f"Yo:0:{dims.Yo} Xo:0:{dims.Xo} C:0:{dims.C}"


def wgt_memtile_memory(dims: MaxPoolDims) -> str:
    """WGT memory tile layout"""
    return f"Bytes:{dims.wgt_subv_size}"


def wgt_memtile_s2mm(dims: MaxPoolDims) -> str:
    """WGT memory tile fill from shim access pattern"""
    return f"Bytes:0:{dims.wgt_subv_size}"


def wgt_memtile_mm2s(dims: MaxPoolDims) -> str:
    """WGT memory tile spill to shim access pattern"""
    return f"Bytes:0:{dims.wgt_subv_size}"


def wgt_shim_memory(dims: MaxPoolDims) -> str:
    """WGT shim memory layout"""
    return f"Bytes:{dims.wgt_subv_size}"


def wgt_shim_mm2s(dims: MaxPoolDims) -> str:
    """WGT shim mm2s access pattern"""
    return f"Bytes:0:{dims.wgt_subv_size}"


def ifm_shim_memory(dims: MaxPoolDims) -> str:
    """IFM shim memory layout"""
    return f"Yi:{dims.Yi} Xi:{dims.Xi} C:{dims.C}"


def ifm_shim_mm2s(dims: MaxPoolDims) -> str:
    """IFM shim mm2s access pattern"""
    return f"Yi:0:{dims.Yi} Xi:0:{dims.Xi} C:0:{dims.C}"


def ofm_shim_memory(dims: MaxPoolDims) -> str:
    """OFM shim memory layout"""
    return f"Yo:{dims.Yo} Xo:{dims.Xo} C:{dims.C}"


def ofm_shim_s2mm(dims: MaxPoolDims) -> str:
    """OFM shim s2mm access pattern"""
    return f"Yo:0:{dims.Yo} Xo:0:{dims.Xo} C:0:{dims.C}"


def dummy_transfer() -> str:
    """Dummy transfer for uninitialized buffers"""
    return "C:0:0"


def ifm_core_s2mm(dims: MaxPoolDims, col: int, row: int) -> str:
    """IFM core s2mm access pattern for tile at (col, row)"""
    # NOTE: We require that the core access pattern doesn't change
    # in time, so it's safe to just use the slice in iteration 0
    # for size calculations.
    Yi_start, Yi_stop, _ = Yi_slice(dims, col, row, 0)
    Xi_start, Xi_stop, _ = Xi_slice(dims, col, row, 0)
    Yis = Yi_stop - Yi_start
    Xis = Xi_stop - Xi_start
    return f"Yi:0:{Yis} C:0:{dims.Cs}:{dims.C_gran} Xi:0:{Xis} C:0:{dims.C_gran}"


def ifm_memtile_mm2s(dims: MaxPoolDims, col: int, row: int, Y_loop: int, X_loop: int) -> str:
    """IFM memory tile mm2s access pattern for tile at (col, row)"""
    Yi_start, Yi_stop, _ = Yi_slice(dims, col, row, Y_loop)
    Xi_start, Xi_stop, _ = Xi_slice(dims, col, row, X_loop)
    return f"C:0:{dims.C}:{dims.Cs} " f"Yi:{Yi_start}:{Yi_stop} " f"C:0:{dims.Cs_pad}:{dims.C_gran} " f"Xi:{Xi_start}:{Xi_stop} " f"C:0:{dims.C_gran}"


def ofm_memtile_s2mm(dims: MaxPoolDims, col: int, row: int, Y_loop: int, X_loop: int) -> str:
    """OFM memory tile s2mm access pattern for tile at (col, row)"""
    Yo_start, Yo_stop, _ = Yo_slice(dims, col, row, Y_loop)
    Xo_start, Xo_stop, _ = Xo_slice(dims, col, row, X_loop)
    return f"C:0:{dims.C}:{dims.Cs} " f"Yo:{Yo_start}:{Yo_stop} " f"C:0:{dims.Cs_pad}:{dims.C_gran} " f"Xo:{Xo_start}:{Xo_stop} " f"C:0:{dims.C_gran}"


def ofm_core_mm2s(dims: MaxPoolDims, col: int, row: int) -> str:
    """OFM core mm2s access pattern for tile at (col, row)"""
    # NOTE: We require the OFM access pattern to be repeatable
    # so it's safe to use start/stop from iteration 0.
    Yo_start, Yo_stop, _ = Yo_slice(dims, col, row, 0)
    Xo_start, Xo_stop, _ = Xo_slice(dims, col, row, 0)
    Yos = Yo_stop - Yo_start
    Xos = Xo_stop - Xo_start
    return f"Yo:0:{Yos} C:0:{dims.Cs_pad}:{dims.C_gran} Xo:0:{Xos} C:0:{dims.C_gran}"


def ifm_core_memory(dims: MaxPoolDims) -> str:
    """Define IFM L1 data order and shape"""
    return f" Yi:{dims.Yis} C:{dims.Cs} Xi:{dims.Xis} C:{dims.C_gran}"


def ofm_core_memory(dims: MaxPoolDims) -> str:
    """Define OFM L1 data order and shape"""
    return f"Yo:{dims.Yos} C:{dims.Cs} Xo:{dims.Xos} C:{dims.C_gran}"


def generate_maxpool_core_instrs(dims: MaxPoolDims, mapped_maxpool: MaxpoolShape, col: int, row: int) -> List[Any]:
    """Generate maxpool core instructions for core (col, row)"""
    ifm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, ifm_core_channel()),
        mapped_maxpool.ifm_l1_ping_addr,
        mapped_maxpool.ifm_l1_pong_addr,
        ifm_core_memory(dims),
        ifm_core_s2mm(dims, col, row),
    )
    ofm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.MM2S, 0),
        mapped_maxpool.ofm_l1_ping_addr,
        mapped_maxpool.ofm_l1_pong_addr,
        ofm_core_memory(dims),
        ofm_core_mm2s(dims, col, row),
    )
    core_ifm_s2mm_channel: int = 0  # Unicast channel
    core_wgt_s2mm_channel: int = 1  # Broadcast channel
    core_ofm_mm2s_channel: int = 0
    core_intrs = [
        ifm_config,
        ofm_config,
        ConfigBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel), mapped_maxpool.wgt_l1_addr, None, dims.wgt_subv_size),
        AcqBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
        RelBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
        Loop(
            (dims.C_loop * dims.Y_loop * dims.X_loop),
            [
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
                CallKernel("run_maxpool_int8x8", generate_aie4_maxpool_params(dims, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
            ],
        ),
    ]
    return core_intrs


def generate_memtile_ifm_transfers(dims: MaxPoolDims, ifm_l2_pinned_col: int, ifm_l2_addr: int, ifm_l2_size: int, enable_ifm_fill: bool) -> list[DataTransfer]:
    """
    Generate IFM transfers
    Here the first S2MM transfer is bringing in the tensor for pinning
    It is assumed that the same S2MM channel is used to transfer
    the IFM data to L2 prior to these data transfers
    """
    ifm_s2mm_packed_transfers = []
    ifm_mm2s_packed_transfers = []
    for col in range(ifm_l2_pinned_col, ifm_l2_pinned_col + 1):
        reconfig_transfer_params = []
        transfer_params_fill = generate_transfer_params(
            memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]), ifm_memtile_memory(dims), ifm_memtile_s2mm(dims) if enable_ifm_fill else dummy_transfer(), dims.ifm_bits
        )
        reconfig_transfer_params.append(transfer_params_fill)
        for i in range(dims.Y_loop):
            for j in range(dims.X_loop):
                transfer_params_reuse = generate_transfer_params(memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]), ifm_memtile_memory(dims), dummy_transfer(), dims.ifm_bits)
                reconfig_transfer_params.append(transfer_params_reuse)
        transfer_params_spill = generate_transfer_params(memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[1]), ifm_memtile_memory(dims), dummy_transfer(), dims.ifm_bits)
        reconfig_transfer_params.append(transfer_params_spill)
        packed_transfer_params = TransferParams(
            reconfig_transfer_params[0].dma,
            length=[transfer.length_i(0) for transfer in reconfig_transfer_params],
            offset=[transfer.offset_i(0) for transfer in reconfig_transfer_params],
            step=[transfer.step_i(0) for transfer in reconfig_transfer_params],
            wrap=[transfer.wrap_i(0) for transfer in reconfig_transfer_params],
            padding=[transfer.padding_i(0) for transfer in reconfig_transfer_params],
            iter_step=[transfer.iter_step_i(0) for transfer in reconfig_transfer_params],
            iter_wrap=[transfer.iter_wrap_i(0) for transfer in reconfig_transfer_params],
        )
        ifm_s2mm_packed_transfers.append(packed_transfer_params)
    # Pack MM2s transfer, 1st transfer is empty to sync with s2mm transfer
    for col in range(dims.aie_cols):
        for row, channel_id in ifm_memtile_channels():
            # NOTE: Pack reconfgurations for all the iterations of Y_loop and X_loop for each row
            reconfig_transfer_params = []
            transfer_params_fill = generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, channel_id),
                ifm_memtile_memory(dims),
                dummy_transfer(),
                dims.ifm_bits,
            )
            reconfig_transfer_params.append(transfer_params_fill)
            for i in range(dims.Y_loop):
                for j in range(dims.X_loop):
                    transfer_params = generate_transfer_params(
                        memtile_dma(col, DmaDir.MM2S, channel_id),
                        ifm_memtile_memory(dims),
                        ifm_memtile_mm2s(dims, col, row, i, j),
                        dims.ifm_bits,
                        enable_padding=True,
                    )
                    reconfig_transfer_params.append(transfer_params)
            transfer_params_spill = generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, channel_id),
                ifm_memtile_memory(dims),
                dummy_transfer(),
                dims.ifm_bits,
            )
            reconfig_transfer_params.append(transfer_params_spill)
            packed_transfer_params = TransferParams(
                reconfig_transfer_params[0].dma,
                length=[transfer.length_i(0) for transfer in reconfig_transfer_params],
                offset=[transfer.offset_i(0) for transfer in reconfig_transfer_params],
                step=[transfer.step_i(0) for transfer in reconfig_transfer_params],
                wrap=[transfer.wrap_i(0) for transfer in reconfig_transfer_params],
                padding=[transfer.padding_i(0) for transfer in reconfig_transfer_params],
                iter_step=[transfer.iter_step_i(0) for transfer in reconfig_transfer_params],
                iter_wrap=[transfer.iter_wrap_i(0) for transfer in reconfig_transfer_params],
            )
            ifm_mm2s_packed_transfers.append(packed_transfer_params)
        # Create a data transfer object for each col with all the rows and packed transfer params
    data_transfer = DataTransfer(
        [1] + [1] * (dims.Y_loop * dims.X_loop) + [0],
        AieTile(TileType.Memtile, ifm_l2_pinned_col),
        [ifm_l2_addr],
        ifm_l2_size,
        ifm_s2mm_packed_transfers,
        ifm_mm2s_packed_transfers,
        sync_strategy=SyncStrategy.Parallel_1_to_N,
    )
    return [data_transfer]


def generate_memtile_ofm_transfers(dims: MaxPoolDims, ofm_l2_pinned_col: int, ofm_l2_addr: int, ofm_l2_size: int, enable_ofm_spill: bool) -> list[DataTransfer]:
    """
    Generate OFM S2MM transfers
    Here the MM2S transfers are dummy and only used for synchroniazation
    It is assumed that the same MM2S channel is used to
    transfer the OFM data to L3 after these data transfers
    """
    ofm_s2mm_packed_transfers = []
    ofm_mm2s_packed_transfers = []
    for col in range(dims.aie_cols):
        for row, channel_id in ofm_memtile_channels():
            # NOTE: Pack reconfgurations for all the iterations of Y_loop and X_loop for each row
            reconfig_transfer_params = []
            transfer_params_fill = generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, channel_id),
                ofm_memtile_memory(dims),
                dummy_transfer(),
                dims.ofm_bits,
            )
            reconfig_transfer_params.append(transfer_params_fill)
            for i in range(dims.Y_loop):
                for j in range(dims.X_loop):
                    transfer_params = generate_transfer_params(
                        memtile_dma(col, DmaDir.S2MM, channel_id),
                        ofm_memtile_memory(dims),
                        ofm_memtile_s2mm(dims, col, row, i, j),
                        dims.ofm_bits,
                    )
                    reconfig_transfer_params.append(transfer_params)
            transfer_params_spill = generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, channel_id),
                ofm_memtile_memory(dims),
                dummy_transfer(),
                dims.ofm_bits,
            )
            reconfig_transfer_params.append(transfer_params_spill)
            packed_transfer_params = TransferParams(
                reconfig_transfer_params[0].dma,
                length=[transfer.length_i(0) for transfer in reconfig_transfer_params],
                offset=[transfer.offset_i(0) for transfer in reconfig_transfer_params],
                step=[transfer.step_i(0) for transfer in reconfig_transfer_params],
                wrap=[transfer.wrap_i(0) for transfer in reconfig_transfer_params],
                padding=[transfer.padding_i(0) for transfer in reconfig_transfer_params],
                iter_step=[transfer.iter_step_i(0) for transfer in reconfig_transfer_params],
                iter_wrap=[transfer.iter_wrap_i(0) for transfer in reconfig_transfer_params],
            )
            ofm_s2mm_packed_transfers.append(packed_transfer_params)
    # Start of packing MM2s transfers
    for col in range(ofm_l2_pinned_col, ofm_l2_pinned_col + 1):
        reconfig_transfer_params = []
        transfer_params_fill = generate_transfer_params(
            memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
            ofm_memtile_memory(dims),
            dummy_transfer(),
            dims.ofm_bits,
        )
        reconfig_transfer_params.append(transfer_params_fill)
        for i in range(dims.Y_loop):
            for j in range(dims.X_loop):
                transfer_params = generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
                    ofm_memtile_memory(dims),
                    dummy_transfer(),
                    dims.ofm_bits,
                )
                reconfig_transfer_params.append(transfer_params)
        transfer_params_spill = generate_transfer_params(
            memtile_dma(col, DmaDir.MM2S, overlay_3x4_S_ids(col)[0]),
            ofm_memtile_memory(dims),
            ofm_memtile_mm2s(dims) if enable_ofm_spill else dummy_transfer(),
            dims.ofm_bits,
        )
        reconfig_transfer_params.append(transfer_params_spill)
        packed_transfer_params = TransferParams(
            reconfig_transfer_params[0].dma,
            length=[transfer.length_i(0) for transfer in reconfig_transfer_params],
            offset=[transfer.offset_i(0) for transfer in reconfig_transfer_params],
            step=[transfer.step_i(0) for transfer in reconfig_transfer_params],
            wrap=[transfer.wrap_i(0) for transfer in reconfig_transfer_params],
            padding=[transfer.padding_i(0) for transfer in reconfig_transfer_params],
            iter_step=[transfer.iter_step_i(0) for transfer in reconfig_transfer_params],
            iter_wrap=[transfer.iter_wrap_i(0) for transfer in reconfig_transfer_params],
        )
        ofm_mm2s_packed_transfers.append(packed_transfer_params)
        # Create a data transfer object for each col with all the rows and packed transfer params
        data_transfer = DataTransfer(
            [0] + [1] * (dims.Y_loop * dims.X_loop) + [1],
            AieTile(TileType.Memtile, ofm_l2_pinned_col),
            [ofm_l2_addr],
            ofm_l2_size,
            ofm_s2mm_packed_transfers,
            ofm_mm2s_packed_transfers,
            sync_strategy=SyncStrategy.Parallel_N_to_1,
        )
    return [data_transfer]


def compile_maxpool_dataflow(schedule_input: ScheduleInputs) -> tuple:
    """Compile maxpool dataflow and return dimensions"""
    mapped_maxpool: MaxpoolShape = schedule_input.shape
    fusion_params: L2Alloc = schedule_input.L2_alloc
    L3_alloc: L3Alloc = schedule_input.L3_alloc
    maxpool_shim = L3Alloc_to_Shim(L3_alloc)
    log(f"Shim Allocator: {maxpool_shim}")
    dims = MaxPoolDims(mapped_maxpool)
    shim_prm_size = compute_buffer_size(prm_shim_memory())
    shim_wgt_size = compute_buffer_size(wgt_shim_memory(dims))
    overlay_shape = OverlayShape(dims.aie_cols, dims.aie_rows)
    dma_connections = overlay_3x4_dma_connections()
    no_of_reconfigs = dims.Y_loop * dims.X_loop
    # Allocate L2 buffers if fusion is disabled
    ShimAlloc = shim_alloc()
    ifm_l2_size = compute_buffer_size(ifm_memtile_memory(dims), dims.ifm_bits)
    wgt_l2_size = compute_buffer_size(wgt_memtile_memory(dims), dims.ifm_bits)
    ofm_l2_size = compute_buffer_size(ofm_memtile_memory(dims), dims.ofm_bits)
    param_l2_size = compute_buffer_size(prm_memtile_memory())
    ifm_l2_tile, ifm_l2_addr = fusion_params.ifm_L2_loc
    ofm_l2_tile, ofm_l2_addr = fusion_params.ofm_L2_loc
    wgt_L2_alloc_tiles = [entry[0] for entry in fusion_params.wgt_l2_loc]
    log(f"wgt_L2_alloc_tiles: {wgt_L2_alloc_tiles}")
    wgt_L2_ping_addrs = [entry[1] for entry in fusion_params.wgt_l2_loc]
    log(f"wgt_ping_addrs: {wgt_L2_ping_addrs}")
    wgt_L2_pong_addrs = [entry[2] for entry in fusion_params.wgt_l2_loc]
    log(f"wgt_L2_pong_addrs: {wgt_L2_pong_addrs}")
    prm_L2_alloc_tiles = [entry[0] for entry in fusion_params.prm_l2_loc]
    prm_L2_addrs = [entry[1] for entry in fusion_params.prm_l2_loc]
    log(f"prm_L2_alloc_tiles: {prm_L2_alloc_tiles}")
    log(f"prm_L2_addr: {prm_L2_addrs}")
    core_instrs = {}
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            core_instrs[core_tile(col, row)] = generate_maxpool_core_instrs(
                dims,
                mapped_maxpool,
                col,
                row,
            )
    memtile_param_transfers = [
        DataTransfer(
            [1] + [0] * (no_of_reconfigs) + [0],
            prm_L2_alloc_tiles[col],
            [prm_L2_addrs[col]],
            param_l2_size,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    prm_memtile_memory(),
                    prm_memtile_s2mm(),
                    dims.param_bits,
                )
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, overlay_3x4_A_ids()[row]),
                    prm_memtile_memory(),
                    prm_memtile_mm2s(row),
                    dims.param_bits,
                )
                for row in range(dims.aie_rows)
            ],
        )
        for col in range(dims.aie_cols)
    ]
    memtile_wgt_transfers = [
        DataTransfer(
            [1] + [0] * (no_of_reconfigs) + [0],
            wgt_L2_alloc_tiles[col],
            [wgt_L2_ping_addrs[col], wgt_L2_pong_addrs[col]],
            wgt_l2_size,
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.S2MM, overlay_3x4_F_ids()[0]),
                    wgt_memtile_memory(dims),
                    wgt_memtile_s2mm(dims),
                    dims.wgt_bits,
                )
            ],
            [
                generate_transfer_params(
                    memtile_dma(col, DmaDir.MM2S, channel_id),
                    wgt_memtile_memory(dims),
                    wgt_memtile_mm2s(dims),
                    dims.wgt_bits,
                )
                for row, channel_id in broadcast_channels(col)
            ],
        )
        for col in range(dims.aie_cols)
    ]
    # NOTE: IFM is pinned in l2 in the col1 mem tile and
    # all the 3 columns access data from this pinned buffer
    # In the below transfer the MM2S transfer is dummy and only used for synchronization
    # And the col range is limited to col-1 as IFM is pinned in center memtile
    memtile_ifm_transfers = []
    memtile_ifm_transfers += generate_memtile_ifm_transfers(dims, ifm_l2_tile.col, ifm_l2_addr, ifm_l2_size, fusion_params.enable_ifm_fill)

    memtile_ofm_transfers = []
    memtile_ofm_transfers += generate_memtile_ofm_transfers(dims, ofm_l2_tile.col, ofm_l2_addr, ofm_l2_size, fusion_params.enable_ofm_spill)
    # Start of shim transfers
    shim_param_transfers = [
        DataTransfer(
            [1] + [0] * (no_of_reconfigs) + [0],
            AieTile(TileType.Shim, col),
            [ShimAlloc.prm_buffer_id],
            param_l2_size,
            [],
            [
                generate_transfer_params(
                    shim_dma(col, DmaDir.MM2S, 0),
                    prm_shim_memory(),
                    prm_shim_mm2s(col),
                    dims.param_bits,
                    buffer_offset=maxpool_shim.prm_xrt_offset,
                )
            ],
        )
        for col in range(dims.aie_cols)
    ]
    shim_wgt_transfers = [
        DataTransfer(
            [1] + [0] * (no_of_reconfigs) + [0],
            AieTile(TileType.Shim, col),
            [ShimAlloc.wgt_buffer_id],
            wgt_l2_size,
            [],
            [
                generate_transfer_params(
                    shim_dma(col, DmaDir.MM2S, 0),
                    wgt_shim_memory(dims),
                    wgt_shim_mm2s(dims),
                    dims.wgt_bits,
                    buffer_offset=maxpool_shim.wgt_xrt_offset,
                )
            ],
        )
        for col in range(dims.aie_cols)
    ]
    # NOTE: IFM is pinned in l2 in the col1 mem tile and
    # all the 3 columns access data from this pinned buffer'
    shim_ifm_transfers = []
    if fusion_params.enable_ifm_fill:
        shim_ifm_transfers += [
            generate_shim_data_transfer(
                [1] + [0] * (no_of_reconfigs) + [0],
                shim_dma(col, DmaDir.MM2S, 1),
                ShimAlloc.ifm_buffer_id,
                ifm_shim_memory(dims),
                ifm_shim_mm2s(dims),
                dims.ifm_bits,
            )
            for col in range(1, 2)
        ]

    shim_ofm_transfers = []
    if fusion_params.enable_ofm_spill:
        shim_ofm_transfers += [
            generate_shim_data_transfer(
                [1] + [0] * (no_of_reconfigs) + [0],
                shim_dma(col, DmaDir.S2MM, 0),
                ShimAlloc.ofm_buffer_id,
                ofm_shim_memory(dims),
                ofm_shim_s2mm(dims),
                dims.ifm_bits,
            )
            for col in range(1, 2)
        ]

    memtile_transfers = []
    shimtile_transfers = []
    memtile_transfers += memtile_param_transfers
    memtile_transfers += memtile_wgt_transfers
    memtile_transfers += memtile_ifm_transfers
    memtile_transfers += memtile_ofm_transfers
    shimtile_transfers += shim_param_transfers
    shimtile_transfers += shim_wgt_transfers
    shimtile_transfers += shim_ifm_transfers
    shimtile_transfers += shim_ofm_transfers

    run_layer_compilation(
        overlay_shape,
        schedule_input.kernel_names,
        schedule_input.kernel_includes,
        core_instrs,
        memtile_transfers,
        shimtile_transfers,
        dma_connections,
        core_stack_addr=overlay_3x4_core_stack_addr(),
        param_channel_id=overlay_3x4_param_channel_id(),
        back_end=schedule_input.backend,
        layer_file=schedule_input.layer_file_name,
        dma_padding_map=schedule_input.dma_pad,
    )

    shim_prm_offset_next_layer = maxpool_shim.prm_xrt_offset + shim_prm_size
    shim_wgt_offset_next_layer = maxpool_shim.wgt_xrt_offset + shim_wgt_size
    log("shim_prm_offset_next_layer", shim_prm_offset_next_layer)
    log("shim_wgt_offset_next_layer", shim_wgt_offset_next_layer)

    return shim_prm_offset_next_layer, shim_wgt_offset_next_layer


def generate_maxpool_mappings(
    Yi: int,
    Xi: int,
    C: int,
    Yo: int,
    Xo: int,
    Ky: int,
    Kx: int,
    Sy: int,
    Sx: int,
    Py: int,
    Px: int,
    Y_gran: int = 1,
    X_gran: int = 1,
    C_gran: int = 64,
    Cs_min: int = 64,
    ifm_bits: int = 8,
    ofm_bits: int = 8,
    bits_per_byte: int = 8,
    aie_cols: int = 3,
    aie_rows: int = 4,
) -> list:
    """Generate maxpool mappings"""
    mapped_pool_soln = {}
    for key, split_mode_str in maxpool_split_modes.items():
        mapped_pool_soln[key] = MaxpoolShape(
            Yo,
            Xo,
            C,
            Yi,
            Xi,
            Ky,
            Kx,
            Sy,
            Sx,
            Py,
            Px,
            Y_gran,
            X_gran,
            C_gran,
            Cs_min,
            aie_cols,
            aie_rows,
            ifm_bits,
            ofm_bits,
            bits_per_byte,
        )
        spatial_split = split_mode_str
        map_maxpool(spatial_split, mapped_pool_soln[key])
    ranked_maxpool = maxpool_ranker(mapped_pool_soln)
    final_mapping_list = []
    for key, mapping in ranked_maxpool.items():
        final_mapping_list.append(mapping)
    return final_mapping_list


def main():
    """Main function to compile maxpool dataflow"""
    kernel_names = ["run_maxpool_int8x8"]
    kernel_includes = ["super.hh", "maxpool_int8x8_wrapper.cc"]
    layer_file_name = "dma.hpp"
    aie_cols = 3
    aie_rows = 4
    Yo, Xo, C = (56, 56, 64)
    Yi, Xi, C = (112, 112, 64)
    Ky, Kx = (3, 3)
    Sy, Sx = (2, 2)
    Py, Px = (0, 0)
    ifm_bits = 8
    ofm_bits = 8
    bits_per_byte = 8
    Y_gran, X_gran, C_gran = (1, 1, 64)
    # NOTE: There is an unrolling factor of 2 for channel dimenion in the kernel
    kernel_Cs_unroll = 1
    Cs_min = C_gran * kernel_Cs_unroll

    # Fusion params assumed for generic testing, these are input params from graph parser
    ifm_l2_addr = 0  # 3MB
    ofm_l2_addr = 1 * (2**20)  # 1MB
    # stream address is offset per mem tile
    stream_l2_addr = 2 * (2**20)  # 1MB
    stream_l2_size = 1 * (2**20)  # 1MB
    enable_ifm_fill = True
    enable_ofm_spill = True
    fusion_params = L2Alloc(
        (memory_tile(1), ifm_l2_addr),
        (memory_tile(1), ofm_l2_addr),
        stream_l2_addr,
        stream_l2_size,
        enable_ifm_fill,
        enable_ofm_spill,
        True,
    )
    maxpool_mapped_soln = generate_maxpool_mappings(
        Yi,
        Xi,
        C,
        Yo,
        Xo,
        Ky,
        Kx,
        Sy,
        Sx,
        Py,
        Px,
        Y_gran,
        X_gran,
        C_gran,
        Cs_min,
        ifm_bits,
        ofm_bits,
        bits_per_byte,
        aie_cols,
        aie_rows,
    )

    compile_maxpool_dataflow(maxpool_mapped_soln[0], fusion_params, kernel_names, kernel_includes, layer_file_name, BackEnd.Adf)


if __name__ == "__main__":
    main()
