'''
This module implements the code generation phase that compiles the
physical buffer descriptor allocation into a control code sequence
to run on the array. External facing functions are documented below.

generate_runtime_control - takes the low-level control operation sequence
and generates the control code to configure and enqueue all data transfers.
This supports either ADF APIs for running a simulation or driver APIs for
serializing a transaction binary.

generate_overlay_graph - converts the connectivity list into
an equivalent ADF graph

generate_super_kernel - creates the super kernel entry point
with function pointers for the provided kernel list
'''


from typing import List, Dict, Tuple, Optional, Union, Type
from collections import OrderedDict
from copy import deepcopy
import os
from dataclasses import dataclass, field

from .types import (
    DevGen,
    OverlayShape, TileType, AieTile, DmaDir, DmaChannel, AieDma,
    BackEnd, CascDir, DmaConnection, CoreConnection,
    Lock, BufferDescriptor, DataBuffer,
    CoreInstr, Loop, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel,
    ControlOpVisitor, BdConfig, DmaPaddingMap,
)
from . import config
from .control import LayerControl


################################################################################
#
# Code Generation for Runtime Control
#
################################################################################


# Compute unique DMA channels for buffer descriptors
def buffer_descriptor_dmas(buffer_descriptors: List[BufferDescriptor]) -> List[AieDma]:
    return list(OrderedDict.fromkeys([bd.aie_dma for bd in buffer_descriptors]))


# Compute chains of buffer descriptors that share the same DMA channel
def buffer_descriptor_chains(buffer_descriptors: List[BufferDescriptor]) -> List[List[BufferDescriptor]]:
    dmas = buffer_descriptor_dmas(buffer_descriptors)
    chains = [[bd for bd in buffer_descriptors if bd.aie_dma == dma]
              for dma in dmas]
    return chains


def compute_core_task_queue_depth(repeat_count: int, is_double_buffer: bool) -> int:
    task_depth = 0
    if is_double_buffer:
        if repeat_count % 2:
            repeat_count -= 1
            task_depth += 1
        max_repeats = 2 * config.MAX_REPEAT_COUNT
        task_depth += (repeat_count + max_repeats - 1) // max_repeats
    else:
        max_repeats = config.MAX_REPEAT_COUNT
        task_depth += (repeat_count + max_repeats - 1) // max_repeats
    return task_depth


def check_core_task_depth(cfg: ConfigBuffer):
    task_depth = compute_core_task_queue_depth(
        cfg.repeat_count,
        cfg.pong_addr is not None
    )
    if task_depth > config.MAX_TASK_QUEUE_SIZE:
        raise RuntimeError(f'Core task queue overflow on {cfg.dma_channel}!')


def compute_core_repeat_counts(core_instrs: List[Type[CoreInstr]]):
    channel_repeats: Dict[DmaChannel, int] = {}
    active_configs: Dict[DmaChannel, Optional[ConfigBuffer]] = {}
    for i in range(0, config.MAX_CORE_S2MM_DMA_CHANNEL + 1):
        channel_repeats[DmaChannel(DmaDir.S2MM, i)] = 0
        active_configs[DmaChannel(DmaDir.S2MM, i)] = None
    for i in range(0, config.MAX_CORE_MM2S_DMA_CHANNEL + 1):
        channel_repeats[DmaChannel(DmaDir.MM2S, i)] = 0
        active_configs[DmaChannel(DmaDir.MM2S, i)] = None
    def compute_core_repeat_counts_rec(instrs: List[Type[CoreInstr]]):
        for instr in instrs:
            if isinstance(instr, Loop):
                for _ in range(instr.num_iters):
                    compute_core_repeat_counts_rec(instr.loop_body)
            elif isinstance(instr, ConfigBuffer):
                dma_channel = instr.dma_channel
                prev_config = active_configs[dma_channel]
                if prev_config is not None:
                    prev_config.repeat_count = channel_repeats[dma_channel]
                    check_core_task_depth(prev_config)
                channel_repeats[dma_channel] = 0
                active_configs[dma_channel] = instr
            elif isinstance(instr, AcqBuffer):
                channel_repeats[instr.dma_channel] += 1
    compute_core_repeat_counts_rec(core_instrs)
    for cfg in active_configs.values():
        if cfg is not None:
            cfg.repeat_count = channel_repeats[cfg.dma_channel]
            check_core_task_depth(cfg)


def fuse_call_instrs(core_instrs: List[Type[CoreInstr]]):

    def prev_instr_pair(i: int) -> Tuple[Optional[CoreInstr], Optional[CoreInstr]]:
        instr1 = None
        instr2 = None
        if i >= 2:
            instr1 = core_instrs[i - 1]
            instr2 = core_instrs[i - 2]
        return instr1, instr2

    def next_instr_pair(i: int) -> Tuple[Optional[CoreInstr], Optional[CoreInstr]]:
        instr1 = None
        instr2 = None
        if i + 2 < len(core_instrs):
            instr1 = core_instrs[i + 1]
            instr2 = core_instrs[i + 2]
        return instr1, instr2

    def both_acq_input(pair: Tuple[Optional[CoreInstr], Optional[CoreInstr]]) -> bool:
        instr1, instr2 = pair
        return ((instr1 is not None and
                 isinstance(instr1, AcqBuffer) and
                 (not instr1.disable) and
                 instr1.dma_channel.dir == DmaDir.S2MM) and
                (instr2 is not None and
                 isinstance(instr2, AcqBuffer) and
                 (not instr2.disable) and
                 instr2.dma_channel.dir == DmaDir.S2MM))

    def both_rel_input(pair: Tuple[Optional[CoreInstr], Optional[CoreInstr]]) -> bool:
        instr1, instr2 = pair
        return ((instr1 is not None and
                 isinstance(instr1, RelBuffer) and
                 (not instr1.disable) and
                 instr1.dma_channel.dir == DmaDir.S2MM) and
                (instr2 is not None and
                 isinstance(instr2, RelBuffer) and
                 (not instr2.disable) and
                 instr2.dma_channel.dir == DmaDir.S2MM))

    for instr in core_instrs:
        if isinstance(instr, Loop):
            fuse_call_instrs(instr.loop_body)
    idx = 2
    while idx < len(core_instrs):
        can_fuse_in0_in1 = (isinstance(core_instrs[idx], CallKernel) and
                            both_acq_input(prev_instr_pair(idx)) and
                            both_rel_input(next_instr_pair(idx)))
        if can_fuse_in0_in1:
            core_instrs[idx].opcode = CoreInstr.KERNEL_CALL_IN0_IN1_OP
            core_instrs.pop(idx + 1)
            core_instrs.pop(idx + 1)
            core_instrs.pop(idx - 1)
            core_instrs.pop(idx - 2)
            idx -= 2
        idx += 1


def generate_layer_params(
    core_instrs: List[Type[CoreInstr]],
    kernel_names: Union[List[str], Dict[str, int]],
    word_size = 4
) -> bytes:

    kernel_ids: Dict[str, int] = {}
    kernel_id_counter = 0
    if (isinstance(kernel_names, list)):
        for name in kernel_names:
            kernel_ids[name] = kernel_id_counter
            kernel_id_counter += 1
    else:
        kernel_ids = deepcopy(kernel_names)

    buffer_ids: Dict[DmaChannel, int] = {}
    buffer_id_counter = 0
    for id in range(0, config.MAX_CORE_S2MM_DMA_CHANNEL + 1):
        buffer_ids[DmaChannel(DmaDir.S2MM, id)] = buffer_id_counter
        buffer_id_counter += 1
    for id in range(0, config.MAX_CORE_MM2S_DMA_CHANNEL + 1):
        buffer_ids[DmaChannel(DmaDir.MM2S, id)] = buffer_id_counter
        buffer_id_counter += 1

    instrs = deepcopy(core_instrs)
    compute_core_repeat_counts(instrs)
    fuse_call_instrs(instrs)

    layer_params = Loop(1, instrs).to_bytes(kernel_ids, buffer_ids)
    layer_params += b'\x00' * (len(layer_params) % word_size)
    if len(layer_params) > config.MAX_CORE_LAYER_PARAM_SIZE:
        raise RuntimeError('Core layer parameter buffer overflow!')

    return layer_params


def generate_buffer_comment(data_buffer: DataBuffer) -> str:
    writers = [bd for bd in data_buffer.buffer_descriptors
               if bd.aie_dma.channel.dir == DmaDir.S2MM]
    readers = [bd for bd in data_buffer.buffer_descriptors
               if bd.aie_dma.channel.dir == DmaDir.MM2S]
    writer_chains = buffer_descriptor_chains(writers)
    reader_chains = buffer_descriptor_chains(readers)
    num_writers = len(writer_chains)
    num_readers = len(reader_chains)
    writer_channels = [chain[0].aie_dma.channel for chain in writer_chains]
    reader_channels = [chain[0].aie_dma.channel for chain in reader_chains]
    writer_str = '\n    // '.join([f'{writer_channels[i]} BDs: '
                                   + ' -> '.join([str(bd.id) for bd in writer_chains[i]])
                                   for i in range(num_writers)])
    reader_str = '\n    // '.join([f'{reader_channels[i]} BDs: '
                                   + ' -> '.join([str(bd.id) for bd in reader_chains[i]])
                                   for i in range(num_readers)])
    lock_str = '\n    // '.join([f'Id: {lock.id}, Init: {lock.init_value:+}'
                                 for lock in data_buffer.locks])
    if num_writers == 0:
        writer_str = 'None'
    if num_readers == 0:
        reader_str = 'None'
    if len(data_buffer.buffer_descriptors) > 0:
        tile = data_buffer.buffer_descriptors[0].aie_dma.tile
    else:
        tile = 'None'
    return f'''
    //
    // {num_writers} to {num_readers} Data Transfer
    //
    // Location: {tile}
    //
    // Writers
    // ----------------
    // {writer_str}
    //
    // Readers
    // ----------------
    // {reader_str}
    //
    // Locks
    // ----------------
    // {lock_str}
    //
'''


