import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, TransferParams, core_dma, memtile_dma, shim_dma, \
    compute_buffer_size, \
    generate_core_buffer_config, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_4x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    ceildiv, \
    shim_alloc, \
    prm_shim_memory, \
    conv_kernel_prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    conv_kernel_prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s

from conv_common import \
    ConvDims, \
    iceil, \
    X_index, \
    Co_index, \
    Xi_slice, \
    ifm_core_memory, \
    ifm_core_s2mm, \
    conv_preproc_directives, \
    conv_core_alloc, \
    conv_core_instrs
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def Yi_slice(dims: ConvDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yi_split = dims.Yos * dims.Sy
    Yi_stride = dims.aie_cols * Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride) - dims.Py_b
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.Yi else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yi)) - max(0, min(Yi_start, dims.Yi))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)

def Yo_slice(dims: ConvDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yo_stride = dims.aie_cols * dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo) if Yo_start <= dims.Yo else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)

def Yi_split_iters(dims: ConvDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.Yi))
        return has_no_padding
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def Yo_split_iters(dims: ConvDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter + num_iters - 1)
        is_full_slice = Yo_size == dims.Yos
        return is_full_slice
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def Yi_repeat_counts(dims: ConvDims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    for start_iter, num_iters in Yi_split_iters(dims, col):
        if num_iters * scale < 1024:
            repeat_counts[start_iter] = num_iters * scale
        else:
            for iter in range(start_iter, start_iter + num_iters):
                if iter < dims.Y_loop:
                    repeat_counts[iter] = scale
    return repeat_counts

def Yo_repeat_counts(dims: ConvDims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    for start_iter, num_iters in Yo_split_iters(dims, col):
        repeat_counts[start_iter] = num_iters * scale
    return repeat_counts

def ifm_chain_length(dims: ConvDims) -> int:
    max_ifm_reuse = (2**6 - 1) // dims.aie_rows
    length = 1
    max_chain_length = 4
    for length in range(1, max_chain_length + 1):
        is_valid = (
            (dims.Co_loop % length == 0) and
            (dims.Co_loop // length <= max_ifm_reuse)
        )
        if is_valid: return length
    raise RuntimeError('Failed to allocate IFM chain!')

def wgt_chain_length(wgt_reuse: int) -> int:
    max_wgt_reuse = (2**6 - 1)
    max_chain_length = 4
    for length in range(1, max_chain_length + 1):
        is_valid = (
            (wgt_reuse % length == 0) and
            (wgt_reuse // length <= max_wgt_reuse)
        )
        if is_valid: return length
    raise RuntimeError('Failed to allocate WGT chain!')

def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def generate_packed_shim_data_transfer(
    repeat_counts: List[int],
    dma: AieDma,
    shim_buffer_idx: int,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
) -> DataTransfer:
    '''
    Reconfigures a BD with different transfer
    params at the shim for poll and re-enqueue
    '''
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = []
    repeat_coeff0 = 0
    for i in range(num_fmts):
        repeat_coeff, transfer_chain = generate_transfer_params(
                dma,
                memory_fmts[i],
                tiling_fmts[i],
                bits_per_block=bits_per_elem,
                enable_padding=False,
                use_iter_step=True,
                max_chain_length=4
        )
        repeat_coeff0 = repeat_coeff if i == 0 else repeat_coeff0
        for transfer in transfer_chain:
            params.append(transfer)
    packed_read_params = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
        iter_step=pack([param.iter_step_i(0) for param in params]),
        iter_wrap=pack([param.iter_wrap_i(0) for param in params]),
    )
    buffer_size = compute_buffer_size(memory_fmts[0], bits_per_elem)
    write_params = []
    return DataTransfer(
        [count * repeat_coeff0 for count in repeat_counts],
        dma.tile, [shim_buffer_idx], buffer_size,
        write_params,
        [packed_read_params]
    )

def ifm_shim_memory(dims: ConvDims) -> str:
    if dims.ifm_use_hwc_format:
        return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'
    else:
        return f'Yi:{dims.Yi} Ci:{dims.Ci} Xi:{dims.Xi} Ci:{dims.Ci_gran}'

def ifm_shim_mm2s(dims: ConvDims, col: int) -> List[str]:
    def fmt(start_iter: int, num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice(dims, col, start_iter)
        if dims.ifm_use_hwc_format:
            return (
                f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
                f'Ci:0:{dims.Ci}:{dims.Cim} '
                f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Xi:0:{dims.Xi} Ci:0:{dims.Cim}'
            )
        else:
            return (
                f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
                f'Ci:0:{dims.Ci}:{dims.Cim} '
                f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Ci:0:{dims.Cim}:{dims.Ci_gran} Xi:0:{dims.Xi} Ci:0:{dims.Ci_gran}'
            )
    if dims.enable_ifm_streaming and (dims.Co_loop >= 1) and (not dims.is_standalone_dwc):
        fs = [fmt(s, 1) for s in range(dims.Y_loop)]
    else:
        fs = [fmt(s, n) for s, n in Yi_split_iters(dims, col)]
    return fs

def ifm_shim_repeat_counts(dims: ConvDims, idx: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    if dims.enable_ifm_streaming and (dims.Co_loop >= 1) and (not dims.is_standalone_dwc):
        repeat_counts[idx] = dims.Co_loop
    else:
        repeat_counts[0] = 1
    return repeat_counts

def ifm_memtile_memory(dims: ConvDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        return f'Yi:{Yi_size} Ci:{dims.Cim} Xi:{dims.Xi} Ci:{dims.Ci_gran}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_s2mm(dims: ConvDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice(dims, col, start_iter)
        if dims.ifm_use_hwc_format:
            return f'Yi:0:{Yi_size} Xi:0:{dims.Xi} Ci:0:{dims.Cim}'
        else:
            return f'Yi:0:{Yi_size} Ci:0:{dims.Cim}:{dims.Ci_gran} Xi:0:{dims.Xi} Ci:0:{dims.Ci_gran}'
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def ifm_memtile_mm2s(dims: ConvDims, col: int, row: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        Yi_start, _, _, _ = Yi_slice(dims, col, start_iter)
        Xi_start, Xi_stop, _, _ = Xi_slice(dims, row)
        if dims.is_standalone_dwc:
            assert dims.Cis == dims.Cos
            Ci_start = Co_index(dims, row) * dims.Cis
            Ci_stop = Ci_start + dims.Cis
            Ci_stride = dims.Cis * dims.Co_split
        else:
            Ci_start = 0
            Ci_stop = dims.Cis
            Ci_stride = dims.Cis
        return (
            f'Ci:0:{dims.Cim}:{Ci_stride} '
            f'Yi:{min(Yi_start, 0)}:{min(Yi_start, 0) + dims.Yis} Ci:{Ci_start}:{Ci_stop}:{dims.Ci_gran} Xi:{Xi_start}:{Xi_stop} Ci:0:{dims.Ci_gran}'
        )
    fs = [fmt(s) for s, _ in Yi_split_iters(dims, col)]
    return fs

def wgt_shim_memory(dims: ConvDims) -> str:
    return f'Cos:{dims.Co_loop * dims.Co_split} Cis:{dims.Ci_loop} Subv:{dims.wgt_subv_size}'

def wgt_shim_mm2s(dims: ConvDims, row: int) -> str:
    Co_start = Co_index(dims, row)
    Co_stop = Co_start + 1
    return f'Cos:0:{dims.Co_loop * dims.Co_split}:{dims.Co_split} Cos:{Co_start}:{Co_stop} Cis Subv'

def wgt_memtile_memory(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return f'Cos:{dims.Co_loop} Cis:{dims.Ci_loop} Subv:{dims.wgt_subv_size}'
    else:
        return f'Subv:{dims.wgt_subv_size}'

def wgt_memtile_s2mm(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return 'Cos Cis Subv'
    else:
        return f'Subv'

def wgt_memtile_mm2s(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return 'Cos Cis Subv'
    else:
        return f'Subv'

def ofm_shim_memory(dims: ConvDims) -> str:
    if dims.ofm_use_hwc_format:
        return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'
    else:
        return f'Yo:{dims.Yo} Co:{dims.Co} Xo:{dims.Xo} Co:{dims.Co_gran}'

def ofm_shim_s2mm(dims: ConvDims, col: int) -> List[str]:
    assert 0 <= col < dims.aie_cols
    def fmt(start_iter: int, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice(dims, col, start_iter)
        if dims.ofm_use_hwc_format:
            return (
                f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
                f'Co:0:{dims.Co}:{min(dims.Com, dims.Co)} '
                f'Yo:{Yo_start}:{Yo_stop} Xo:0:{dims.Xo} Co:0:{min(dims.Com, dims.Co)}'
            )
        else:
            return (
                f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
                f'Co:0:{dims.Co}:{min(dims.Com, dims.Co)} '
                f'Yo:{Yo_start}:{Yo_stop} Co:0:{min(dims.Com, dims.Co)}:{dims.Co_gran} Xo:0:{dims.Xo} Co:0:{dims.Co_gran}'
            )
    fs = [fmt(s, n) for s, n in Yo_split_iters(dims, col)]
    return fs

def ofm_memtile_memory(dims: ConvDims, col: int) -> List[str]:
    fmt = f'Yo:{dims.Yos} Xo:{dims.Xom} Co:{dims.Com}'
    fs = [fmt] * len(Yo_split_iters(dims, col))
    return fs

def ofm_memtile_s2mm(dims: ConvDims, col: int, row: int) -> List[str]:
    Xo_start = X_index(dims, row) * dims.Xos
    Xo_stop = Xo_start + dims.Xos
    Co_start = Co_index(dims, row) * dims.Cos
    Co_stop = Co_start + dims.Cos
    fmt = f'Yo Co:{Co_start}:{Co_stop}:{dims.Co_gran} Xo:{Xo_start}:{Xo_stop} Co:0:{dims.Co_gran}'
    fs = [fmt] * len(Yo_split_iters(dims, col))
    return fs

def ofm_memtile_mm2s(dims: ConvDims, col: int) -> List[str]:
    def fmt(start_iter: int) -> str:
        _, _, _, Yo_size = Yo_slice(dims, col, start_iter)
        if dims.ofm_use_hwc_format:
            return f'Yo:0:{Yo_size} Xo:0:{dims.Xo} Co:0:{min(dims.Com, dims.Co)}'
        else:
            return f'Yo:0:{Yo_size} Co:0:{min(dims.Com, dims.Co)}:{dims.Co_gran} Xo:0:{dims.Xo} Co:0:{dims.Co_gran}'
    fs = [fmt(s) for s, _ in Yo_split_iters(dims, col)]
    return fs

def split_memtile_ifm_subv(
    dims: ConvDims,
    size_cutoff: int,
) -> int:
    if dims.is_standalone_dwc:
        loop_count = dims.Co_loop
        subv_size = (dims.Yis * dims.Xi * dims.Cos * dims.Co_split * dims.ifm_bits) // 8
    else:
        loop_count = dims.Ci_loop
        subv_size = (dims.Yis * dims.Xi * dims.Cis * dims.ifm_bits) // 8
    for num_ifm_subv in range(loop_count, 0, -1):
        memtile_size = subv_size * num_ifm_subv
        is_valid = (
            (memtile_size <= size_cutoff) and
            ((loop_count % num_ifm_subv) == 0)
        )
        if is_valid:
            return num_ifm_subv
    return 1

def compile_dataflow(
    dims: ConvDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    # Core Buffer Allocation

    core_alloc = conv_core_alloc(dims, overlay_stack_addr())

    if core_alloc.ofm_pong_addr is not None:
        memtile_ofm_buffering = 2
    else:
        memtile_ofm_buffering = 1
    has_memtile_ofm_double_buffer = memtile_ofm_buffering == 2

    ofm_memtile_ping_addr = None
    ofm_memtile_pong_addr = None
    ofm_memtile_addr = None
    if dims.enable_ifm_streaming:
        prm_memtile_addr = 0
        conv_kernel_param_memtile_addr = prm_memtile_addr + dims.prm_memtile_size
        ifm_memtile_ping_addr = ((conv_kernel_param_memtile_addr + dims.conv_kernel_param_size ) if dims.is_xint8 else (prm_memtile_addr + dims.prm_memtile_size))
        ifm_memtile_pong_addr = ifm_memtile_ping_addr + dims.ifm_memtile_size
        if dims.enable_wgt_reuse:
            wgt_memtile_addr = ifm_memtile_pong_addr + dims.ifm_memtile_size
            if has_memtile_ofm_double_buffer:
                ofm_memtile_ping_addr = wgt_memtile_addr + dims.wgt_memtile_size
                ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
            else:
                ofm_memtile_addr = wgt_memtile_addr + dims.wgt_memtile_size
        else:
            wgt_memtile_ping_addr = ifm_memtile_pong_addr + dims.ifm_memtile_size
            wgt_memtile_pong_addr = wgt_memtile_ping_addr + dims.wgt_memtile_size
            if has_memtile_ofm_double_buffer:
                ofm_memtile_ping_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size
                ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
            else:
                ofm_memtile_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size

        ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
        ifm_memtile_repeat_scale = (dims.Co_loop * dims.Ci_loop) // dims.num_ifm_subv
        ifm_memtile_reuse_ratio = 1
    else:
        prm_memtile_addr = 0
        conv_kernel_param_memtile_addr = prm_memtile_addr + dims.prm_memtile_size
        ifm_memtile_addr = ((conv_kernel_param_memtile_addr + dims.conv_kernel_param_size ) if dims.is_xint8 else (prm_memtile_addr + dims.prm_memtile_size))
        wgt_memtile_ping_addr = ifm_memtile_addr + dims.ifm_memtile_size
        wgt_memtile_pong_addr = wgt_memtile_ping_addr + dims.wgt_memtile_size
        if has_memtile_ofm_double_buffer:
            ofm_memtile_ping_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size
            ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
        else:
            ofm_memtile_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size

        ifm_memtile_addrs = [ifm_memtile_addr]
        ifm_memtile_repeat_scale = 1
        if dims.pin_ifm_l1:
            ifm_memtile_reuse_ratio = 1
        else:
            ifm_memtile_reuse_ratio = dims.Co_loop

    if dims.enable_wgt_reuse:
        wgt_memtile_addrs = [wgt_memtile_addr]
        wgt_memtile_repeat_count = 1
        wgt_memtile_repeat_count_list = [wgt_memtile_repeat_count] + [0] * (dims.Y_loop - 1)
        if dims.pin_wgt_bias_l1:
            wgt_memtile_reuse_ratio = 1
        else:
            wgt_memtile_reuse_ratio = dims.Y_loop
        wgt_shim_repeat_count = 1
    else:
        wgt_memtile_addrs = [wgt_memtile_ping_addr, wgt_memtile_pong_addr]
        wgt_memtile_repeat_count = dims.Co_loop * dims.Ci_loop * dims.Y_loop
        wgt_memtile_repeat_count_list = [dims.Co_loop * dims.Ci_loop] * dims.Y_loop if wgt_memtile_repeat_count > 1024 else [wgt_memtile_repeat_count] + [0] * (dims.Y_loop - 1)
        wgt_memtile_reuse_ratio = 1
        wgt_shim_repeat_count = dims.Y_loop

    if has_memtile_ofm_double_buffer:
        assert(ofm_memtile_pong_addr is not None and ofm_memtile_ping_addr is not None)
        ofm_memtile_addrs = [ofm_memtile_ping_addr, ofm_memtile_pong_addr]
    else:
        assert(ofm_memtile_addr is not None)
        ofm_memtile_addrs = [ofm_memtile_addr]

    # Shim Buffer Allocation

    conv_shim_alloc = shim_alloc()

    core_instrs = {}
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            core_instrs[AieTile(TileType.Core, col, row)] = conv_core_instrs(
                dims, core_alloc,
                dims.Y_loop,
                dims.Co_loop,
                dims.Ci_loop,
                ifm_config=generate_core_buffer_config(
                    core_dma(col, row, DmaDir.S2MM, 0),
                    core_alloc.ifm_ping_addr, core_alloc.ifm_pong_addr,
                    ifm_core_memory(dims),
                    ifm_core_s2mm(dims, row),
                    bits_per_block=dims.ifm_bits,
                ),
            )
    bcst_col_step = 2 if (dims.aie_cols, dims.aie_rows) == (8, 4) else 1
    memtile_transfers = [
        DataTransfer(
            [1] + [0] * (dims.Y_loop - 1),
            AieTile(TileType.Memtile, col), [prm_memtile_addr], dims.prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
        ) for col in range(dims.aie_cols)
    ] + [
            DataTransfer(                                                              #NOTE: Added Extra transfer of conv_kernel param size only for xint8 because of kernel requiremnet
            [1] + [0] * (dims.Y_loop - 1),
            AieTile(TileType.Memtile, col), [conv_kernel_param_memtile_addr], dims.conv_kernel_param_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),     #Using Wgt data channel(S2MM-1) to get the kernel_param from shim tile
                conv_kernel_prm_memtile_memory(dims, dims.conv_kernel_param_size),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),     #Using Wgt data channel(MM2S-4) to send the kernel_param to core tile
                conv_kernel_prm_memtile_memory(dims, dims.conv_kernel_param_size),
                prm_memtile_mm2s(0))],                                                    #NOTE: Updating arg of prm_memtile_mm2s to 0, because we always want to transfer from offset = 0
        ) for col in range(0, dims.aie_cols, bcst_col_step) if dims.is_xint8
    ] + [
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, dims.ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims, col),
                ifm_memtile_s2mm(dims, col),
                [n for _, n in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            )],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [n for _, n in Yi_split_iters(dims, col)],
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            for _ in range(ifm_chain_length(dims) if not dims.enable_ifm_streaming else 1)],
            sync_strategy=SyncStrategy.Parallel_1_to_N,
            reuse_ratio=ifm_memtile_reuse_ratio // ifm_chain_length(dims) if not dims.enable_ifm_streaming else ifm_memtile_reuse_ratio,
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            wgt_memtile_repeat_count_list,
            AieTile(TileType.Memtile, col), wgt_memtile_addrs, dims.wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims),
            ) for _ in range(wgt_chain_length(wgt_memtile_reuse_ratio) if dims.enable_wgt_reuse else 1)],
            reuse_ratio=wgt_memtile_reuse_ratio // wgt_chain_length(wgt_memtile_reuse_ratio) if dims.enable_wgt_reuse else wgt_memtile_reuse_ratio,
        ) for col in range(0, dims.aie_cols, bcst_col_step)
    ] + [
        DataTransfer(
            Yo_repeat_counts(dims, col, dims.Co_loop),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, dims.ofm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ofm_memtile_memory(dims, col),
                ofm_memtile_s2mm(dims, col, row),
                [n for _, n in Yo_split_iters(dims, col)],
                dims.ofm_bits,
            ) for row in range(dims.aie_rows)],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [n for _, n in Yo_split_iters(dims, col)],
                dims.ofm_bits,
            )],
            sync_strategy=SyncStrategy.Parallel_N_to_1,
        ) for col in range(dims.aie_cols)
    ]

    max_chain_length = 4
    ifm_shim_transfers = []
    if len(ifm_shim_mm2s(dims, 0)) <= max_chain_length:
        # NOTE: BD Chaining
        ifm_shim_transfers += [generate_shim_data_transfer(
            ifm_shim_repeat_counts(dims, idx),
            shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.ifm_buffer_id,
            ifm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for idx, fmt in enumerate(ifm_shim_mm2s(dims, col))]
    else:
        # NOTE: Poll and Re-enqueue
        ifm_shim_transfers += [generate_packed_shim_data_transfer(
            [dims.Co_loop] * dims.Y_loop,
            shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.ifm_buffer_id,
            [ifm_shim_memory(dims)] * dims.Y_loop,
            ifm_shim_mm2s(dims, col),
            [1] * dims.Y_loop,
            dims.ifm_bits
        ) for col in range(dims.aie_cols)]

    shim_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ] + [
        generate_shim_data_transfer(                                       #NOTE: Adding conv_kernel_param for only xint8 because of conv kernel requirement 
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 1), conv_shim_alloc.wgt_buffer_id,
            conv_kernel_prm_shim_memory(dims,  dims.conv_kernel_param_size),
            prm_shim_mm2s(0),                                              #Making 0 to start from Col0:1 in memory format for all col
            ) for col in range(0,dims.aie_cols,bcst_col_step) if dims.is_xint8
    ] + [
        transfer 
        for transfer in ifm_shim_transfers
    ] + [
        generate_shim_data_transfer(
            [wgt_shim_repeat_count] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.MM2S, 1), conv_shim_alloc.wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(dims, (col // bcst_col_step)),
            buffer_offset= ((dims.conv_kernel_param_size) if dims.is_xint8 else 0),     #NOTE: Adding extra 256 words to offset for only xint8 because of conv_kerenl param.
        ) for col in range(0, dims.aie_cols, bcst_col_step)
    ] + [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.S2MM, 0), conv_shim_alloc.ofm_buffer_id,
            ofm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ofm_bits,
        ) for col in range(dims.aie_cols) for fmt in ofm_shim_s2mm(dims, col)
    ]

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections() if (dims.aie_cols, dims.aie_rows) == (8, 4) else overlay_4x4_dma_connections(),
        back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0
    )

