import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, TransferParams, core_dma, memtile_dma, shim_dma, \
    compute_buffer_size, \
    generate_core_buffer_config, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    pack_reconfig_transfers, \
    generate_packed_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_4x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    ceildiv, \
    shim_alloc, \
    prm_shim_memory, \
    conv_kernel_prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    conv_kernel_prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s

from conv_common import \
    ConvDims, \
    iceil, \
    X_index, \
    conv_input, \
    Co_index, \
    Xi_slice, \
    ifm_core_memory, \
    ifm_core_s2mm, \
    conv_preproc_directives, \
    conv_core_alloc, \
    conv_core_instrs

from dataflow_utils import \
    Yi_slice_iter, \
    Xi_slice_iter, \
    Yo_slice_iter, \
    Xo_slice_iter, \
    Yi_split_iters, \
    Yo_split_iters, \
    ifm_chain_length, \
    wgt_chain_length


set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def Yi_repeat_counts(dims: ConvDims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    if dims.is_X8_split:
        for start_iter, num_iters in Yi_split_iters(dims, 0, 64):
            repeat_counts[start_iter] = num_iters * scale
    else:
        for start_iter, num_iters in Yi_split_iters(dims, col):
            if num_iters * scale < 1024:
                repeat_counts[start_iter] = num_iters * scale
            else:
                for iter in range(start_iter, start_iter + num_iters):
                    if iter < dims.Y_loop:
                        repeat_counts[iter] = scale
    return repeat_counts * dims.X_loop

def Yo_repeat_counts(dims: ConvDims, col: int, scale: int, is_mt: bool = True) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    for start_iter, num_iters in Yo_split_iters(dims, 0 if dims.is_X8_split else col):
            if is_mt:
                repeat_counts[start_iter] = num_iters * scale
            else:
                repeat_counts[start_iter] = scale
    return repeat_counts * dims.X_loop

def ifm_shim_memory(dims: ConvDims) -> str:
    if dims.ifm_use_hwc_format:
        return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'
    else:
        return f'Yi:{dims.Yi} Ci:{dims.Ci} Xi:{dims.Xi} Ci:{dims.Ci_gran}'

def ifm_shim_mm2s(dims: ConvDims, col: int) -> List[str]:
    col_sel = 0 if dims.is_X8_split else col
    max_iter = 64 if dims.is_X8_split else 0
    def fmt(start_Y_iter: int, start_X_iter, num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice_iter(dims, col_sel, start_Y_iter)
        Xi_start, Xi_stop, Xi_stride, _ = Xi_slice_iter(dims, col, start_X_iter)
        if dims.ifm_use_hwc_format:
            return (
                f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
                f'Ci:0:{dims.Ci}:{dims.Cim} '
                f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Xi:{max(0, Xi_start)}:{min(Xi_stop, dims.Xi)} Ci:0:{dims.Cim}'
            )
        else:
            return (
                f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
                f'Ci:0:{dims.Ci}:{dims.Cim} '
                f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Ci:0:{dims.Cim}:{dims.Ci_gran} Xi:{max(0, Xi_start)}:{min(Xi_stop, dims.Xi)} Ci:0:{dims.Ci_gran}'
            )
    if dims.enable_ifm_streaming and (dims.Co_loop >= 1) and (not dims.is_standalone_dwc):
        fs = [fmt(s_y, s_x, 1) for s_x in range(dims.X_loop) for s_y in range(dims.Y_loop)]
    else:
        fs = [fmt(s_y, s_x, n) for s_x in range(dims.X_loop) for s_y, n in Yi_split_iters(dims, col_sel, max_iter)]
    return fs

def ifm_shim_repeat_counts(dims: ConvDims, idx: int, loop_iter_mode: bool = True) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    if dims.is_X8_split:
        if dims.enable_ifm_streaming and (dims.Co_loop > 1) and (not dims.is_standalone_dwc):
            repeat_counts = [dims.Co_loop] * len(repeat_counts)
        else:
            for start_iter, _ in Yi_split_iters(dims, 0, 64):
                repeat_counts[start_iter] = 1
    else:
        if loop_iter_mode:
            if dims.enable_ifm_streaming and (dims.Co_loop >= 1) and (not dims.is_standalone_dwc):
                repeat_counts[idx] = dims.Co_loop
            else:
                repeat_counts[0] = 1
        else:
            repeat_counts = [dims.Co_loop] * dims.Y_loop
    return repeat_counts * dims.X_loop

#NOTE :
#       when the  wgt_memtile_repeat_count is too big(>1024), we need to split to multiple with phase
def wgt_repeat_counts(dims: ConvDims, scale: int)  -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    if dims.pin_wgt_bias_l1:
        repeat_counts = repeat_counts * dims.X_loop
        repeat_counts[0] = scale
        return repeat_counts
    else:
        if scale <= 1024:
            repeat_counts[0] = scale
        else:
            phase_split = ceildiv(scale, dims.Y_loop)
            start_iter = 0
            iter_step = 1024 // phase_split
            remain_repeat = scale
            while remain_repeat > 1024:
                repeat_counts[start_iter] = 1024
                remain_repeat = remain_repeat - 1024
                start_iter += iter_step
            repeat_counts[start_iter]= remain_repeat
        return repeat_counts * dims.X_loop

def ifm_memtile_memory(dims: ConvDims, col: int) -> List[str]:
    col_sel = 0 if dims.is_X8_split else col
    def fmt(start_Y_iter: int, start_X_iter) -> str:
        _, _, _, Yi_size = Yi_slice_iter(dims, col_sel, start_Y_iter)
        _, _, _, Xi_size = Xi_slice_iter(dims, col, start_X_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        if Xi_size <= 0:
            Xi_size = dims.Xis
        return f'Yi:{Yi_size} Ci:{dims.Cim} Xi:{Xi_size} Ci:{dims.Ci_gran}'
    fs = [fmt(s_y, s_x) for s_x in range(dims.X_loop) for s_y, _ in Yi_split_iters(dims, col_sel)]
    return fs

def ifm_memtile_s2mm(dims: ConvDims, col: int) -> List[str]:
    col_sel = 0 if dims.is_X8_split else col
    def fmt(start_Y_iter: int, start_X_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice_iter(dims, col_sel, start_Y_iter)
        _, _, _, Xi_size = Xi_slice_iter(dims, col, start_X_iter)
        if dims.ifm_use_hwc_format:
            return f'Yi:0:{Yi_size} Xi:0:{Xi_size} Ci:0:{dims.Cim}'
        else:
            return f'Yi:0:{Yi_size} Ci:0:{dims.Cim}:{dims.Ci_gran} Xi:0:{Xi_size} Ci:0:{dims.Ci_gran}'
    fs = [fmt(s_y, s_x) for s_x in range(dims.X_loop) for s_y, _ in Yi_split_iters(dims, col_sel)]
    return fs

def ifm_memtile_mm2s(dims: ConvDims, col: int, row: int) -> List[str]:
    col_sel = 0 if dims.is_X8_split else col
    def fmt(start_Y_iter: int, start_X_iter: int) -> str:
        Yi_start, _, _, _ = Yi_slice_iter(dims, col_sel, start_Y_iter)
        if dims.is_X8_split:
            Xi_start, _, _, _ = Xi_slice_iter(dims, col, start_X_iter)
            Xi_start = min(Xi_start, 0)
            Xi_stop = Xi_start + dims.Xis
        else:
            Xi_start, Xi_stop, _, _ = Xi_slice(dims, row)
        if dims.is_standalone_dwc:
            assert dims.Cis == dims.Cos
            Ci_start = Co_index(dims, row) * dims.Cis
            Ci_stop = Ci_start + dims.Cis
            Ci_stride = dims.Cis * dims.Co_split
        else:
            Ci_start = 0
            Ci_stop = dims.Cis
            Ci_stride = dims.Cis
        return (
            f'Ci:0:{dims.Cim}:{Ci_stride} '
            f'Yi:{min(Yi_start, 0)}:{min(Yi_start, 0) + dims.Yis} Ci:{Ci_start}:{Ci_stop}:{dims.Ci_gran} Xi:{Xi_start}:{Xi_stop} Ci:0:{dims.Ci_gran}'
        )
    fs = [fmt(s_y, s_x) for s_x in range(dims.X_loop) for s_y, _ in Yi_split_iters(dims, col_sel)]
    return fs

def wgt_shim_memory(dims: ConvDims) -> str:
    return f'Cos:{dims.Co_loop * dims.Co_split} Cis:{dims.Ci_loop} Subv:{dims.wgt_subv_size}'

def wgt_shim_mm2s(dims: ConvDims, row: int) -> str:
    Co_start = Co_index(dims, row)
    Co_stop = Co_start + 1
    return f'Cos:0:{dims.Co_loop * dims.Co_split}:{dims.Co_split} Cos:{Co_start}:{Co_stop} Cis Subv'

def wgt_memtile_memory(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return f'Cos:{dims.Co_loop} Cis:{dims.Ci_loop} Subv:{dims.wgt_subv_size}'
    else:
        return f'Subv:{dims.wgt_subv_size if not dims.is_X8_split else dims.wgt_memtile_size}'

def wgt_memtile_s2mm(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return 'Cos Cis Subv'
    else:
        return f'Subv'

def wgt_memtile_mm2s(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return 'Cos Cis Subv'
    else:
        return f'Subv'

def ofm_shim_memory(dims: ConvDims) -> str:
    if dims.ofm_use_hwc_format:
        return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'
    else:
        return f'Yo:{dims.Yo} Co:{dims.Co} Xo:{dims.Xo} Co:{dims.Co_gran}'

def ofm_shim_s2mm(dims: ConvDims, col: int, Co_split: int) -> List[str]:
    col_sel = 0 if dims.is_X8_split else col
    assert 0 <= col < dims.aie_cols
    def fmt(start_Y_iter: int, start_X_iter: int, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice_iter(dims, col_sel, start_Y_iter)
        Xo_start, Xo_stop, Xo_stride, _ = Xo_slice_iter(dims, col, start_X_iter)
        if dims.is_X8_split:
            Co_stride = min(dims.Com * dims.mt_co_pack, dims.Co)
            Co_start = Co_split * Co_stride
            Co_stop  = Co_start + Co_stride
        else:
            Co_start = 0
            Co_stop = min(dims.Com, dims.Co)
            Co_stride = min(dims.Com, dims.Co)
        if dims.ofm_use_hwc_format:
            return (
                f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
                f'Co:0:{Co_stride if dims.is_X8_split else dims.Co}:{Co_stride} '
                f'Yo:{Yo_start}:{Yo_stop} Xo:{Xo_start}:{Xo_stop} Co:{Co_start}:{Co_stop}'
            )
        else:
            return (
                f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
                f'Co:0:{Co_stride if dims.is_X8_split else dims.Co}:{Co_stride} '
                f'Yo:{Yo_start}:{Yo_stop} Co:{Co_start}:{Co_stop}:{dims.Co_gran} Xo:{Xo_start}:{Xo_stop} Co:0:{dims.Co_gran}'
            )
    fs = [fmt(s_y, s_x, n) for s_x in range(dims.X_loop) for s_y, n in Yo_split_iters(dims, col_sel)]
    return fs

def ofm_memtile_memory(dims: ConvDims, col: int) -> List[str]:
    fmt = f'Yo:{dims.Yos} Xo:{dims.Xom} Co:{dims.Com * dims.mt_co_pack}'
    fs = ([fmt] * len(Yo_split_iters(dims, 0 if dims.is_X8_split else col))) * dims.X_loop
    return fs

def ofm_memtile_s2mm(dims: ConvDims, col: int, row: int, co_index: int) -> List[str]:
    Xo_start = 0 if dims.is_X8_split else X_index(dims, row) * dims.Xos
    Xo_stop = Xo_start + dims.Xos
    Co_start = Co_index(dims, row) * dims.Cos
    Co_stop = Co_start + dims.Cos
    if dims.is_X8_split:
        fmt = f'Co:{co_index * dims.Com}:{(co_index + 1) *dims.Com}:{dims.Com * dims.mt_co_pack} ' + \
              f'Yo Co:{Co_start}:{Co_stop}:{dims.Co_gran} Xo:{Xo_start}:{Xo_stop} Co:0:{dims.Co_gran}'
    else:
        fmt = f'Yo Co:{Co_start}:{Co_stop}:{dims.Co_gran} Xo:{Xo_start}:{Xo_stop} Co:0:{dims.Co_gran}'
    fs = ([fmt] * len(Yo_split_iters(dims, 0 if dims.is_X8_split else col))) * dims.X_loop
    return fs

def ofm_memtile_mm2s(dims: ConvDims, col: int) -> List[str]:
    col_sel = 0 if dims.is_X8_split else col
    def fmt(start_Y_iter: int, start_X_iter: int) -> str:
        _, _, _, Yo_size = Yo_slice_iter(dims, col_sel, start_Y_iter)
        _, _, _, Xo_size = Xo_slice_iter(dims, col, start_X_iter)
        if dims.ofm_use_hwc_format:
            return f'Yo:0:{Yo_size} Xo:0:{Xo_size} Co:0:{min(dims.Com * dims.mt_co_pack, dims.Co)}'
        else:
            return f'Yo:0:{Yo_size} Co:0:{min(dims.Com * dims.mt_co_pack, dims.Co)}:{dims.Co_gran} Xo:0:{Xo_size} Co:0:{dims.Co_gran}'
    fs = [fmt(s_y, s_x) for s_x in range(dims.X_loop) for s_y, _ in Yo_split_iters(dims, col_sel)]
    return fs


def compile_dataflow(
    dims: ConvDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    # Core Buffer Allocation

    core_alloc = conv_core_alloc(dims, overlay_stack_addr())

    if core_alloc.ofm_pong_addr is not None:
        memtile_ofm_buffering = 2
    else:
        memtile_ofm_buffering = 1
    has_memtile_ofm_double_buffer = memtile_ofm_buffering == 2

    ofm_memtile_ping_addr = None
    ofm_memtile_pong_addr = None
    ofm_memtile_addr = None
    if dims.enable_ifm_streaming:
        prm_memtile_addr = 0
        conv_kernel_param_memtile_addr = prm_memtile_addr + dims.prm_memtile_size
        ifm_memtile_ping_addr = ((conv_kernel_param_memtile_addr + dims.conv_kernel_param_size ) if dims.is_xint8 else (prm_memtile_addr + dims.prm_memtile_size))
        ifm_memtile_pong_addr = ifm_memtile_ping_addr + dims.ifm_memtile_size
        if dims.enable_wgt_reuse:
            wgt_memtile_addr = ifm_memtile_pong_addr + dims.ifm_memtile_size
            if has_memtile_ofm_double_buffer:
                ofm_memtile_ping_addr = wgt_memtile_addr + dims.wgt_memtile_size
                ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
            else:
                ofm_memtile_addr = wgt_memtile_addr + dims.wgt_memtile_size
        else:
            wgt_memtile_ping_addr = ifm_memtile_pong_addr + dims.ifm_memtile_size
            wgt_memtile_pong_addr = wgt_memtile_ping_addr + dims.wgt_memtile_size
            if has_memtile_ofm_double_buffer:
                ofm_memtile_ping_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size
                ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
            else:
                ofm_memtile_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size

        ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
        ifm_memtile_repeat_scale = (dims.Co_loop * dims.Ci_loop) // dims.num_ifm_subv
        ifm_memtile_reuse_ratio = 1
    else:
        prm_memtile_addr = 0
        conv_kernel_param_memtile_addr = prm_memtile_addr + dims.prm_memtile_size
        ifm_memtile_addr = ((conv_kernel_param_memtile_addr + dims.conv_kernel_param_size ) if dims.is_xint8 else (prm_memtile_addr + dims.prm_memtile_size))
        if dims.enable_wgt_reuse:
            wgt_memtile_addr = ifm_memtile_addr + dims.ifm_memtile_size
            wgt_memtile_pong_addr = wgt_memtile_addr
        else:
            wgt_memtile_ping_addr = ifm_memtile_addr + dims.ifm_memtile_size
            wgt_memtile_pong_addr = wgt_memtile_ping_addr + dims.wgt_memtile_size
        if has_memtile_ofm_double_buffer:
            ofm_memtile_ping_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size
            ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
        else:
            ofm_memtile_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size

        ifm_memtile_addrs = [ifm_memtile_addr]
        ifm_memtile_repeat_scale = 1
        if dims.pin_ifm_l1:
            ifm_memtile_reuse_ratio = 1
        else:
            ifm_memtile_reuse_ratio = dims.Co_loop

    if dims.enable_wgt_reuse:
        wgt_memtile_addrs = [wgt_memtile_addr]
        wgt_memtile_repeat_count = 1
        wgt_memtile_repeat_count_list = [wgt_memtile_repeat_count] + [0] * (dims.Y_loop - 1)
        if dims.pin_wgt_bias_l1 and not dims.is_X8_split:
            wgt_memtile_reuse_ratio = 1
        else:
            wgt_memtile_reuse_ratio = dims.Y_loop
        wgt_shim_repeat_count = 1
    else:
        wgt_memtile_addrs = [wgt_memtile_ping_addr, wgt_memtile_pong_addr]
        wgt_memtile_repeat_count = dims.Co_loop * dims.Ci_loop * dims.Y_loop if not dims.is_X8_split else \
                                   (dims.Y_loop * dims.Co_loop * dims.Ci_loop // dims.num_pack_wgt_subv ) if not dims.pin_wgt_bias_l1 else 1
        wgt_memtile_repeat_count_list = [dims.Co_loop * dims.Ci_loop] * dims.Y_loop if wgt_memtile_repeat_count > 1024 else [wgt_memtile_repeat_count] + [0] * (dims.Y_loop - 1)
        wgt_memtile_reuse_ratio = 1
        wgt_shim_repeat_count = dims.Y_loop if not dims.is_X8_split else (dims.Y_loop if not dims.pin_wgt_bias_l1 else 1)

    if has_memtile_ofm_double_buffer:
        assert(ofm_memtile_pong_addr is not None and ofm_memtile_ping_addr is not None)
        ofm_memtile_addrs = [ofm_memtile_ping_addr, ofm_memtile_pong_addr]
    else:
        assert(ofm_memtile_addr is not None)
        ofm_memtile_addrs = [ofm_memtile_addr]

    # Shim Buffer Allocation

    conv_shim_alloc = shim_alloc()

    core_instrs = {}
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            core_instrs[AieTile(TileType.Core, col, row)] = conv_core_instrs(
                dims, core_alloc,
                dims.Y_loop * dims.X_loop,
                dims.Co_loop,
                dims.Ci_loop,
                ifm_config=generate_core_buffer_config(
                    core_dma(col, row, DmaDir.S2MM, 0),
                    core_alloc.ifm_ping_addr, core_alloc.ifm_pong_addr,
                    ifm_core_memory(dims),
                    ifm_core_s2mm(dims, row),
                    bits_per_block=dims.ifm_bits,
                ),
            )
    bcst_col_step = 2 if (dims.aie_cols, dims.aie_rows) == (8, 4) else 1
    memtile_transfers = [
        DataTransfer(
            [1] + [0] * (dims.Y_loop * dims.X_loop - 1),
            AieTile(TileType.Memtile, col), [prm_memtile_addr], dims.prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
        ) for col in range(dims.aie_cols)
    ]
    memtile_kernel_params_transfer = [
            DataTransfer(                                                              #NOTE: Added Extra transfer of conv_kernel param size only for xint8 because of kernel requiremnet
            [1] + [0] * (dims.Y_loop * dims.X_loop - 1),
            AieTile(TileType.Memtile, col), [conv_kernel_param_memtile_addr], dims.conv_kernel_param_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),     #Using Wgt data channel(S2MM-1) to get the kernel_param from shim tile
                conv_kernel_prm_memtile_memory(dims, dims.conv_kernel_param_size),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),     #Using Wgt data channel(MM2S-4) to send the kernel_param to core tile
                conv_kernel_prm_memtile_memory(dims, dims.conv_kernel_param_size),
                prm_memtile_mm2s(0))],                                                    #NOTE: Updating arg of prm_memtile_mm2s to 0, because we always want to transfer from offset = 0
        ) for col in range(0, dims.aie_cols, bcst_col_step) if dims.is_xint8
    ]
    memtile_ifm_transfer = [
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, dims.ifm_memtile_size,
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims, col),
                ifm_memtile_s2mm(dims, col),
                [n for _, n in Yi_split_iters(dims, 0 if dims.is_X8_split else col)] * dims.X_loop,
                dims.ifm_bits,
            )],
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [n for _, n in Yi_split_iters(dims, 0 if dims.is_X8_split else col)] * dims.X_loop,
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)
            for _ in range(ifm_chain_length(dims) if not dims.enable_ifm_streaming else 1)],
            sync_strategy=SyncStrategy.Parallel_1_to_N,
            reuse_ratio=ifm_memtile_reuse_ratio // ifm_chain_length(dims) if not dims.enable_ifm_streaming else ifm_memtile_reuse_ratio,
        ) for col in range(dims.aie_cols)
    ]
    if dims.is_X8_split:
        wgt_chain = 1
        wgt_reuse = wgt_memtile_reuse_ratio if not dims.pin_wgt_bias_l1 else 1
    else:
        wgt_chain = wgt_chain_length(wgt_memtile_reuse_ratio) if dims.enable_wgt_reuse else 1
        wgt_reuse = wgt_memtile_reuse_ratio // wgt_chain_length(wgt_memtile_reuse_ratio) if dims.enable_wgt_reuse else wgt_memtile_reuse_ratio
    memtile_wgt_transfer = [
        DataTransfer(
            wgt_repeat_counts(dims, wgt_memtile_repeat_count) if dims.is_X8_split else wgt_memtile_repeat_count_list,
            AieTile(TileType.Memtile, col), wgt_memtile_addrs, dims.wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims),
            ) for _ in range(wgt_chain)],
            reuse_ratio=wgt_reuse,
        ) for col in range(0, dims.aie_cols, bcst_col_step)
    ]
    """ NOTE: why  dims.mt_co_pack
              1. because of the X8 split, the shim OFM might be out of dimension because of the Co split;
              2. so chained BD (paralle BD) might be used for shim to receive the OFM data; but shim BD total number is limited;
              3. then comibing the Co(packing Co in memtile) based on memtile available space before sending to shim;
              4. will reduce the # of BDs used in shim.
    """
    memtile_ofm_transfer = [
        DataTransfer(
            Yo_repeat_counts(dims, col, dims.Co_loop // dims.mt_co_pack),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, dims.ofm_memtile_size,
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ofm_memtile_memory(dims, col),
                ofm_memtile_s2mm(dims, col, row, co_index),
                [n for _, n in Yo_split_iters(dims, 0 if dims.is_X8_split else col)] * dims.X_loop,
                dims.ofm_bits,
            ) for row in range(dims.aie_rows) for co_index in range(dims.mt_co_pack)],
            [pack_reconfig_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col),
                ofm_memtile_mm2s(dims, col),
                [n for _, n in Yo_split_iters(dims, 0 if dims.is_X8_split else col)] * dims.X_loop,
                dims.ofm_bits,
            )],
            sync_strategy=SyncStrategy.Parallel_N_to_1,
        ) for col in range(dims.aie_cols)
    ]
    memtile_transfers += (memtile_kernel_params_transfer + memtile_ifm_transfer + memtile_wgt_transfer + memtile_ofm_transfer)

    max_chain_length = 4

    shim_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop * dims.X_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        ) for col in range(dims.aie_cols)
    ]
    shim_kernel_params_transfer = [
        generate_shim_data_transfer(                                       #NOTE: Adding conv_kernel_param for only xint8 because of conv kernel requirement
            [1] + [0] * (dims.Y_loop * dims.X_loop - 1),
            shim_dma(col, DmaDir.MM2S, 1), conv_shim_alloc.wgt_buffer_id,
            conv_kernel_prm_shim_memory(dims,  dims.conv_kernel_param_size),
            prm_shim_mm2s(0),                                              #Making 0 to start from Col0:1 in memory format for all col
            ) for col in range(0,dims.aie_cols,bcst_col_step) if dims.is_xint8
    ]
    if dims.enable_ifm_streaming:
        len_mem = dims.Y_loop
        num_iter = [1] * dims.Y_loop
        start_iter = list(range(dims.Y_loop)) * dims.X_loop
    else:
        max_iter = 64 if dims.is_X8_split else 0
        len_mem = len(Yi_split_iters(dims, 0 if dims.is_X8_split else col, max_iter))
        num_iter = [n for _, n in Yi_split_iters(dims, 0 if dims.is_X8_split else col, max_iter)]
        start_iter = [s + i*dims.Y_loop for i in range(dims.X_loop) for s, _ in Yi_split_iters(dims, 0 if dims.is_X8_split else col, max_iter)]
    shim_ifm_transfer = [generate_shim_data_transfer(
            ifm_shim_repeat_counts(dims, idx),
            shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.ifm_buffer_id,
            ifm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ifm_bits,
        ) for col in range(dims.aie_cols) for idx, fmt in enumerate(ifm_shim_mm2s(dims, col))] \
        if len(ifm_shim_mm2s(dims, 0)) <= max_chain_length  and not dims.is_X8_split else [  # <-- NOTE: BD Chaining
        generate_packed_shim_data_transfer(
            ifm_shim_repeat_counts(dims, 0, False),
            shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.ifm_buffer_id,
            ([ifm_shim_memory(dims)] * len_mem) * dims.X_loop,
            ifm_shim_mm2s(dims, col),
            num_iter * dims.X_loop,
            start_iter,
            dims.ifm_bits
        ) for col in range(dims.aie_cols)]  # <-- NOTE: Poll and Re-enqueue
    shim_wgt_transfer = [
        generate_shim_data_transfer(
            ([wgt_shim_repeat_count] + [0] * (dims.Y_loop - 1)) if not dims.is_X8_split else \
            (([wgt_shim_repeat_count] + [0] * (dims.Y_loop - 1)) * dims.X_loop) if not dims.pin_wgt_bias_l1 else ([wgt_shim_repeat_count] + ([0] * (dims.Y_loop * dims.X_loop - 1))),
            shim_dma(col, DmaDir.MM2S, 1), conv_shim_alloc.wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(dims, (col // bcst_col_step)),
            buffer_offset= ((dims.conv_kernel_param_size) if dims.is_xint8 else 0),     #NOTE: Adding extra 256 words to offset for only xint8 because of conv_kerenl param.
        ) for col in range(0, dims.aie_cols, bcst_col_step)
    ]

    shim_ofm_size = dims.Co * dims.Yo * dims.Xo * dims.ofm_bits // 8
    shim_ofm_transfer = [
        DataTransfer(
            Yo_repeat_counts(dims, col, 1, False),
            AieTile(TileType.Shim, col), [conv_shim_alloc.ofm_buffer_id], shim_ofm_size,
            [pack_reconfig_transfers(
                shim_dma(col, DmaDir.S2MM, 0),
                [ofm_shim_memory(dims)] * (len(Yo_split_iters(dims, 0 if dims.is_X8_split else col))) * dims.X_loop,
                ofm_shim_s2mm(dims, col, Co_split),
                [n for _, n in Yo_split_iters(dims, 0 if dims.is_X8_split else col)] * dims.X_loop,
                dims.ofm_bits,
            ) for Co_split in range(dims.Co_loop // dims.mt_co_pack)],
            [],
        ) for col in range(dims.aie_cols)
    ]  if dims.is_X8_split else [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop - 1),
            shim_dma(col, DmaDir.S2MM, 0), conv_shim_alloc.ofm_buffer_id,
            ofm_shim_memory(dims),
            fmt,
            bits_per_block=dims.ofm_bits,
        ) for col in range(dims.aie_cols) for fmt in ofm_shim_s2mm(dims, col, 0)
    ]
    shim_transfers += (shim_kernel_params_transfer + shim_ifm_transfer + shim_wgt_transfer + shim_ofm_transfer)

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections() if (dims.aie_cols, dims.aie_rows) == (8, 4) else overlay_4x4_dma_connections(),
        back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0
    )

def main():
    Ci, Yi, Xi = 128, 56, 56
    Co, Yo, Xo = 128, 56, 56
    Cis, Yis, Xis = 32, 5, 64
    Cos, Yos, Xos = 32, 4, 32
    Ky, Kx = 3, 3
    Sy, Sx = 1, 1
    Py_b, Px_b, Py_a, Px_a = 1, 1, 1, 1
    X_split = 2
    Co_split = 2

    Ci_gran = 8
    Co_gran = 8
    X_gran = 4
    X_align = 64

    ifm_bits = 16
    wgt_bits = 8
    ofm_bits = 16
    tdm_bits = 32

    bias_bits = 64
    # C1/C2, shift tdm/res, zp_wgt are int32
    qdq_param_bytes = 5 * 4
    wgt_subv_size = (
        iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, X_align) +
        iceil(((Cos * bias_bits) // 8) + qdq_param_bytes, X_align)
    )

    is_standalone_dwc = True

    dims = ConvDims(
        Ci, Cis, Ci_gran, Co, Cos, Co_gran, Co_split,
        Yi, Yis, Yo, Yos,
        Xi, Xis, Xo, Xos, X_gran, X_align, X_split,
        Ky, Kx,
        Sy, Sx,
        Py_b, Px_b, Py_a, Px_a,
        ifm_bits, wgt_bits, ofm_bits, tdm_bits,
        wgt_subv_size,
        is_standalone_dwc,
    )

    kernel_names = ['run_conv_a16w8_qdq']
    kernel_includes = [
        'super.hh',
        'conv/direct_conv_int16x8_generic/direct_conv_int16x8_generic_wrapper.cc',
    ]

    host_cpp = 'dwc_main.cpp' if is_standalone_dwc else 'conv_main.cpp'
    backend = BackEnd.Adf

    clean_overlay()
    compile_dataflow(dims, BackEnd.Adf, kernel_names, kernel_includes)
    build_sim_overlay(backend, host_cpp, conv_preproc_directives(dims, backend))

if __name__ == '__main__':
    main()