def generate_task_iter_comment(iter: int) -> str:
    return f'''
    //
    // Task Iteration {iter}
    //
'''


def compute_lock_access_config(lock: Optional[Lock], dma_tile: AieTile) -> Tuple[int, int]:
    if lock is not None:
        tile_diff = (lock.aie_tile.col - dma_tile.col) + config.MAX_NEIGHBOR_ACCESS
        lock_offset = tile_diff * (config.MAX_MEMTILE_LOCK_ID + 1)
        lock_id = lock.id
    else:
        if dma_tile.type == TileType.Memtile:
            lock_offset = config.MAX_NEIGHBOR_ACCESS * (config.MAX_MEMTILE_LOCK_ID + 1)
        else:
            lock_offset = 0
        lock_id = 0
    return lock_offset, lock_id


def compute_buffer_access_config(cfg: BdConfig) -> Tuple[int, int]:
    # Here we factor the neighbor tile offset and local address into two
    # separate numbers for clarity in the generated code.
    bd = cfg.bd
    length = cfg.length_i()
    if bd.aie_dma.tile.type == TileType.Memtile:
        base = config.MEMTILE_BASE_ADDR
        size = config.MAX_MEMTILE_ADDR + 1
        if length == 0:
            # NOTE: We set zero-length BDs to a default address offset starting
            #       at the local memtile address space.
            buffer_offset = base + (config.MAX_NEIGHBOR_ACCESS * size)
            buffer_addr = 0
        else:
            buffer_offset = (((bd.buffer_addr - base) // size) * size) + base
            buffer_addr = (bd.buffer_addr - base) % size
            assert bd.buffer_addr == (buffer_offset + buffer_addr)
    else:
        buffer_offset = 0
        buffer_addr = 0
    return buffer_offset, buffer_addr


def generate_param_vector(
    shape: OverlayShape,
    layer_params: Dict[AieTile, bytes],
):
    param_bytes = b''
    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        for row in range(shape.start_row, shape.start_row + shape.num_rows):
            param = layer_params[AieTile(TileType.Core, col, row)]
            assert len(param) <= config.MAX_CORE_LAYER_PARAM_SIZE
            param += b'\x00' * (config.MAX_CORE_LAYER_PARAM_SIZE - len(param))
            param_bytes += param

    # Save `layer_params` to binary file for AIE4
    if config.DEV_GEN == DevGen.Aie4:
        with open("param.bin", "wb") as f:
            f.write(param_bytes)

    line_len = 16
    assert (len(param_bytes) % line_len) == 0
    return f'''std::vector<uint8_t> layer_params = {{
        {("," + chr(10) + " " * 8).join([
            ", ".join([f"0x{byte:x}"
                for byte in param_bytes[(i * line_len):((i + 1) * line_len)]])
            for i in range(len(param_bytes) // line_len)])}
    }};'''


def adf_code_header(config_name: str) -> str:
    return f'''
void {config_name}(
    ComputeGraph& graph,
    void* ddr_addr_idx0,
    void* ddr_addr_idx1,
    void* ddr_addr_idx2)
{{
#ifdef __AIESIM__
    adf::syncPSToGM();
#endif // __AIESIM__

    graph.run(1);
'''


def adf_code_footer() -> str:
    return f'''
    graph.wait();

#ifdef __AIESIM__
    adf::syncPSFromGM();
#endif // __AIESIM__

    adf::GMIO::free(ddr_addr_idx{config.SHIM_PARAM_BUFFER_IDX});
}}
'''


def adf_create_param_buffer(
    shape: OverlayShape,
    layer_params: Dict[AieTile, bytes],
) -> str:
    param_size = shape.num_cols * shape.num_rows * config.MAX_CORE_LAYER_PARAM_SIZE
    return f'''
    void* ddr_addr_idx{config.SHIM_PARAM_BUFFER_IDX} = adf::GMIO::malloc({param_size});
    {generate_param_vector(shape, layer_params)}
    assert(layer_params.size() == {param_size});
    memcpy(ddr_addr_idx{config.SHIM_PARAM_BUFFER_IDX}, layer_params.data(), {param_size});
'''


def adf_start_txn(name: str, asm_name: str = "placeholder") -> str:
    return ''


def adf_export_txn(name: str) -> str:
    return ''


def adf_tile_type(type: TileType) -> str:
    if type == TileType.Core:
        return 'adf::aie_tile'
    elif type == TileType.Memtile:
        return 'adf::memory_tile'
    else:
        return 'adf::shim_tile'


def adf_dma_dir(dir: DmaDir) -> str:
    if dir == DmaDir.MM2S:
        return 'adf::dma_mm2s'
    else:
        return 'adf::dma_s2mm'


def lock_id_sanity_check(col: int, tile_type: TileType, lock_id: int, lock_type: str) -> None:
    local_lock_low, local_lock_high = 448, 511

    ranges = {
        0: (local_lock_low, local_lock_high + 128),
        1: (local_lock_low - 64, local_lock_high + 64),
        2: (local_lock_low - 128, local_lock_high),
    }

    low, high = ranges[col]
    if tile_type == TileType.Memtile and not (low <= lock_id <= high):
        raise AssertionError(f'Wrong "{lock_type}" Lock ID "{lock_id}" for col "{col}"')


def adf_configure_buffer_descriptor(cfg: BdConfig) -> str:

    bd = cfg.bd

    assert bd.aie_dma.tile.type != TileType.Core

    next_bd_id = f'{bd.next_bd.id}' if bd.use_next_bd else '0'

    buffer_offset, buffer_addr = compute_buffer_access_config(cfg)

    lock_acq_offset, lock_acq_id = compute_lock_access_config(bd.lock_acq, bd.aie_dma.tile)

    lock_rel_offset, lock_rel_id = compute_lock_access_config(bd.lock_rel, bd.aie_dma.tile)

    access_offset = cfg.offset_i()

    if bd.aie_dma.tile.type == TileType.Memtile:
        address = f'{access_offset} + ((0x{buffer_offset:x} + 0x{buffer_addr:x}) / sizeof(uint32_t))'
        call = f'''
    adf::configureBufferDescriptor({adf_tile_type(bd.aie_dma.tile.type)}, {bd.aie_dma.tile.col}, {bd.aie_dma.tile.row}, {bd.id}, {cfg});'''
    else:
        address = f'{access_offset}'
        call = ''

    if cfg.iter_step_i() is not None:
        iter_step = f'''
    {cfg}.iteration_stepsize = {cfg.iter_step_i()};'''
    else:
        iter_step = ''

    if cfg.iter_wrap_i() is not None:
        iter_wrap = f'''
    {cfg}.iteration_wrap = {cfg.iter_wrap_i()};'''
    else:
        iter_wrap = ''

    if bd.packet_enable:
        packet_id = f'''
    {cfg}.enable_packet = {str(bd.packet_enable).lower()};
    {cfg}.packet_id = {bd.packet_id};'''
    else:
        packet_id = ''

    # NOTE: There is a bug in the simulator that will incorrectly crash
    # if a BD running on memtile channels 4 or 5 has locking disabled.
    # To get around this, we always enable locking. BDs with lock_enable
    # set to False will have lock ID 0 as the acq/rel lock, which
    # will succeed immediately.
    lock_enable = True
    assert bd.lock_enable or ((bd.lock_acq is None) and (bd.lock_acq_value == +0))
    assert bd.lock_enable or ((bd.lock_rel is None) and (bd.lock_rel_value == +0))

    return f'''
    // BDConfig | {bd.name} | Col: {bd.aie_dma.tile.col} | Dir: {bd.aie_dma.channel.dir} | Channel_ID: {bd.aie_dma.channel.id} | BD_ID: {bd.id}
    adf::dma_buffer_descriptor {cfg};
    {cfg}.address = {address};
    {cfg}.length = {cfg.length_i()};
    {cfg}.stepsize = {{{", ".join([str(x) for x in cfg.step_i()])}}};
    {cfg}.wrap = {{{", ".join([str(x) for x in cfg.wrap_i()])}}};
    {cfg}.padding = {{{", ".join([f'std::pair<uint32_t, uint32_t>({x}, {y})' for x, y in cfg.padding_i()])}}};
    {cfg}.lock_acq_enable = {str(lock_enable).lower()};
    {cfg}.lock_acq_value = {bd.lock_acq_value:+};
    {cfg}.lock_acq_id = {lock_acq_offset} + {lock_acq_id};
    {cfg}.lock_rel_value = {bd.lock_rel_value:+};
    {cfg}.lock_rel_id = {lock_rel_offset} + {lock_rel_id};
    {cfg}.use_next_bd = {str(bd.use_next_bd).lower()};
    {cfg}.next_bd = {next_bd_id};{iter_step}{iter_wrap}{packet_id}{call}
'''


def adf_patch_ddr_addr(cfg: BdConfig) -> str:
    assert cfg.bd.aie_dma.tile.type == TileType.Shim
    return ''


def adf_init_lock(lock: Lock) -> str:
    assert lock.aie_tile.type != TileType.Core
    return f'''
    // Init Lock | Lock_ID: {lock.id} | Col: {lock.aie_tile.col} | Init_Value: {lock.init_value}
    adf::initializeLock({adf_tile_type(lock.aie_tile.type)}, {lock.aie_tile.col}, {lock.aie_tile.row}, {lock.id}, {lock.init_value:+});
'''


def adf_enqueue_task(cfg_chain: List[BdConfig], repeat_count: int) -> str:
    bd = cfg_chain[0].bd
    assert bd.aie_dma.tile.type != TileType.Core
    if bd.aie_dma.tile.type == TileType.Memtile:
        enqueue = f'''
    // Enqueue | {bd.name} | Col: {bd.aie_dma.tile.col} | Channel_ID: {bd.aie_dma.channel.id} | BD_ID: {bd.id} | Repeat Count: {repeat_count}
    adf::enqueueTask({adf_tile_type(bd.aie_dma.tile.type)}, {bd.aie_dma.tile.col}, {bd.aie_dma.tile.row}, {adf_dma_dir(bd.aie_dma.channel.dir)}, {bd.aie_dma.channel.id}, {bd.id}, {repeat_count}, false);
'''
    else:
        bd_names = ', '.join([f'{cfg}' for cfg in cfg_chain])
        bd_ids = ', '.join([f'{cfg.bd.id}' for cfg in cfg_chain])
        bd_addrs = [cfg.bd.buffer_addr for cfg in cfg_chain]
        # NOTE: We can only have one shim BD address in a chain
        #       due to a limitation of aiesim. The TXN backend
        #       does not have this limitation.
        assert len(set(bd_addrs)) == 1
        if bd.aie_dma.channel.dir == DmaDir.MM2S:
            enqueue = f'''
    // Enqueue | {bd.name} | Col: {bd.aie_dma.tile.col} | Channel_ID: {bd.aie_dma.channel.id} | BD_Names: {bd_names} | BD_ID: {bd_ids} | Repeat Count: {repeat_count}
    graph.gmio_in_col{bd.aie_dma.tile.col}_ch{bd.aie_dma.channel.id}.gm2aie_nb(ddr_addr_idx{bd.buffer_addr}, {{{bd_names}}}, {{{bd_ids}}}, {repeat_count});
'''
        else:
            enqueue = f'''
    // Enqueue | {bd.name} | Col: {bd.aie_dma.tile.col} | Channel_ID: {bd.aie_dma.channel.id} | BD_Names: {bd_names} | BD_ID: {bd_ids} | Repeat Count: {repeat_count}
    graph.gmio_out_col{bd.aie_dma.tile.col}_ch{bd.aie_dma.channel.id}.aie2gm_nb(ddr_addr_idx{bd.buffer_addr}, {{{bd_names}}}, {{{bd_ids}}}, {repeat_count});
'''
    return enqueue


def adf_wait_dma_completion(aie_dma: AieDma) -> str:
    assert aie_dma.tile.type != TileType.Core
    if aie_dma.tile.type == TileType.Memtile:
        if config.ENABLE_BUSY_POLL is not None:
            enable_busy = f', {str(config.ENABLE_BUSY_POLL).lower()}'
        else:
            enable_busy = ''
        wait = f'''
    // Col: {aie_dma.tile.col} | Channel_ID: {aie_dma.channel.id}
    adf::waitDMAChannelDone({adf_tile_type(aie_dma.tile.type)}, {aie_dma.tile.col}, {aie_dma.tile.row}, {adf_dma_dir(aie_dma.channel.dir)}, {aie_dma.channel.id}{enable_busy});
'''
    else:
        if aie_dma.channel.dir == DmaDir.MM2S:
            wait = f'''
    // Col: {aie_dma.tile.col} | Channel_ID: {aie_dma.channel.id}
    graph.gmio_in_col{aie_dma.tile.col}_ch{aie_dma.channel.id}.wait();
'''
        else:
            wait = f'''
    // Col: {aie_dma.tile.col} | Channel_ID: {aie_dma.channel.id}
    graph.gmio_out_col{aie_dma.tile.col}_ch{aie_dma.channel.id}.wait();
'''
    return wait


def adf_acquire_lock(lock: Lock) -> str:
    raise NotImplementedError


class AdfControlOpVisitor(ControlOpVisitor):
    __slots__ = ()

    def visit_config_buffer_descriptor(self, op):
        return adf_configure_buffer_descriptor(op.cfg)

    def visit_patch_ddr_addr(self, op):
        return adf_patch_ddr_addr(op.cfg)

    def visit_set_lock_value(self, op):
        return adf_init_lock(op.lock)

    def visit_enqueue_task(self, op):
        return adf_enqueue_task(op.cfg_chain, op.repeat_count)

    def visit_wait_dma_done(self, op):
        return adf_wait_dma_completion(op.dma)

    def visit_aqcuire_lock(self, op):
        return adf_acquire_lock(op.lock)

    def visit_remote_barrier(self, op):
        return ''


def xrt_start_txn(name: str, asm_name: str = "placeholder") -> str:
    return '''
    XRT_ERRCHK(XAie_StartTransaction(&DevInst, XAIE_TRANSACTION_DISABLE_AUTO_FLUSH));
'''


def xrt_export_txn(name: str) -> str:
    return f'''
    unsigned char* {name}_ptr = XAie_ExportSerializedTransaction(&DevInst, 0, 0);
    auto* {name}_header = reinterpret_cast<XAie_TxnHeader*>({name}_ptr);
    uint32_t {name}_size = {name}_header->TxnSize;
    std::vector<char> {name}({name}_ptr, {name}_ptr + {name}_size);
#ifndef _WIN32
    // Skip free on Windows - XAie library may use different CRT heap
    free({name}_ptr);
#endif
    XRT_ERRCHK(XAie_ClearTransaction(&DevInst));
'''


def xrt_row_offset(type: TileType) -> int:
    assert type != TileType.Core
    if type == TileType.Memtile:
        offset = 1
    else:
        offset = 0
    return offset


def xrt_dma_dir(d: DmaDir) -> str:
    if d == DmaDir.S2MM:
        s = 'DMA_S2MM'
    else:
        s = 'DMA_MM2S'
    return s


def xrt_configure_buffer_descriptor(cfg: BdConfig) -> str:

    bd = cfg.bd

    assert bd.aie_dma.tile.type != TileType.Core

    buffer_offset, buffer_addr = compute_buffer_access_config(cfg)

    lock_acq_offset, lock_acq_id = compute_lock_access_config(bd.lock_acq, bd.aie_dma.tile)

    lock_rel_offset, lock_rel_id = compute_lock_access_config(bd.lock_rel, bd.aie_dma.tile)

    access_offset = cfg.offset_i()

    if bd.aie_dma.tile.type == TileType.Memtile:
        address = f'({access_offset} * sizeof(u32)) + (0x{buffer_offset:x} + 0x{buffer_addr:x})'
    else:
        assert 0 <= access_offset * 4 < 2**32
        assert 0 <= bd.buffer_addr < 2**8
        if config.DEV_GEN == DevGen.Aie4:
            address = f'({access_offset} * sizeof(u32)) | (u64(0) << 32)'
        else:
            address = f'({access_offset} * sizeof(u32)) | (u64({bd.buffer_addr}) << 32)'

    # Init step and wrap dimensions
    steps = cfg.step_i()
    wraps = cfg.wrap_i()
    if len(wraps) == len(steps) - 1:
        wraps = cfg.wrap_i() + [0]
    assert len(steps) == len(wraps)
    step_init = '\n    '.join([f'{cfg}_dims[{i}].AieMlDimDesc.StepSize = {steps[i]};' for i in range(len(steps))])
    wrap_init = '\n    '.join([f'{cfg}_dims[{i}].AieMlDimDesc.Wrap = {wraps[i]};' for i in range(len(wraps))])

    # Init padding parameters
    padding_decl = ''
    padding_init = ''
    padding_call = ''
    if len(cfg.padding_i()) > 0:
        padding_decl = f'''
    XAie_DmaPadTensor {cfg}_pad_tensor;
    XAie_PadDesc {cfg}_pad_dims[{len(cfg.padding_i())}];'''
        padding_before = '\n    '.join([f'{cfg}_pad_dims[{i}].Before = {cfg.padding_i()[i][0]};' for i in range(len(cfg.padding_i()))])
        padding_after = '\n    '.join([f'{cfg}_pad_dims[{i}].After = {cfg.padding_i()[i][1]};' for i in range(len(cfg.padding_i()))])
        padding_init = f'''
    {padding_before}
    {padding_after}
    {cfg}_pad_tensor.NumDim = {len(cfg.padding_i())};
    {cfg}_pad_tensor.PadDesc = {cfg}_pad_dims;'''
        padding_call = f'''
    XRT_ERRCHK(XAie_DmaSetPadding(&{cfg}, &{cfg}_pad_tensor));'''

    # Init iteration parameters
    iteration_call = ''
    if cfg.iter_step_i() is not None:
        step = cfg.iter_step_i()
        wrap = cfg.iter_wrap_i() if cfg.iter_wrap_i() is not None else 0
        iteration_call = f'''
    XRT_ERRCHK(XAie_DmaSetBdIteration(&{cfg}, {step}, {max(1, wrap)}, 0));'''

    # Init locking parameters
    locking_decl = ''
    locking_init = ''
    locking_call = ''
    if bd.lock_enable:
        locking_decl = f'''
    XAie_Lock {cfg}_acq;
    XAie_Lock {cfg}_rel;'''
        locking_init = f'''
    {cfg}_acq.LockId = {lock_acq_offset} + {lock_acq_id};
    {cfg}_acq.LockVal = {bd.lock_acq_value:+};
    {cfg}_rel.LockId = {lock_rel_offset} + {lock_rel_id};
    {cfg}_rel.LockVal = {bd.lock_rel_value:+};'''
        locking_call = f'''
    XRT_ERRCHK(XAie_DmaSetLock(&{cfg}, {cfg}_acq, {cfg}_rel));'''

    # Init next BD parameters
    next_bd_call = ''
    if bd.use_next_bd:
        if config.DEV_GEN == DevGen.Aie4:
            next_bd_call = f'''
                XRT_ERRCHK(XAie_DmaSetNextBd(&{cfg}, {cert_bd_id_offset_adjust(bd.next_bd.id, bd.aie_dma.tile.type)}, XAIE_ENABLE));'''
        else:
            next_bd_call = f'''
        XRT_ERRCHK(XAie_DmaSetNextBd(&{cfg}, {bd.next_bd.id}, XAIE_ENABLE));'''

    # Init packet ID parameters
    packet_id_call = ''
    if bd.packet_enable:
        packet_id_call = f'''
    XRT_ERRCHK(XAie_DmaSetPkt(&{cfg}, XAie_PacketInit({bd.packet_id}, 0)));'''

    # Init AXI parameters
    axi_call = ''
    if bd.aie_dma.tile.type == TileType.Shim:
        axi_call = f'''
    XRT_ERRCHK(XAie_DmaSetAxi(&{cfg}, 0, {config.MAX_DDR_BURST_LENGTH}, 0, 2, 0));'''

    cert_ddr_addr_patch = ''
    if bd.aie_dma.tile.type == TileType.Shim:
        cert_ddr_addr_patch = f'''
    XRT_ERRCHK(XAie_AddressPatching(&DevInst, {bd.buffer_addr}, 1));'''

    # Init WriteBD call
    write_bd_call = ''
    if config.DEV_GEN == DevGen.Aie4:
        write_bd_call = f'''{cert_ddr_addr_patch}{axi_call}
    XRT_ERRCHK(XAie_DmaWriteBdGeneric(&DevInst, &{cfg}, XAie_TileLoc({bd.aie_dma.tile.col}, {xrt_row_offset(bd.aie_dma.tile.type)} + {bd.aie_dma.tile.row}), {bd.aie_dma.channel.id}, DMA_{bd.aie_dma.channel.dir.name}, {cert_bd_id_offset_adjust(bd.id, bd.aie_dma.tile.type)}));'''
    else:
        write_bd_call = f'''
    XRT_ERRCHK(XAie_DmaEnableBd(&{cfg}));{axi_call}
    XRT_ERRCHK(XAie_DmaWriteBd(&DevInst, &{cfg}, XAie_TileLoc({bd.aie_dma.tile.col}, {xrt_row_offset(bd.aie_dma.tile.type)} + {bd.aie_dma.tile.row}), {bd.id}));'''

    # Return the generated code
    return f'''
    // BDConfig | {bd.name} | Col: {bd.aie_dma.tile.col} | Dir: {bd.aie_dma.channel.dir} | Channel_ID: {bd.aie_dma.channel.id} | BD_ID: {bd.id}
    {{
    XAie_DmaDesc {cfg};
    XAie_DmaDimDesc {cfg}_dims[{len(cfg.step_i())}];
    XAie_DmaTensor {cfg}_tensor;{padding_decl}{locking_decl}
    u64 {cfg}_addr = {address};
    u32 {cfg}_len = {cfg.length_i()} * sizeof(u32);
    {step_init}
    {wrap_init}
    {cfg}_tensor.NumDim = {len(cfg.step_i())};
    {cfg}_tensor.Dim = {cfg}_dims;{padding_init}{locking_init}
    XRT_ERRCHK(XAie_DmaDescInit(&DevInst, &{cfg}, XAie_TileLoc({bd.aie_dma.tile.col}, {xrt_row_offset(bd.aie_dma.tile.type)} + {bd.aie_dma.tile.row})));
    XRT_ERRCHK(XAie_DmaSetMultiDimAddr(&{cfg}, &{cfg}_tensor, {cfg}_addr, {cfg}_len));{padding_call}{iteration_call}{locking_call}{next_bd_call}{packet_id_call}{write_bd_call}
    }}
'''


def xrt_patch_ddr_addr(cfg: BdConfig) -> str:
    bd = cfg.bd
    assert bd.aie_dma.tile.type == TileType.Shim
    return f'''
    AddDDRCustomOp(0, {bd.buffer_addr}, ({cfg.offset_i()} * sizeof(u32)), {bd.aie_dma.tile.col}, {bd.id}, XAIE_IO_CUSTOM_OP_DDR_PATCH);
'''


def xrt_init_lock(lock: Lock) -> str:
    return f'''
    // Init Lock | Lock_ID: {lock.id} | Col: {lock.aie_tile.col} | Init_Value: {lock.init_value}
    XRT_ERRCHK(XAie_LockSetValue(&DevInst, XAie_TileLoc({lock.aie_tile.col}, {xrt_row_offset(lock.aie_tile.type)} + {lock.aie_tile.row}), XAie_LockInit({lock.id}, {lock.init_value:+})));
'''


def xrt_enqueue_task(cfg_chain: List[BdConfig], repeat_count: int) -> str:
    bd = cfg_chain[0].bd
    return f'''
    // Enqueue | {bd.name} | Col: {bd.aie_dma.tile.col} | Channel_ID: {bd.aie_dma.channel.id} | BD_ID: {bd.id} | Repeat Count: {repeat_count}
    XRT_ERRCHK(XAie_DmaChannelSetStartQueue(&DevInst, XAie_TileLoc({bd.aie_dma.tile.col}, {xrt_row_offset(bd.aie_dma.tile.type)} + {bd.aie_dma.tile.row}), {bd.aie_dma.channel.id}, {xrt_dma_dir(bd.aie_dma.channel.dir)}, {bd.id}, {repeat_count}, XAIE_DISABLE));
'''


def xrt_wait_dma_completion(aie_dma: AieDma) -> str:
    if config.ENABLE_BUSY_POLL:
        callname = 'XAie_DmaWaitForDoneBusy'
    else:
        callname = 'XAie_DmaWaitForDone'
    return f'''
    // Col: {aie_dma.tile.col} | Channel_ID: {aie_dma.channel.id}
    XRT_ERRCHK({callname}(&DevInst, XAie_TileLoc({aie_dma.tile.col}, {xrt_row_offset(aie_dma.tile.type)} + {aie_dma.tile.row}), {aie_dma.channel.id}, {xrt_dma_dir(aie_dma.channel.dir)}, 0));
'''


def xrt_acquire_lock(lock: Lock) -> str:
    if config.ENABLE_BUSY_POLL:
        callname = 'XAie_LockAcquireBusy'
    else:
        callname = 'XAie_LockAcquire'
    return f'''
    XRT_ERRCHK({callname}(&DevInst, XAie_TileLoc({lock.aie_tile.col}, {xrt_row_offset(lock.aie_tile.type)} + {lock.aie_tile.row}), XAie_LockInit({lock.id}, {lock.init_value:+}), 1000000U));
'''


def txn_code_header(config_name: str) -> str:
    return f'''
DmaBins {config_name}()
{{
'''


def txn_lx7_patch_code_footer() -> str:
    return f'''
    DmaBins bins;
    bins.txn_bin = lx7_txn;
    bins.layer_params = layer_params;
    bins.ctrl_pkts = generate_control_packets(pkt_txn);

    return bins;
}}
'''


def txn_host_patch_code_footer() -> str:
    return '''
    DmaBins bins;
    bins.txn_bin = lx7_txn;
    bins.layer_params = layer_params;
    bins.ctrl_pkts = generate_control_packets(pkt_txn);
    bins.patch_json = generate_patch_json(bins.ctrl_pkts);

    return bins;
}
'''


def txn_create_param_buffer(
    shape: OverlayShape,
    layer_params: Dict[AieTile, bytes],
) -> str:
    return f'''
    {generate_param_vector(shape, layer_params)}
'''


class TxnControlOpVisitor(ControlOpVisitor):
    __slots__ = ()

    def visit_config_buffer_descriptor(self, op):
        return xrt_configure_buffer_descriptor(op.cfg)

    def visit_patch_ddr_addr(self, op):
        return xrt_patch_ddr_addr(op.cfg)

    def visit_set_lock_value(self, op):
        return xrt_init_lock(op.lock)

    def visit_enqueue_task(self, op):
        return xrt_enqueue_task(op.cfg_chain, op.repeat_count)

    def visit_wait_dma_done(self, op):
        return xrt_wait_dma_completion(op.dma)

    def visit_aqcuire_lock(self, op):
        return xrt_acquire_lock(op.lock)

    def visit_remote_barrier(self, op):
        return ''


def cert_start_asm(name: str, asm_name: str = "test.asm") -> str:
    # TODO: The if-else here is temporary till ctrl_pkts are not supported for CERT
    if not config.IS_MULTI_UC:
        asm_name = "test.asm"
    if name == 'lx7_txn':
        return ''
    else:
        return f'''
    XAie_OpenControlCodeFile(&DevInst, "Work_AIE4/{asm_name}", 8192);
'''


def cert_export_asm(name: str) -> str:
    # TODO: The if-else here is temporary till ctrl_pkts are not supported for CERT
    if name == 'pkt_txn':
        return ''
    else:
        return '''
    XAie_CloseControlCodeFile(&DevInst);
'''


def cert_code_header(config_name: str) -> str:
    return f'''
int main() {{
    XAie_DevInst DevInst = Device_Configure_Intialization();
'''


def cert_code_footer() -> str:
    return f'''
    DmaBins bins;
    bins.layer_params = layer_params;
    bins.save();
    return 0;
}}
'''

def cert_bd_id_offset_adjust(bd_id: int, tile_type: TileType) -> int:
    if tile_type == TileType.Memtile:
        bd_id = bd_id % 16
    return bd_id


def cert_patch_ddr_addr(cfg: BdConfig) -> str:
    bd = cfg.bd
    assert bd.aie_dma.tile.type == TileType.Shim
    return f''


def cert_enqueue_task(cfg_chain: List[BdConfig], repeat_count: int) -> str:
    bd = cfg_chain[0].bd
    return f'''
    // {bd.name} | Col: {bd.aie_dma.tile.col} | Channel_ID: {bd.aie_dma.channel.id} | BD_ID: {bd.id} | Repeat Count: {repeat_count}
    XRT_ERRCHK(XAie_DmaChannelSetStartQueue(&DevInst, XAie_TileLoc({bd.aie_dma.tile.col}, {xrt_row_offset(bd.aie_dma.tile.type)} + {bd.aie_dma.tile.row}), {bd.aie_dma.channel.id}, {xrt_dma_dir(bd.aie_dma.channel.dir)}, {cert_bd_id_offset_adjust(bd.id, bd.aie_dma.tile.type)}, {repeat_count}, XAIE_DISABLE));
'''


class CertAsmOpVisitor(ControlOpVisitor):
    __slots__ = ()

    def visit_config_buffer_descriptor(self, op):
        code = xrt_configure_buffer_descriptor(op.cfg)
        return code

    def visit_patch_ddr_addr(self, op):
        return cert_patch_ddr_addr(op.cfg)

    def visit_set_lock_value(self, op):
        code = xrt_init_lock(op.lock)
        return code

    def visit_enqueue_task(self, op):
        code = cert_enqueue_task(op.cfg_chain, op.repeat_count)
        return code

    def visit_wait_dma_done(self, op):
        code = xrt_wait_dma_completion(op.dma)
        return code

    def visit_aqcuire_lock(self, op):
        return xrt_acquire_lock(op.lock)

    def visit_remote_barrier(self, op):
        return generate_remote_barrier(op.id)


def append_code(op, visitor, code_list):
    '''
    Don't append repeat BD configuration created due
    to reuse of lock BDs in Optimization Phase 1.1
    '''
    result = op.apply(visitor)
    if type(op).__name__ == "ConfigBufferDescriptor":
        if result not in code_list:
            code_list.append(result)
    else:
        code_list.append(result)


@dataclass
class ChannelGroup:
    """Channels for one column"""
    unicast: List[int] = field(default_factory=list)
    broadcast: List[int] = field(default_factory=list)


@dataclass
class MemtileChannelMap:
    """Full mapping of memtile column → channels"""
    columns: Dict[int, ChannelGroup] = field(default_factory=dict)


def map_memtile_channels(dma_connections: list[DmaConnection]) -> MemtileChannelMap:
    """
    Build a MemtileChannelMap from DMA connections.

    We only consider connections where:
      - read side:  Memtile tile, MM2S
      - write side: Core or Shim tile, S2MM

    Classification:
      - Memtile MM2S -> Shim S2MM: always treated as unicast
      - Memtile MM2S -> Core S2MM:
            write.channel.id == 0 -> unicast
            write.channel.id == 1 -> broadcast

    Channels are grouped per Memtile column. For each column, the unicast
    and broadcast channel lists are deduplicated and sorted.
    """
    result = MemtileChannelMap()

    for conn in dma_connections:
        read, write = conn.read_dma, conn.write_dma

        if (
            read.tile.type == TileType.Memtile
            and read.channel.dir == DmaDir.MM2S
            and write.channel.dir == DmaDir.S2MM
            and write.tile.type in (TileType.Core, TileType.Shim)
        ):
            col = read.tile.col
            group = result.columns.setdefault(col, ChannelGroup())

            # Memtile MM2S -> Shim S2MM: always treat as unicast
            if write.tile.type == TileType.Shim or write.channel.id == 0:
                group.unicast.append(read.channel.id)
            elif write.channel.id == 1:
                group.broadcast.append(read.channel.id)

    for group in result.columns.values():
        group.unicast = sorted(set(group.unicast))
        group.broadcast = sorted(set(group.broadcast))

    return result


def generate_dma_padding(
    shape: OverlayShape,
    dma_connections: list[DmaConnection],
    backend: BackEnd,
    pad_values: DmaPaddingMap,
    col: int
) -> str:
    """Generate API calls to set DMA pad values for the given AieDma"""

    if not getattr(pad_values, "enable_dma_pad", False):
        return ""

    valid_dma_channels = map_memtile_channels(dma_connections)
    pad_value = getattr(pad_values, "pad_value")
    lines: List[str] = ["\n"]
    
    if config.IS_MULTI_UC:
        start_idx = col
        end_idx = col + 1   
    else:
        start_idx = shape.start_col
        end_idx = shape.start_col + shape.num_cols

    for col in range(start_idx, end_idx):
        group = valid_dma_channels.columns.get(col)
        if not group:
            continue

        for ch_type in ("unicast", "broadcast"):
            channel_list = getattr(group, ch_type)
            for ch in channel_list:
                lines.append(
                    f"\n    // DMA Padding | Col: {col} | Ch_Type: {ch_type.capitalize()} | "
                    f"Channel_ID: {ch} | Pad_Value: {pad_value}\n"
                )
                if backend in [BackEnd.TxnHostPatch, BackEnd.CertAsm]:
                    lines.append(f"    XAie_DmaSetPadValue(&DevInst, XAie_TileLoc({col}, 1), {ch}, {pad_value});\n")
                elif backend in [BackEnd.Adf]:
                    lines.append(f"    adf::configurePadValue(adf::memory_tile, {col}, 0, {ch}, {pad_value});\n")
                else:
                    raise ValueError(f"Invalid BackEnd: {backend}")

    return "".join(lines)


def generate_end_page() -> str:
    return f'''
    XRT_ERRCHK(XAie_EndPage(&DevInst));
    '''


def generate_start_job() -> str:
    return f'''
    XRT_ERRCHK(XAie_StartNewJob(&DevInst));
    '''


def generate_attach_to_group(uc_index: int) -> str:
    if config.IS_MULTI_UC:
        return f'''
    XRT_ERRCHK(XAie_AttachToGroup(&DevInst, {uc_index}));
    '''
    else:
        return ''


def generate_remote_barrier(rb_id: int) -> str:
    if config.IS_MULTI_UC:
        return f'''
    XRT_ERRCHK(XAie_RemoteBarrier(&DevInst, {rb_id}, {0x15}));'''
    else:
        return ''


def generate_runtime_control(
    shape: OverlayShape,
    dma_connections: List[DmaConnection],
    layer_control: list[LayerControl],
    data_buffers: List[DataBuffer],
    core_instrs: Union[List[Type[CoreInstr]], Dict[AieTile, List[Type[CoreInstr]]]],
    kernel_names: Union[List[str], Dict[str, int]],
    back_end: BackEnd,
    config_name: str,
    dma_padding_map: DmaPaddingMap,
) -> str:

    layer_params = {}
    if isinstance(core_instrs, list):
        param = generate_layer_params(core_instrs, kernel_names)
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            for row in range(shape.start_row, shape.start_row + shape.num_rows):
                layer_params[AieTile(TileType.Core, col, row)] = param
    else:
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            for row in range(shape.start_row, shape.start_row + shape.num_rows):
                layer_params[AieTile(TileType.Core, col, row)] = generate_layer_params(
                    core_instrs[AieTile(TileType.Core, col, row)], kernel_names)

    # Set up backend specific functions
    if back_end == BackEnd.Adf:
        visitor = AdfControlOpVisitor()
        code_header = adf_code_header
        code_footer = adf_code_footer
        start_txn = adf_start_txn
        export_txn = adf_export_txn
        create_param = adf_create_param_buffer
        has_startup_control = False
        host_runtime_files = ['txn.hpp','dma.hpp']
    elif back_end == BackEnd.TxnHostPatch:
        visitor = TxnControlOpVisitor()
        code_header = txn_code_header
        code_footer = txn_host_patch_code_footer
        start_txn = xrt_start_txn
        export_txn = xrt_export_txn
        create_param = txn_create_param_buffer
        has_startup_control = True
        host_runtime_files = ['txn.hpp','dma.hpp']
    elif back_end == BackEnd.CertAsm:
        visitor = CertAsmOpVisitor()
        code_header = cert_code_header
        code_footer = cert_code_footer
        start_txn = cert_start_asm
        export_txn = cert_export_asm
        create_param = txn_create_param_buffer
        has_startup_control = False
        host_runtime_files = ['aie4_dma.cpp']
    else:
        assert False

    currdir = os.path.dirname(__file__)

    code = []

    if config.ENABLE_FAST_PM:
        code.append('#define FAST_PM\n')

    for file in host_runtime_files:
        impl_filename = os.path.abspath(os.path.join(currdir, 'host_runtime', file))
        with open(impl_filename, 'r') as f:
            file_content = f.read()
        code.append(file_content)

    code.append(code_header(config_name))

    code.append(create_param(shape, layer_params))

    for buffer in data_buffers:
        code.append(generate_buffer_comment(buffer))
        
    if config.IS_MULTI_UC:
        start_idx = shape.start_col
        end_idx = shape.start_col + shape.num_cols
    else:
        start_idx = 0
        end_idx = 1      

    iters = []    # Collect active phases from control code of all uCs
    for col in range(start_idx, end_idx):
        phases = layer_control[col].dataflow_phases     # List of phases for this uC
        # Phases where this uC has non-empty phases
        non_empty = [i for i in range(len(phases)) if len(phases[i]) > 0]
        # Add all active phases + len(phases) to mark uC's end boundary
        iters.extend(non_empty + [len(phases)])
    # Earliest phase across all uCs
    start_iter = min(iters)

    for col in range(start_idx, end_idx):
        code.append(start_txn('pkt_txn', f"uc{col * 2}.asm"))
        code.append(generate_attach_to_group(col * 2))

        # DmaPadding
        if config.DEV_GEN == DevGen.Aie4:
            code.append(generate_dma_padding(shape, dma_connections, back_end, dma_padding_map, col))

        for op in layer_control[col].control_pkts:
            append_code(op, visitor, code)
        code.append(export_txn('pkt_txn'))

        code.append(start_txn('lx7_txn'))
        if has_startup_control:
            for op in layer_control[col].startup_control:
                append_code(op, visitor, code)

        num_iters = len(layer_control[col].dataflow_phases)
        for iter in range(start_iter, num_iters):
            code.append(generate_task_iter_comment(iter))
            for op in layer_control[col].dataflow_phases[iter]:
                append_code(op, visitor, code)
        for op in layer_control[col].final_barrier:
            append_code(op, visitor, code)

        code.append(export_txn('lx7_txn'))

    code.append(code_footer())

    s = ''.join(code)

    return s


################################################################################
#
# Code Generation for Overlay Graph
#
################################################################################


class GraphAllocator:
    __slots__ = ('shape',
                 'tile_input_counts', 'tile_output_counts',
                 'tile_input_ports', 'tile_output_ports',
                 'channel_idxs')

    def __init__(self, shape: OverlayShape):
        self.shape = shape
        self.tile_input_counts: Dict[AieTile, int] = {}
        self.tile_output_counts: Dict[AieTile, int] = {}
        self.tile_input_ports: Dict[AieTile, List[Tuple[int, DmaChannel]]] = {}
        self.tile_output_ports: Dict[AieTile, List[Tuple[int, DmaChannel]]] = {}
        self.channel_idxs: Dict[AieDma, int] = {}

        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            self.tile_input_counts[AieTile(TileType.Memtile, col, 0)] = 0
            self.tile_output_counts[AieTile(TileType.Memtile, col, 0)] = 0
            self.tile_input_counts[AieTile(TileType.Shim, col, 0)] = 0
            self.tile_output_counts[AieTile(TileType.Shim, col, 0)] = 0
            self.tile_input_ports[AieTile(TileType.Memtile, col, 0)] = []
            self.tile_output_ports[AieTile(TileType.Memtile, col, 0)] = []
            self.tile_input_ports[AieTile(TileType.Shim, col, 0)] = []
            self.tile_output_ports[AieTile(TileType.Shim, col, 0)] = []
            for row in range(shape.start_row, shape.start_row + shape.num_rows):
                self.tile_input_counts[AieTile(TileType.Core, col, row)] = 0
                self.tile_output_counts[AieTile(TileType.Core, col, row)] = 0
                self.tile_input_ports[AieTile(TileType.Core, col, row)] = []
                self.tile_output_ports[AieTile(TileType.Core, col, row)] = []

    def alloc_dma(self, dma: AieDma):
        if dma not in self.channel_idxs:
            if dma.channel.dir == DmaDir.S2MM:
                idx = self.tile_input_counts[dma.tile]
                self.tile_input_counts[dma.tile] += 1
                self.tile_input_ports[dma.tile].append((idx, dma.channel))
            else:
                idx = self.tile_output_counts[dma.tile]
                self.tile_output_counts[dma.tile] += 1
                self.tile_output_ports[dma.tile].append((idx, dma.channel))
            self.channel_idxs[dma] = idx

    def input_ports(self, tile: AieTile) -> List[Tuple[int, DmaChannel]]:
        return self.tile_input_ports[tile]

    def output_ports(self, tile: AieTile) -> List[Tuple[int, DmaChannel]]:
        return self.tile_output_ports[tile]

    def port_index(self, dma: AieDma) -> int:
        return self.channel_idxs[dma]

    def param_index(self, tile: AieTile) -> int:
        assert tile.type == TileType.Core
        return self.tile_input_counts[tile]


def adf_graph_header(
    alloc: GraphAllocator,
) -> str:
    input_gmios = '\n    '.join([
        f'adf::input_gmio gmio_in_col{col}_ch{ch.id};'
        for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
        for _, ch in alloc.output_ports(AieTile(TileType.Shim, col, 0))
    ])
    output_gmios = '\n    '.join([
        f'adf::output_gmio gmio_out_col{col}_ch{ch.id};'
        for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
        for _, ch in alloc.input_ports(AieTile(TileType.Shim, col, 0))
    ])
    cores = '\n    '.join([
        f'adf::kernel core_col{col}_row{row};'
        for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
        for row in range(alloc.shape.start_row, alloc.shape.start_row + alloc.shape.num_rows)
    ])
    memtiles = '\n    '.join([
        f'adf::shared_buffer<int8_t> memtile_col{col};'
        for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
    ])
    if config.DEV_GEN == DevGen.Aie2p:
        core_controls = '\n    '.join([
            f'adf::pktcontrol core_pktcontrol_col{col}_row{row};'
            for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
            for row in range(alloc.shape.start_row, alloc.shape.start_row + alloc.shape.num_rows)
        ])
        memtile_controls = '\n    '.join([
            f'adf::pktcontrol memtile_pktcontrol_col{col};'
            for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
        ])
        shim_controls = '\n    '.join([
            f'adf::pktcontrol shim_pktcontrol_col{col};'
            for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
        ])
        shim_splits = '\n    '.join([
            f'adf::pktsplit<{config.NUM_CTRL_PKT_SPLIT}> shim_pktsplit_col{col};'
            for col in range(alloc.shape.start_col, alloc.shape.start_col + alloc.shape.num_cols)
        ])
    elif config.DEV_GEN == DevGen.Aie4:
        core_controls = '// Core Control Removed'
        memtile_controls = '// Memtile Control Removed'
        shim_controls = '// Shim Control Removed'
        shim_splits = '// Shim Splits Removed'
    else:
        assert False
    return f'''
class ComputeGraph : public adf::graph
{{
public:
    {input_gmios}
    {output_gmios}
private:
    {cores}
    {memtiles}
    {core_controls}
    {memtile_controls}
    {shim_controls}
    {shim_splits}

public:
    ComputeGraph()
    {{
'''


def adf_graph_footer() -> str:
    return f'''
    }}
}};
'''


def adf_create_core(
    alloc: GraphAllocator,
    tile: AieTile,
    kernel_entry_point: str,
    kernel_source_filename: str,
    kernel_stack_addr: int
) -> str:
    assert tile.type == TileType.Core
    input_constraints = '\n        '.join([
        smt for idx, dma in alloc.input_ports(tile) for smt in (
            f'adf::location<adf::dma>(core_col{tile.col}_row{tile.row}.in[{idx}]) = adf::dma_channel(adf::aie_tile, {tile.col}, {tile.row}, {dma.id});',
            f'adf::dimensions(core_col{tile.col}_row{tile.row}.in[{idx}]) = {{16}};',
        )
    ])
    output_constraints = '\n        '.join([
        smt for idx, dma in alloc.output_ports(tile) for smt in (
            f'adf::location<adf::dma>(core_col{tile.col}_row{tile.row}.out[{idx}]) = adf::dma_channel(adf::aie_tile, {tile.col}, {tile.row}, {dma.id});',
            f'adf::dimensions(core_col{tile.col}_row{tile.row}.out[{idx}]) = {{16}};',
        )
    ])
    if config.DEV_GEN == DevGen.Aie2p:
        core_controls = '\n        '.join([
            f'core_pktcontrol_col{tile.col}_row{tile.row} = adf::pktcontrol::create();',
            f'adf::location<adf::interconnect>(core_pktcontrol_col{tile.col}_row{tile.row}) = adf::tile(adf::aie_tile, {tile.col}, {tile.row});',
        ])
    elif config.DEV_GEN == DevGen.Aie4:
        core_controls = '// Core Control Removed'
    else:
        assert False
    return f'''
        core_col{tile.col}_row{tile.row} = adf::kernel::create({kernel_entry_point});
        adf::source(core_col{tile.col}_row{tile.row}) = "{kernel_source_filename}";
        adf::runtime<adf::ratio>(core_col{tile.col}_row{tile.row}) = 1.0;
        adf::location<adf::kernel>(core_col{tile.col}_row{tile.row}) = adf::tile({tile.col}, {tile.row});
        adf::location<adf::stack>(core_col{tile.col}_row{tile.row}) = adf::address({tile.col}, {tile.row}, {kernel_stack_addr});
        {input_constraints}
        {output_constraints}
        {core_controls}
'''


def adf_create_memtile(
    alloc: GraphAllocator,
    tile: AieTile
) -> str:
    def ch_expr(channel: DmaChannel) -> str:
        ch = f'{channel.id}'
        return ch
    assert tile.type == TileType.Memtile
    input_constraints = '\n        '.join([
        f'adf::location<adf::dma>(memtile_col{tile.col}.in[{idx}]) = adf::dma_channel(adf::memory_tile, {tile.col}, {tile.row}, {ch_expr(dma)});'
        for idx, dma in alloc.input_ports(tile)
    ])
    output_constraints = '\n        '.join([
        f'adf::location<adf::dma>(memtile_col{tile.col}.out[{idx}]) = adf::dma_channel(adf::memory_tile, {tile.col}, {tile.row}, {ch_expr(dma)});'
        for idx, dma in alloc.output_ports(tile)
    ])
    if config.DEV_GEN == DevGen.Aie2p:
        memtile_controls = '\n        '.join([
            f'memtile_pktcontrol_col{tile.col} = adf::pktcontrol::create();',
            f'adf::location<adf::interconnect>(memtile_pktcontrol_col{tile.col}) = adf::tile(adf::memory_tile, {tile.col}, {tile.row});',
        ])
    elif config.DEV_GEN == DevGen.Aie4:
        memtile_controls = '// Memtile Control Removed'
    else:
        assert False
    return f'''
        memtile_col{tile.col} = adf::shared_buffer<int8_t>::create({{16}}, {len(alloc.input_ports(tile))}, {len(alloc.output_ports(tile))});
        {input_constraints}
        {output_constraints}
        {memtile_controls}
'''


def adf_create_shim(
    alloc: GraphAllocator,
    tile: AieTile
) -> str:
    assert tile.type == TileType.Shim
    mm2s_ids = [ch.id for _, ch in alloc.output_ports(tile)]
    if (config.DEV_GEN == DevGen.Aie2p) and (config.SHIM_CTRL_MM2S_CHANNEL_ID not in mm2s_ids):
        raise ValueError('Invalid graph connections, '
                         f'must use shim mm2s ch{config.SHIM_CTRL_MM2S_CHANNEL_ID}!')
    input_gmios = '\n        '.join([
        smt for _, dma in alloc.output_ports(tile) for smt in (
            f'gmio_in_col{tile.col}_ch{dma.id} = adf::input_gmio::create(256, 8);',
            f'adf::location<adf::GMIO>(gmio_in_col{tile.col}_ch{dma.id}) = adf::shim({tile.col});',
            f'adf::location<adf::dma>(gmio_in_col{tile.col}_ch{dma.id}.out[0]) = adf::dma_channel(adf::shim_tile, {tile.col}, {tile.row}, {dma.id});',
        )
    ])
    output_gmios = '\n        '.join([
        smt for _, dma in alloc.input_ports(tile) for smt in (
            f'gmio_out_col{tile.col}_ch{dma.id} = adf::output_gmio::create(256, 8);',
            f'adf::location<adf::GMIO>(gmio_out_col{tile.col}_ch{dma.id}) = adf::shim({tile.col});',
            f'adf::location<adf::dma>(gmio_out_col{tile.col}_ch{dma.id}.in[0]) = adf::dma_channel(adf::shim_tile, {tile.col}, {tile.row}, {dma.id});',
        )
    ])
    if config.DEV_GEN == DevGen.Aie2p:
        shim_controls = '\n        '.join([
            f'shim_pktcontrol_col{tile.col} = adf::pktcontrol::create();',
            f'adf::location<adf::interconnect>(shim_pktcontrol_col{tile.col}) = adf::tile(adf::shim_tile, {tile.col}, {tile.row});',
        ])
        shim_split = f'shim_pktsplit_col{tile.col} = adf::pktsplit<{config.NUM_CTRL_PKT_SPLIT}>::create();'
        split_cxns = '\n        '.join([
            f'adf::connect(gmio_in_col{tile.col}_ch{config.SHIM_CTRL_MM2S_CHANNEL_ID}.out[0], shim_pktsplit_col{tile.col}.in[0]);',
            f'adf::connect(shim_pktsplit_col{tile.col}.out[{config.SHIM_CTRL_PKT_SPLIT_IDX}], shim_pktcontrol_col{tile.col}.in[0]);',
            f'adf::connect(shim_pktsplit_col{tile.col}.out[{config.MEMTILE_CTRL_PKT_SPLIT_IDX}], memtile_pktcontrol_col{tile.col}.in[0]);',
        ] + ([
            f'adf::connect(shim_pktsplit_col{tile.col}.out[{config.CORE_CTRL_PKT_SPLIT_IDX[row]}], core_pktcontrol_col{tile.col}_row{row}.in[0]);'
            for row in range(alloc.shape.start_row, alloc.shape.start_row + alloc.shape.num_rows)
        ] if config.ENABLE_FAST_PM else [
            f'adf::connect(shim_pktsplit_col{tile.col}.out[{config.CORE_CTRL_PKT_SPLIT_IDX}], core_pktcontrol_col{tile.col}_row{row}.in[0]);'
            for row in range(alloc.shape.start_row, alloc.shape.start_row + alloc.shape.num_rows)
        ]))
    elif config.DEV_GEN == DevGen.Aie4:
        shim_controls = '// Shim Control Removed'
        shim_split = '// Shim Split Removed'
        split_cxns = '// Split Connections Removed'
    else:
        assert False
    return f'''
        {input_gmios}
        {output_gmios}
        {shim_controls}
        {shim_split}
        {split_cxns}
'''


def adf_create_connection(
    alloc: GraphAllocator,
    dma_connection: DmaConnection
) -> str:
    def dma_to_var(dma: AieDma) -> str:
        idx = alloc.port_index(dma)
        if dma.tile.type == TileType.Core:
            if dma.channel.dir == DmaDir.S2MM:
                var = f'core_col{dma.tile.col}_row{dma.tile.row}.in[{idx}]'
            else:
                var = f'core_col{dma.tile.col}_row{dma.tile.row}.out[{idx}]'
        elif dma.tile.type == TileType.Memtile:
            if dma.channel.dir == DmaDir.S2MM:
                var = f'memtile_col{dma.tile.col}.in[{idx}]'
            else:
                var = f'memtile_col{dma.tile.col}.out[{idx}]'
        else:
            if dma.channel.dir == DmaDir.S2MM:
                var = f'gmio_out_col{dma.tile.col}_ch{dma.channel.id}.in[0]'
            else:
                if (config.DEV_GEN == DevGen.Aie2p) and (dma.channel.id == config.SHIM_CTRL_MM2S_CHANNEL_ID):
                    var = f'shim_pktsplit_col{dma.tile.col}.out[{config.DATA_TRANSFER_PKT_SPLIT_IDX}]'
                else:
                    var = f'gmio_in_col{dma.tile.col}_ch{dma.channel.id}.out[0]'
        return var
    return f'''
        adf::connect({dma_to_var(dma_connection.read_dma)}, {dma_to_var(dma_connection.write_dma)});
'''


def adf_create_casc_connection(
    alloc: GraphAllocator,
    src_core: AieTile,
    dst_core: AieTile,
) -> str:
    assert src_core.type == TileType.Core
    assert dst_core.type == TileType.Core
    # NOTE: Cascade ports will be located after
    # the last DMA port in the function argument list
    src_index = alloc.tile_output_counts[src_core]
    dst_index = alloc.tile_input_counts[dst_core]
    src_var = f'core_col{src_core.col}_row{src_core.row}.out[{src_index}]'
    dst_var = f'core_col{dst_core.col}_row{dst_core.row}.in[{dst_index}]'
    return f'''
        adf::connect({src_var}, {dst_var});
'''


def adf_create_core_connection(
    alloc: GraphAllocator,
    cxn: CoreConnection,
    shape: OverlayShape,
    casc_dir: Optional[CascDir],
) -> str:
    src_has_casc_out = (
        ((casc_dir == CascDir.Horizontal) and (cxn.src_core.col < shape.start_col + shape.num_cols - 1)) or
        ((casc_dir == CascDir.Vertical) and (cxn.src_core.row > shape.start_row))
    )
    dst_has_casc_in = (
        ((casc_dir == CascDir.Horizontal) and (cxn.dst_core.col > shape.start_col)) or
        ((casc_dir == CascDir.Vertical) and (cxn.dst_core.row < shape.start_row + shape.num_rows - 1))
    )
    # NOTE: Core ports will be located after
    # the last cascade port in the function argument list
    src_index = alloc.tile_output_counts[cxn.src_core] + int(src_has_casc_out)
    dst_index = alloc.tile_input_counts[cxn.dst_core] + int(dst_has_casc_in)
    src_var = f'core_col{cxn.src_core.col}_row{cxn.src_core.row}.out[{src_index}]'
    dst_var = f'core_col{cxn.dst_core.col}_row{cxn.dst_core.row}.in[{dst_index}]'
    return f'''
        adf::connect({src_var}, {dst_var});
'''


def generate_overlay_graph(
    shape: OverlayShape,
    dma_connections: List[DmaConnection],
    kernel_source_filename: str = 'super.cc',
    core_stack_addr: int = 61440,
    casc_dir: Optional[CascDir] = None,
    core_connections: List[CoreConnection] = [],
) -> str:
    assert 0 <= core_stack_addr <= config.MAX_CORE_ADDR

    alloc = GraphAllocator(shape)
    for cxn in dma_connections:
        alloc.alloc_dma(cxn.read_dma)
        alloc.alloc_dma(cxn.write_dma)

    code = ''

    code += adf_graph_header(alloc)

    src_cores = set([cxn.src_core for cxn in core_connections])
    dst_cores = set([cxn.dst_core for cxn in core_connections])

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        for row in range(shape.start_row, shape.start_row + shape.num_rows):
            num_inputs = len(alloc.input_ports(AieTile(TileType.Core, col, row)))
            num_outputs = len(alloc.output_ports(AieTile(TileType.Core, col, row)))
            has_casc_in = (
                ((casc_dir == CascDir.Horizontal) and (col > shape.start_col)) or
                ((casc_dir == CascDir.Vertical) and (row < shape.start_row + shape.num_rows - 1))
            )
            has_casc_out = (
                ((casc_dir == CascDir.Horizontal) and (col < shape.start_col + shape.num_cols - 1)) or
                ((casc_dir == CascDir.Vertical) and (row > shape.start_row))
            )
            has_core_in = AieTile(TileType.Core, col, row) in dst_cores
            has_core_out = AieTile(TileType.Core, col, row) in src_cores
            entry_point = generate_super_kernel_name(
                num_inputs, num_outputs,
                has_casc_in, has_casc_out,
                has_core_in, has_core_out,
            )
            code += adf_create_core(alloc, AieTile(TileType.Core, col, row),
                                    entry_point, kernel_source_filename,
                                    core_stack_addr)

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        code += adf_create_memtile(alloc, AieTile(TileType.Memtile, col, 0))

    for col in range(shape.start_col, shape.start_col + shape.num_cols):
        code += adf_create_shim(alloc, AieTile(TileType.Shim, col, 0))

    for cxn in dma_connections:
        code += adf_create_connection(alloc, cxn)

    if casc_dir == CascDir.Horizontal:
        for col in range(shape.start_col, shape.start_col + shape.num_cols - 1):
            for row in range(shape.start_row, shape.start_row + shape.num_rows):
                code += adf_create_casc_connection(
                    alloc,
                    AieTile(TileType.Core, col, row),
                    AieTile(TileType.Core, col + 1, row),
                )
    elif casc_dir == CascDir.Vertical:
        for col in range(shape.start_col, shape.start_col + shape.num_cols):
            for row in range(shape.start_row + shape.num_rows - 1, shape.start_row, -1):
                code += adf_create_casc_connection(
                    alloc,
                    AieTile(TileType.Core, col, row),
                    AieTile(TileType.Core, col, row - 1),
                )

    for cxn in core_connections:
        code += adf_create_core_connection(alloc, cxn, shape, casc_dir)

    code += adf_graph_footer()

    return code


################################################################################
#
# Code Generation for Super Kernel
#
################################################################################


def generate_super_kernel_include_directives(
    include_paths: List[str]
) -> str:
    kernel_includes = '\n'.join([
        f'#include "{path}"' for path in include_paths
    ])
    return f'''
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include <aie_api/utils.hpp>

{kernel_includes}'''

def generate_control_bd(param_channel_id: int):
    assert 0 <= param_channel_id <= config.MAX_CORE_S2MM_DMA_CHANNEL
    return f'''
static constexpr unsigned g_control_buffer_id = {param_channel_id};

alignas(4) static BdConfig g_control_bd[1] = {{{{
    .opcode = {CoreInstr.BD_CONFIG_OP},
    .repeat_count = 1,
    .buffer_id = g_control_buffer_id,
    .ping_addr = 0,
    .pong_addr = {ConfigBuffer.DISABLE_PONG_ADDR},
    .length = {config.MAX_CORE_LAYER_PARAM_SIZE} / sizeof(uint32_t),
    .offset = 0,
    .d0_step = 1,
    .d1_step = 1,
    .d2_step = 1,
    .d0_wrap = 0,
    .d1_wrap = 0
}}}};'''

def generate_super_kernel_buffer_ports():
    num_ch_ids = (
        config.MAX_CORE_S2MM_DMA_CHANNEL + 1 +
        config.MAX_CORE_MM2S_DMA_CHANNEL + 1
    )
    ports = []
    for i in range(num_ch_ids):
        dma_lock_id  = (2 * i) + 0
        core_lock_id = (2 * i) + 1
        ping_bd_id = (2 * i) + 0
        pong_bd_id = (2 * i) + 1
        ports.append(f'BufferPort({dma_lock_id}, {core_lock_id}, {ping_bd_id}, {pong_bd_id})')
    port_list = ',\n    '.join(ports)
    return f'''
static BufferPort g_buffer_ports[{num_ch_ids}] = {{
    {port_list}
}};'''

'''
Kernel names can be List / Dict
Added Dict support - mainly used in generating combined xclbin with multiple PM ID's
'''
def generate_super_kernel_fps(kernel_names: Union[List[str], Dict[str, int]]):
    if len(kernel_names) > config.MAX_CORE_NUM_KERNELS:
        raise RuntimeError('Invalid number of kernels!')

    if isinstance(kernel_names, list):
        kernel_list = ',\n    '.join([
        name for name in kernel_names])
    else:
        ids = [id for _, id in kernel_names.items()]
        kernel_names_list = list(kernel_names.items())
        max_id = max(ids)
        kernels = [kernel_names_list[ids.index(i)][0] if i in ids else '0' for i in range(max_id + 1)]
        kernel_list = ',\n    '.join(kernels)

    return f'''
static KernelFp g_kernel_fps[{config.MAX_CORE_NUM_KERNELS}] = {{
    {kernel_list}
}};'''


def generate_super_kernel_name(
    num_inputs: int,
    num_outputs: int,
    has_casc_in: bool,
    has_casc_out: bool,
    has_core_in: bool,
    has_core_out: bool,
) -> str:
    return (
        f'super_kernel'
        f'_in{num_inputs}_out{num_outputs}'
        f'_casc_in{int(has_casc_in)}_out{int(has_casc_out)}'
        f'_core_in{int(has_core_in)}_out{int(has_core_out)}'
    )


def generate_super_kernel_entry_point_decl(
    num_inputs: int,
    num_outputs: int,
    has_casc_in: bool,
    has_casc_out: bool,
    has_core_in: bool,
    has_core_out: bool,
) -> str:
    input_buffer_params = ',\n    '.join([
        f'adf::input_async_buffer<int8_t>& buf_in{i}'
        for i in range(num_inputs)
    ])
    input_casc_params = ',\n    input_cascade<acc32>* casc_in' if has_casc_in else ''
    input_core_params = ',\n    input_stream_int8* core_in' if has_core_in else ''
    output_buffer_params = ',\n    '.join([
        f'adf::output_async_buffer<int8_t>& buf_out{i}'
        for i in range(num_outputs)
    ])
    output_casc_params = ',\n    output_cascade<acc32>* casc_out' if has_casc_out else ''
    output_core_params = ',\n    output_stream_int8* core_out' if has_core_out else ''
    return f'''
void {generate_super_kernel_name(num_inputs, num_outputs, has_casc_in, has_casc_out, has_core_in, has_core_out)}(
    {input_buffer_params}{input_casc_params}{input_core_params},
    {output_buffer_params}{output_casc_params}{output_core_params})'''


def generate_super_kernel_entry_point_impl(
    num_inputs: int,
    num_outputs: int,
    has_casc_in: bool,
    has_casc_out: bool,
    has_core_in: bool,
    has_core_out: bool
) -> str:
    return f'''{generate_super_kernel_entry_point_decl(num_inputs, num_outputs, has_casc_in, has_casc_out, has_core_in, has_core_out)}
{{
    int constexpr core_local_addr_offset = BufferPort::CORE_LOCAL_ADDR_OFFSET;
    g_control_bd[0].ping_addr = (int(g_layer_params) - core_local_addr_offset) / sizeof(uint32_t);
    g_buffer_ports[g_control_buffer_id].config(g_control_bd[0]);
    g_buffer_ports[g_control_buffer_id].acquire();
    g_buffer_ports[g_control_buffer_id].release();
    super_kernel_loop(g_buffer_ports, g_kernel_fps, g_layer_params, 0);
}}'''


def generate_super_kernel(
    include_paths: List[str],
    kernel_names: Union[List[str], Dict[str, int]],
    param_channel_id: int = 0
) -> Tuple[str, str]:
    currdir = os.path.dirname(__file__)
    super_hh_impl_filename = os.path.abspath(os.path.join(currdir, 'kernel_runtime', 'super.hh'))
    super_cc_impl_filename = os.path.abspath(os.path.join(currdir, 'kernel_runtime', 'super.cc'))
    with open(super_hh_impl_filename, 'r') as f:
        super_hh_impl = f.read().strip()
    with open(super_cc_impl_filename, 'r') as f:
        super_cc_impl = f.read().strip()

    include_directives = generate_super_kernel_include_directives(include_paths)

    max_num_inputs = config.MAX_CORE_S2MM_DMA_CHANNEL + 1
    max_num_outputs = config.MAX_CORE_MM2S_DMA_CHANNEL + 1

    entry_point_decls = '\n'.join([
        f'{generate_super_kernel_entry_point_decl(i, j, ai, ao, ci, co)};'
        for i in range(1, max_num_inputs + 1)
        for j in range(1, max_num_outputs + 1)
        for ai in (False, True)
        for ao in (False, True)
        for ci in (False, True)
        for co in (False, True)
    ])

    control_bd = generate_control_bd(param_channel_id)

    buffer_ports = generate_super_kernel_buffer_ports()

    kernel_fps = generate_super_kernel_fps(kernel_names)

    layer_params = f'''
alignas(64) static char g_layer_params[{config.MAX_CORE_LAYER_PARAM_SIZE}] = {{0}};'''

    entry_point_impls = '\n'.join([
        f'{generate_super_kernel_entry_point_impl(i, j, ai, ao, ci, co)}'
        for i in range(1, max_num_inputs + 1)
        for j in range(1, max_num_outputs + 1)
        for ai in (False, True)
        for ao in (False, True)
        for ci in (False, True)
        for co in (False, True)
    ])

    super_hh = f'''#ifndef SUPER_HH
#define SUPER_HH

{super_hh_impl}
{entry_point_decls}

#endif // SUPER_HH
'''

    super_cc = f'''{include_directives}

{super_cc_impl}
{buffer_ports}
{kernel_fps}
{layer_params}
{control_bd}
{entry_point_impls}
'''
    return (super_hh, super_cc)