def main():
    Ci, Yi, Xi = 128, 56, 56
    Co, Yo, Xo = 128, 56, 56
    Cis, Yis, Xis = 32, 5, 64
    Cos, Yos, Xos = 32, 4, 32
    Ky, Kx = 3, 3
    Sy, Sx = 1, 1
    Py_b, Px_b, Py_a, Px_a = 1, 1, 1, 1
    X_split = 2
    Co_split = 2

    Ci_gran = 8
    Co_gran = 8
    X_gran = 4
    X_align = 64

    ifm_bits = 16
    wgt_bits = 8
    ofm_bits = 16
    tdm_bits = 32

    bias_bits = 64
    # C1/C2, shift tdm/res, zp_wgt are int32
    qdq_param_bytes = 5 * 4
    wgt_subv_size = (
        iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, X_align) +
        iceil(((Cos * bias_bits) // 8) + qdq_param_bytes, X_align)
    )

    is_standalone_dwc = True
    
    dims = ConvDims(
        Ci, Cis, Ci_gran, Co, Cos, Co_gran, Co_split,
        Yi, Yis, Yo, Yos,
        Xi, Xis, Xo, Xos, X_gran, X_align, X_split,
        Ky, Kx,
        Sy, Sx,
        Py_b, Px_b, Py_a, Px_a,
        ifm_bits, wgt_bits, ofm_bits, tdm_bits,
        wgt_subv_size,
        is_standalone_dwc,
    )

    kernel_names = ['run_conv_a16w8_qdq']
    kernel_includes = [
        'super.hh',
        'conv/direct_conv_int16x8_generic/direct_conv_int16x8_generic_wrapper.cc',
    ]
    
    host_cpp = 'dwc_main.cpp' if is_standalone_dwc else 'conv_main.cpp' 
        
    clean_overlay()
    compile_dataflow(dims, BackEnd.Adf, kernel_names, kernel_includes)
    build_sim_overlay(host_cpp, conv_preproc_directives(dims, BackEnd.Adf))

if __name__ == '__main__':
    main()
