import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import Tuple, List

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, TransferParams, core_dma, memtile_dma, shim_dma, \
    compute_buffer_size, \
    generate_core_buffer_config, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config

from dataflow_common import \
    overlay_8x4_dma_connections, \
    overlay_4x4_dma_connections, \
    overlay_stack_addr, \
    clean_overlay, \
    build_sim_overlay, \
    ceildiv, \
    shim_alloc, \
    prm_shim_memory, \
    prm_shim_mm2s, \
    prm_memtile_memory, \
    prm_memtile_s2mm, \
    prm_memtile_mm2s

from conv_common import \
    ConvDims, \
    iceil, \
    X_index, \
    conv_input, \
    Co_index, \
    Xi_slice, \
    ifm_core_memory, \
    ifm_core_s2mm, \
    conv_preproc_directives, \
    conv_core_alloc, \
    conv_core_instrs
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


def Yi_slice_iter(dims: ConvDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yi_split = dims.Yos * dims.Sy
    Yi_stride = dims.aie_cols * Yi_split if not dims.is_X8_split else Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride) - dims.Py_b
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.Yi else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yi)) - max(0, min(Yi_start, dims.Yi))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)


def Xi_slice_iter(dims: ConvDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Xi_split = dims.Xos * dims.Sx
    Xi_stride = dims.aie_cols * Xi_split
    Xi_start = (col * Xi_split) + (start_iter * Xi_stride) - dims.Px_b

    Xi_stop = (
        min(Xi_start + conv_input(dims.Xos, dims.Kx, dims.Sx),
            dims.Xi + dims.Px_a) if Xi_start <= dims.Xi else
        Xi_start
    )
    Xi_size = min(dims.Xi, Xi_stop) - max(0,Xi_start)    
    return (Xi_start, Xi_stop, Xi_stride, Xi_size)


def Xi_slice_mt(dims: ConvDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Xi_split = dims.Xos * dims.Sx
    Xi_stride = dims.aie_cols * Xi_split
    Xi_start = (col * Xi_split) + (start_iter * Xi_stride) - dims.Px_b

    Xi_stop = (
        min(Xi_start + conv_input(dims.Xos, dims.Kx, dims.Sx),
            dims.Xi + dims.Px_a) if Xi_start <= dims.Xi else
        Xi_start
    )
    Xi_size = min(dims.Xi, Xi_stop) - max(0,Xi_start)   
    
    if Xi_stop > dims.Xi:
        Xi_stop = Xi_stop - Xi_start
    else:
        Xi_stop = Xi_size
    if Xi_start >= 0:
        Xi_start = 0

    return (Xi_start, Xi_stop, Xi_stride, Xi_size)


def Yo_slice_iter(dims: ConvDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yo_stride = dims.aie_cols * dims.Yos if not dims.is_X8_split else dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo) if Yo_start <= dims.Yo else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)


def Xo_slice_iter(dims: ConvDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Xo_stride = dims.aie_cols * dims.Xos
    Xo_start = (col * dims.Xos) + (start_iter * Xo_stride)
    Xo_stop = min(Xo_start + dims.Xos, dims.Xo) if Xo_start <= dims.Xo else Xo_start
    Xo_size = Xo_stop - Xo_start
    return (Xo_start, Xo_stop, Xo_stride, Xo_size)

def Co_slice_iter(dims: ConvDims, start_iter: int) -> Tuple[int, int, int, int]:
    Co_stride = dims.Cos * dims.Co_split # = dims.Com
    Co_start = (start_iter * Co_stride)
    Co_stop = min(Co_start + Co_stride, dims.Co) if Co_start <= dims.Co else Co_start
    Co_size = Co_stop - Co_start
    return (Co_start, Co_stop, Co_stride, Co_size)



def Yi_split_iters(dims: ConvDims, col: int, max_iter: int = 0) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int, max_iter: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice_iter(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.Yi))
        if max_iter == 0:
            iter_stop = False
        else: 
            iter_stop = (num_iters == (max_iter + 1))
        return has_no_padding and not iter_stop
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters, max_iter):
            while can_iterate(start_iter, num_iters + 1, max_iter):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split



def Yo_split_iters(dims: ConvDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        _, _, _, Yo_size = Yo_slice_iter(dims, col, start_iter + num_iters - 1)
        is_full_slice = Yo_size == dims.Yos
        return is_full_slice
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split


def Yi_repeat_counts(dims: ConvDims, col: int, scale: int) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop * dims.X_loop)]
    for i in range(dims.X_loop):
        for start_iter, num_iters in Yi_split_iters(dims, 0, 64):
            repeat_counts[start_iter + i * dims.Y_loop] = num_iters * scale
    return repeat_counts

def Yo_repeat_counts(dims: ConvDims, col: int, scale: int, is_mt: bool = True) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop * dims.X_loop)]
    # repeat_counts = [0 for _ in range(dims.repeat_len)]
    # for start_iter, num_iters in Yo_split_iters(dims, col):
    for i in range(dims.X_loop):
        for start_iter, num_iters in Yo_split_iters(dims, col):
            if is_mt:
                repeat_counts[start_iter + i * dims.Y_loop] = num_iters * scale 
            else:
                repeat_counts[start_iter + i * dims.Y_loop] =scale
                # repeat_counts[start_iter] = num_iters 
    return repeat_counts


def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
    use_iter_step: bool = False,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
            use_iter_step = use_iter_step,
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def generate_packed_shim_data_transfer(
    repeat_counts: List[int],
    dma: AieDma,
    shim_buffer_idx: int,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iter_nums: List[int],
    tiling_start_iter: List[int],
    bits_per_elem: int,
    max_chain_length: int = 4, 
) -> DataTransfer:
    '''
    Reconfigures a BD with different transfer
    params at the shim for poll and re-enqueue
    '''
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iter_nums)
    assert len(tiling_start_iter) == len(tiling_iter_nums)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iter_nums)
        res = []
        for item, num in zip(items, tiling_iter_nums):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = []
    repeat_coeff_iter = [0] * len(tiling_iter_nums)
    for i in range(num_fmts):
        repeat_coeff, transfer_chain = generate_transfer_params(
                dma,
                memory_fmts[i],
                tiling_fmts[i],
                bits_per_block=bits_per_elem,
                enable_padding=False,
                use_iter_step=True,
                max_chain_length=max_chain_length
        )
        repeat_coeff_iter[i] = repeat_coeff
        for transfer in transfer_chain:
            params.append(transfer)
    packed_params = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
        iter_step=pack([param.iter_step_i(0) for param in params]),
        iter_wrap=pack([param.iter_wrap_i(0) for param in params]),
    )
    buffer_size = compute_buffer_size(memory_fmts[0], bits_per_elem)
    if dma.channel.dir == DmaDir.S2MM:
        write_params = [packed_params]
        read_params = []
    else:
        read_params = [packed_params]
        write_params = []            
    for idx, count in enumerate( tiling_start_iter):
        repeat_counts[count] *= repeat_coeff_iter[idx]
        
    return DataTransfer(
        repeat_counts,
        dma.tile, [shim_buffer_idx], buffer_size,
        write_params,
        read_params
)


def Xo_slice(dims: ConvDims, row: int) -> Tuple[int, int, int, int]:
    Xo_stride = dims.Xos
    Xo_start = X_index(dims, row) * Xo_stride
    Xo_stop = (
        min(Xo_start + Xo_stride, dims.Xo) if Xo_start < dims.Xo else
        Xo_start
    )
    Xo_size = Xo_stop - Xo_start
    return (Xo_start, Xo_stop, Xo_stride, Xo_size)

def Co_slice(dims: ConvDims, row: int) -> Tuple[int, int, int, int]:
    Co_stride = dims.Cos * dims.Co_split
    Co_start = Co_index(dims, row) * dims.Cos
    Co_stop = min(Co_start + dims.Cos, dims.Co)
    Co_size = Co_stop - Co_start
    return (Co_start, Co_stop, Co_stride, Co_size)

def ofm_core_memory(dims: ConvDims) -> str:
    return f'Yo:{dims.Yos} Co:{dims.Cos} Xo:{dims.Xos} Co:{dims.Co_gran}'

def ofm_core_mm2s(dims: ConvDims, row: int) -> str:
    _, _, _, Xo_size = Xo_slice(dims, row)
    _, _, _, Co_size = Co_slice(dims, row)
    return f'Yo:0:{dims.Yos} Co:0:{Co_size}:{dims.Co_gran} Xo:0:{Xo_size} Co:0:{dims.Co_gran}'

def ifm_shim_memory(dims: ConvDims) -> str:
    if dims.ifm_use_hwc_format:
        return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'
    else:
        return f'Yi:{dims.Yi} Ci:{dims.Ci} Xi:{dims.Xi} Ci:{dims.Ci_gran}'

def ifm_shim_mm2s(dims: ConvDims, col: int, Ci_split: int) -> List[str]:
    def fmt(start_Y_iter: int, start_X_iter: int,  num_iters: int) -> str:
        Yi_start, Yi_stop, Yi_stride, _ = Yi_slice_iter(dims, 0, start_Y_iter)
        Xi_start, Xi_stop, Xi_stride, _ = Xi_slice_iter(dims, col, start_X_iter)
        Ci_stride = dims.Cim // dims.shim_ci_split
        Ci_start = Ci_split * Ci_stride
        Ci_stop = Ci_start + Ci_stride
        if dims.ifm_use_hwc_format:
            return (
                f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
                f'Ci:0:{dims.Ci}:{dims.Cim} '
                f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Xi:{max(0, Xi_start)}:{min(Xi_stop, dims.Xi)} Ci:{Ci_start}:{Ci_stop}'
            )
        else:
            return (
                f'Yi:0:{num_iters * Yi_stride}:{Yi_stride} '
                f'Ci:0:{dims.Ci}:{dims.Cim} '
                f'Yi:{max(0, Yi_start)}:{min(Yi_stop, dims.Yi)} Ci:{Ci_start}:{Ci_stop}:{dims.Ci_gran} Xi:{max(0, Xi_start)}:{min(Xi_stop, dims.Xi)} Ci:0:{dims.Ci_gran}'
            )
    if dims.enable_ifm_streaming and (dims.Co_loop > 1) and (not dims.is_standalone_dwc):
        fs = [fmt(s_y, X_split_index, 1) for X_split_index in range(dims.X_loop) for s_y in range(dims.Y_loop)]
    else:
        fs = [fmt(s_y, X_split_index, n) for X_split_index in range(dims.X_loop) for s_y, n in Yi_split_iters(dims, 0, 64)]
    return fs

#loop_iter_mode:
#  when True: each transfer split to each phase of Y_loop   
#  when False: each transfer split to iteration mode, usually will be 1 or 3 mode
#  for loop mode the reconfig of BD happens in each Y_loop, the performance will be be good
#  only when the big size it generate huge repeat_count and per the repeat_count limitation, 
#     we need to split over phase --- more reconfig

def ifm_shim_repeat_counts(dims: ConvDims, idx: int,  loop_iter_mode: bool = True) -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop *dims.X_loop )]
    
    if dims.enable_ifm_streaming and (dims.Co_loop > 1) and (not dims.is_standalone_dwc):
        repeat_counts = [dims.Co_loop] * len(repeat_counts)
    else:
        for i in range(dims.X_loop):
            for start_iter, _ in Yi_split_iters(dims, 0, 64):
                repeat_counts[start_iter + i * dims.Y_loop] = 1
              
    return repeat_counts
 
   
#NOTE : 
#       when the  wgt_memtile_repeat_count is too big(>1024), we need to split to multiple with phase  
def wgt_repeat_counts(dims: ConvDims, scale: int)  -> List[int]:
    repeat_counts = [0 for _ in range(dims.Y_loop)]
    if dims.pin_wgt_bias_l1:
        repeat_counts = repeat_counts * dims.X_loop
        repeat_counts[0] = scale
        return repeat_counts
    else:
        if scale <= 1024:
            repeat_counts[0] = scale
        else:
            phase_split = ceildiv(scale, dims.Y_loop)
            start_iter = 0
            iter_step = 1024 // phase_split
            remain_repeat = scale
            while remain_repeat > 1024:
                repeat_counts[start_iter] = 1024
                remain_repeat = remain_repeat - 1024
                start_iter += iter_step
            repeat_counts[start_iter]= remain_repeat
        return repeat_counts * dims.X_loop
   
def ifm_memtile_memory(dims: ConvDims, col: int) -> List[str]:
    def fmt(start_Y_iter: int, start_X_iter) -> str:
        _, _, _, Yi_size = Yi_slice_iter(dims, 0, start_Y_iter)
        Xi_start, Xi_stop, Xi_stride, Xi_size = Xi_slice_iter(dims, col, start_X_iter)
        if Yi_size <= 0:
            Yi_size = dims.Yis
        if Xi_size <=0:
            Xi_size = dims.Xis
        Y_loop = dims.Y_loop
        return f'Yi:{Yi_size} Ci:{dims.Cim // dims.mt_ci_split} Xi:{Xi_size} Ci:{dims.Ci_gran}'
    fs = [fmt(s_y, X_split_index) for X_split_index in range(dims.X_loop) for s_y, _ in Yi_split_iters(dims, 0)]
    return fs


def ifm_memtile_s2mm(dims: ConvDims, col: int) -> List[str]:
    def fmt(start_Y_iter: int, start_X_iter: int) -> str:
        _, _, _, Yi_size = Yi_slice_iter(dims, 0, start_Y_iter)
        _, _, _, Xi_size = Xi_slice_iter(dims, col, start_X_iter)
        if Yi_size <=0:
            Yi_size =0
        if Xi_size <=0:
            Xi_size =0
        if dims.ifm_use_hwc_format:
            return f'Yi:0:{Yi_size} Xi:0:{Xi_size} Ci:0:{dims.Cim // dims.mt_ci_split}'
        else:
            return f'Yi:0:{Yi_size} Ci:0:{dims.Cim // dims.mt_ci_split}:{dims.Ci_gran} Xi:0:{Xi_size} Ci:0:{dims.Ci_gran}'
    fs = [fmt(s_y, X_split_index) for X_split_index in range(dims.X_loop) for s_y, _ in Yi_split_iters(dims, 0)]
    return fs


def ifm_memtile_mm2s(dims: ConvDims, col: int, row: int) -> List[str]:
    def fmt(start_Y_iter: int, start_X_iter: int) -> str:
        Yi_start, _, _, _ = Yi_slice_iter(dims, 0, start_Y_iter)
        Xi_start, _, _, _ = Xi_slice_mt(dims, col, start_X_iter)
        Xi_stride = conv_input(dims.Xos, dims.Kx, dims.Sx)
        Xi_stop = Xi_start + Xi_stride
        if dims.is_standalone_dwc:
            assert dims.Cis == dims.Cos
            Ci_start = Co_index(dims, row) * dims.Cis
            Ci_stop = Ci_start + dims.Cis
            Ci_stride = dims.Cis * dims.Co_split
        else:
            Ci_start = 0
            Ci_stop = dims.Cis
            Ci_stride = dims.Cis

        return (
            f'Ci:0:{dims.Cim // dims.mt_ci_split}:{Ci_stride} '
            f'Yi:{min(Yi_start, 0)}:{min(Yi_start, 0) + dims.Yis} '
            f'Ci:{Ci_start}:{Ci_stop}:{dims.Ci_gran} Xi:{Xi_start}:{Xi_stop} Ci:0:{dims.Ci_gran}'
        )
    fs = [fmt(s_y, X_split_index) for X_split_index in range(dims.X_loop) for s_y, _  in Yi_split_iters(dims, 0)]
    return fs

#NOTE: if Cos = Co == 8, the Co_loop =1, here we can optimize the memory allocation. 
def wgt_shim_memory(dims: ConvDims) -> str:
    return f'Cos:{dims.Co_loop * dims.Co_split} Cis:{dims.Ci_loop} Subv:{dims.wgt_subv_size}'

def wgt_shim_mm2s(dims: ConvDims, row: int) -> str:
    Co_start = Co_index(dims, row)
    Co_stop = Co_start + 1
    return f'Cos:0:{dims.Co_loop * dims.Co_split}:{dims.Co_split} Cos:{Co_start}:{Co_stop} Cis Subv'

def wgt_memtile_memory(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return f'Cos:{dims.Co_loop} Cis:{dims.Ci_loop} Subv:{dims.wgt_subv_size}'
    else:
        # return f'Subv:{dims.wgt_subv_size}'
         return f'Subv:{dims.wgt_memtile_size}'

def wgt_memtile_s2mm(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return 'Cos Cis Subv'
    else:
        return f'Subv'

def wgt_memtile_mm2s(dims: ConvDims) -> str:
    if dims.enable_wgt_reuse:
        return 'Cos Cis Subv'
    else:
        return f'Subv'

def ofm_shim_memory(dims: ConvDims) -> str:
    if dims.ofm_use_hwc_format:
        return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'
    else:
        return f'Yo:{dims.Yo} Co:{dims.Co} Xo:{dims.Xo} Co:{dims.Co_gran}'


def ofm_shim_s2mm(dims: ConvDims, col: int, Co_split: int) -> List[str]:
    assert 0 <= col < dims.aie_cols
    def fmt(start_Y_iter: int, start_X_iter, num_iters: int) -> str:
        Yo_start, Yo_stop, Yo_stride, _ = Yo_slice_iter(dims, 0, start_Y_iter)
        Xo_start, Xo_stop, Xo_stride, _ = Xo_slice_iter(dims, col, start_X_iter)

        Co_stride = min(dims.Com * dims.mt_co_pack, dims.Co)
        Co_start = Co_split * Co_stride
        Co_stop  = Co_start + Co_stride 

        if dims.ofm_use_hwc_format:
            return (
                f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
                f'Co:0:{Co_stride}:{Co_stride} '
                f'Yo:{Yo_start}:{Yo_stop} Xo:{Xo_start}:{Xo_stop} Co:{Co_start}:{Co_stop}'
            )
        else:
            return (
                f'Yo:0:{Yo_stride * num_iters}:{Yo_stride} '
                f'Co:0:{Co_stride}:{Co_stride} '
                f'Yo:{Yo_start}:{Yo_stop} Co:{Co_start}:{Co_stop}:{dims.Co_gran} Xo:{Xo_start}:{Xo_stop} Co:0:{dims.Co_gran}'
            )
    fs = [fmt(s_y, X_split_index, n) for X_split_index in range(dims.X_loop) for s_y, n in Yo_split_iters(dims, 0) ]
    return fs


def ofm_memtile_memory(dims: ConvDims, col: int) -> List[str]:
    fmt = f'Yo:{dims.Yos} Xo:{dims.Xom} Co:{dims.Com * dims.mt_co_pack}'
    fs = [fmt] * len(Yo_split_iters(dims, 0))
    return fs


def ofm_memtile_s2mm(dims: ConvDims, col: int, row: int, co_index: int) -> List[str]:
    Xo_start = 0
    Xo_stop = Xo_start + dims.Xos
    Co_start = Co_index(dims, row) * dims.Cos
    Co_stop = Co_start + dims.Cos
    fmt = f'Co:{co_index * dims.Com}:{(co_index + 1) *dims.Com}:{dims.Com * dims.mt_co_pack} Yo Co:{Co_start}:{Co_stop}:{dims.Co_gran} Xo:{Xo_start}:{Xo_stop} Co:0:{dims.Co_gran}'
    fs = [fmt] * len(Yo_split_iters(dims, 0)) * dims.X_loop
    return fs


def ofm_memtile_mm2s(dims: ConvDims, col: int) -> List[str]:
    def fmt(start_Y_iter: int, start_X_iter: int) -> str:
        _, _, _, Yo_size = Yo_slice_iter(dims, 0, start_Y_iter)
        _, _, _, X_size = Xo_slice_iter(dims, col, start_X_iter)
        if dims.ofm_use_hwc_format:
            return f'Yo:0:{Yo_size} Xo:0:{X_size} Co:0:{min(dims.Com * dims.mt_co_pack, dims.Co)}'
        else:
            return f'Yo:0:{Yo_size} Co:0:{min(dims.Com * dims.mt_co_pack, dims.Co)}:{dims.Co_gran} Xo:0:{X_size} Co:0:{dims.Co_gran}'
    fs = [fmt(s_y, X_split_index) for X_split_index in range(dims.X_loop) for s_y, _ in Yo_split_iters(dims, 0)]
    return fs


def split_memtile_ifm_subv(
    dims: ConvDims,
    size_cutoff: int,
) -> int:
    if dims.is_standalone_dwc:
        loop_count = dims.Co_loop
        subv_size = (dims.Yis * dims.Xi * dims.Cos * dims.Co_split * dims.ifm_bits) // 8
    else:
        loop_count = dims.Ci_loop
        subv_size = (dims.Yis * dims.Xi * dims.Cis * dims.ifm_bits) // 8
    for num_ifm_subv in range(loop_count, 0, -1):
        memtile_size = subv_size * num_ifm_subv
        is_valid = (
            (memtile_size <= size_cutoff) and
            ((loop_count % num_ifm_subv) == 0)
        )
        if is_valid:
            return num_ifm_subv
    return 1

def pack_memtile_wgt_subv(
    dims: ConvDims,
    size_cutoff: int,
    is_pingpong: bool = True,
) -> int:
    if dims.is_standalone_dwc:
        loop_count = dims.Co_loop
        subv_size = dims.wgt_subv_size
        total_count = dims.Y_loop * dims.Ci_loop
    else:
        loop_count = dims.Ci_loop
        subv_size = dims.wgt_subv_size
        total_count = dims.Y_loop * dims.Co_loop * dims.Ci_loop
    if total_count <= 256:
        return 1
    for num_wgt_subv in range(loop_count, 0, -1):
        memtile_size = subv_size * num_wgt_subv * 2 if is_pingpong else 1
        is_valid = (
            (memtile_size <= size_cutoff) and
            ((loop_count % num_wgt_subv) == 0)
        )
        if is_valid:
            return num_wgt_subv
    return 1


def compile_dataflow(
    dims: ConvDims,
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
):
    # Core Buffer Allocation

    core_alloc = conv_core_alloc(dims, overlay_stack_addr())

    """Split for reuse strategy:
    
        1. to caculate each out pixel with full Co channel (C:0:Co, Yo, Xo),  it needs the entire  input channel pixel(C 0:Ci, Yi, Xi) and entire weight (Co_loop * Ci_loop * wgt_subv_size)
        2. to caculate each out pixel with partial Co channel (C:0:Co_part, Yo, Xo), it needs the entire input channle pixel(C 0:Ci, Yi, Xi) and partial weight (Ci_loop * wgt_subv_size for easy partition)
        3. to reuse meaning:  the wgt or ifm shard kept in memtile memory, and the other part keep refetching for N times or till the end of the caculation. 
                              when reuse, the reused buffer will be single buffer, not for pingpong mode
           
           ***reuse wgt***  
        4. to enable the reuse of wgt,  meaning the entire weight (Co_loop * Ci_loop * wgt_subv_size) can be kept into memtile, then
           1) going thru the entire Y_loop and X_loop of pixel to get entire output pixel
           2) no need to refetch the ifm data (from shim to memtile)
           3) only fetach once the entire wgt into memtile and reuse the wight with reuse_ratio = Y_loop
           4) the outer most loop is Y_loop
        5. if the wgt is too big, we might enable the partial reuse of wgt, the resue part is (Ci_loop * wgt_subv_size)
           1) only if the (Ci_loop * wgt_subv_size) can be hold in memtile, wgt reuse_ratio = Co_loop
           2) the outer most loop is Co_loop
           3) fetch entire ifm to get the partital(Co_part) pixel(C:0:Co_part, Yo, Xo)
           4) then re-fetch next Co_loop of wgt size = (Ci_loop * wgt_subv_size) -- NOTE the refetch has to be happening after all the Co pixel genrated of ofm( phase alignment)
           5) then re-fetch entire ifm to get the partital(Co_part) pixel(C:Co_part:Co_part*2, Yo, Xo)
           6) repeat until the Co_loop finished. 
            
            ***reuse ifm*** 
        6. to enable the reuse of ifm,  meaning the entire ifm (Ci * Yi * Xi )//aie_cols can be kept into memtile, then
           1) Y_loop = 1 , X_split = 1  ( small X and Y, small Ci, might be huge Co)
           2) ifm reuse ratio = Co_loop
           2) keep refetching (Ci_loop * wgt_subv_size) for Co_loop times to get the enire out pixel(C:0:Co, Yo, Xo)-- NOTE the refetch has to be happening after all the Co pixel genrated of ofm( phase alignment)
           
        NOTE: we are supporting 4 right now, 5 and 6 can be added later if for performance optimization sake. 
    """
    

    # the ifm or wgt shard size might be big, so further cut in the Ci_loop direction 
    # to make each phase repeat as mt_ci_split
    
    """"
    rules:
         1.  Ci // mt_ci_split = n * Cis,  n == 1, 2...
    

    """
    
    """
    ifm_shard_size : 
        1. the size that hold the ifm_subv along the Ci dimension
        2. for y8 split,  it is dims.Yis * dims.Xi * dims.Ci
        3. for x8 split,  it is dims.Yis * dims.Xis(=dims.Xi//aie_col) * dims.Ci 
        4. if the ifm_shard_size can fit in memtile, then there is benefit for ifm reuse in memtile:
            1) if can be reuse --- this is callled NOT enable_ifm_streaming mode (non-streaming mode)
               -- meaning, each ifm_shard_size transfered once from shim to memtile and then be reused for co_loop times
            2) otherwise, it is called enable_ifm_streaming mode (streaming mode)
               -- meaning the ifm has to be refetched co_loop times for each cos caculation
            3) when in streaming mode or non streaming mode, we can further make judgement if the wgt can be reused. 
               -- because we are in Y_loop as outer most loop, wgt_reuse only support full wgt reuse, not partial reuse
               -- for non-streaming mode, we can also enable the wgt_reuse mode , this is differ from current design
            4) when non-streaming mode --> resuse the ifm , then the ifm in memtile can only be ping address mode(not pingpong)
        5. summary 
           1) ifm_shard_size can decide the ifm in memtile reuse or not
           2) even ifm being reused, we can still seek the wgt reuse ( for small wgt,  and beneficial with X8 split)
    """
    
    
    """mt_ci_split and mt_co_pack:
        1. are used for increasing the size of the buffer
        2. or chained more BD
        3. to reduce the repeat count/ the memory dimension restriction and etc.. 
    """
    
    """ for wgt NON in reuse mode in memtile:
        1. total repeat will be (dims.Y_loop * dims.Co_loop * dims.Ci_loop) * wgt_subv,  
           --- the repeat might be huge, like (1024, 512, 512), repeat up to 262,144
        2. to reduce the repeat:
           1) we could pack part of all of dims.Ci_loop * wgt_subv as wgt memtile buffer, then repeat counts = Y_loop*Co_loop
           2) phase split like to Y_loop phase, each phase repeat will only = Co_loop  
    """

    dims.mt_ci_split = 1 
    dims.shim_ci_split = dims.mt_ci_split

    if core_alloc.ofm_pong_addr is not None:
        memtile_ofm_buffering = 2
    else:
        memtile_ofm_buffering = 1
    has_memtile_ofm_double_buffer = memtile_ofm_buffering == 2

    ofm_memtile_ping_addr = None
    ofm_memtile_pong_addr = None
    ofm_memtile_addr = None
    if dims.enable_ifm_streaming:

        prm_memtile_addr = 0
        ifm_memtile_ping_addr = prm_memtile_addr + dims.prm_memtile_size
        ifm_memtile_pong_addr = ifm_memtile_ping_addr + dims.ifm_memtile_size
        if dims.enable_wgt_reuse:
            wgt_memtile_addr = ifm_memtile_pong_addr + dims.ifm_memtile_size
            if has_memtile_ofm_double_buffer:
                ofm_memtile_ping_addr = wgt_memtile_addr + dims.wgt_memtile_size
                ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
            else:
                ofm_memtile_addr = wgt_memtile_addr + dims.wgt_memtile_size
        else:
            wgt_memtile_ping_addr = ifm_memtile_pong_addr + dims.ifm_memtile_size
            wgt_memtile_pong_addr = wgt_memtile_ping_addr + dims.wgt_memtile_size
            if has_memtile_ofm_double_buffer:
                ofm_memtile_ping_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size
                ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
            else:
                ofm_memtile_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size

        ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
        ifm_memtile_repeat_scale = (dims.Co_loop * dims.Ci_loop) // dims.num_ifm_subv
        ifm_memtile_reuse_ratio = 1
    else:

        prm_memtile_addr = 0
        """
            Enable ifm memtile double buffering for Co_loop == 1 and Ci_loop == 1 though enable_ifm_streaming is False.
            In this special case; Co_loop == 1 is enabler for double buffering since ifm_memtile_reuse_ratio = Co_loop i.e. 1.
            Along with this with Ci_loop == 1; ensures that double buffered ifm will fit in memtile since ifm data
            brought in memtile is consumed in coretile.
        """
        if (dims.Co_loop == 1) and (dims.Ci_loop == 1):
            ifm_memtile_ping_addr = prm_memtile_addr + dims.prm_memtile_size
            ifm_memtile_pong_addr = ifm_memtile_ping_addr + dims.ifm_memtile_size
            ifm_memtile_addr = ifm_memtile_pong_addr
            ifm_memtile_addrs = [ifm_memtile_ping_addr, ifm_memtile_pong_addr]
        else:
            ifm_memtile_addr = prm_memtile_addr + dims.prm_memtile_size
            ifm_memtile_addrs = [ifm_memtile_addr]

        if dims.enable_wgt_reuse:
            wgt_memtile_addr = ifm_memtile_addr + dims.ifm_memtile_size
            wgt_memtile_pong_addr = wgt_memtile_addr
        else:
            wgt_memtile_ping_addr = ifm_memtile_addr + dims.ifm_memtile_size
            wgt_memtile_pong_addr = wgt_memtile_ping_addr + dims.wgt_memtile_size
        if has_memtile_ofm_double_buffer:
            ofm_memtile_ping_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size
            ofm_memtile_pong_addr = ofm_memtile_ping_addr + dims.ofm_memtile_size
        else:
            ofm_memtile_addr = wgt_memtile_pong_addr + dims.wgt_memtile_size

        ifm_memtile_repeat_scale = 1
        ifm_memtile_reuse_ratio = dims.Co_loop

    if dims.enable_wgt_reuse:
        wgt_memtile_addrs = [wgt_memtile_addr]
        
        # temp code to reuse the wgt
        wgt_memtile_reuse_ratio = dims.Y_loop
        wgt_memtile_repeat_count = 1
        wgt_shim_repeat_count = 1
    else:
        wgt_memtile_addrs = [wgt_memtile_ping_addr, wgt_memtile_pong_addr]
        wgt_memtile_repeat_count = (dims.Y_loop * dims.Co_loop * dims.Ci_loop // dims.num_pack_wgt_subv ) if not dims.pin_wgt_bias_l1 else 1
        wgt_memtile_reuse_ratio = 1
        wgt_shim_repeat_count = dims.Y_loop if not dims.pin_wgt_bias_l1 else 1

    if has_memtile_ofm_double_buffer:
        assert(ofm_memtile_pong_addr is not None and ofm_memtile_ping_addr is not None)
        ofm_memtile_addrs = [ofm_memtile_ping_addr, ofm_memtile_pong_addr]
    else:
        assert(ofm_memtile_addr is not None)
        ofm_memtile_addrs = [ofm_memtile_addr]

    
    # Shim Buffer Allocation

    conv_shim_alloc = shim_alloc()

    core_instrs = {}
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            core_instrs[AieTile(TileType.Core, col, row)] = conv_core_instrs(
                dims, core_alloc,
                dims.Y_loop * dims.X_loop,
                dims.Co_loop,
                dims.Ci_loop,
                ifm_config=generate_core_buffer_config(
                    core_dma(col, row, DmaDir.S2MM, 0),
                    core_alloc.ifm_ping_addr, core_alloc.ifm_pong_addr,
                    ifm_core_memory(dims),
                    ifm_core_s2mm(dims, row),
                    bits_per_block=dims.ifm_bits,
                ),
            )

    bcst_col_step = 2 if (dims.aie_cols, dims.aie_rows) == (8, 4) else 1
  
    
    # shim_ifm_repeat_count = 1 if dims.enable_wgt_reuse else (1 if dims.mt_ci_split == 1 else dims.Co_loop)
    """ to remove the shim ifm restrictions
    
        1. what are the restrictions:
            1) the shim dimension is 3 max
            2) the shim wrap = 1024 words
            3) the chained BDs <= 4
        2. best performance 
            1) if one BD with less configration (usually 3)
            2) then if <=4 chanined BDs with less configuration ( usually 3)
            3) then one or multi BD with phase split( = num_reconfig) = Y_loop
        3. current solution 
            1) if shim_ifm_phase_split = 0, goes to 2 1)
            2) if shim_ifm_phase_split = 1, goes to 2 3)
            3) future optization to mix 2 1) 2) and 3) 
    """
     
     
    memtile_transfers = [
        DataTransfer(
            [1] + [0] * (dims.Y_loop * dims.X_loop- 1),
            AieTile(TileType.Memtile, col), [prm_memtile_addr], dims.prm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                prm_memtile_memory(dims),
                prm_memtile_s2mm(),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                prm_memtile_memory(dims),
                prm_memtile_mm2s(row),
            ) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Serial_M_to_N,
        ) for col in range(dims.aie_cols)
    ] 

    memtile_ifm_transfer =[
        DataTransfer(
            Yi_repeat_counts(dims, col, ifm_memtile_repeat_scale),
            AieTile(TileType.Memtile, col), ifm_memtile_addrs, dims.ifm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims, col),
                ifm_memtile_s2mm(dims, col),
                [n for _, n in Yi_split_iters(dims, 0)] * dims.X_loop,
                dims.ifm_bits,
            )],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims, col),
                ifm_memtile_mm2s(dims, col, row),
                [n for _, n in Yi_split_iters(dims, 0)] * dims.X_loop,
                dims.ifm_bits,
            ) for row in range(dims.aie_rows)],
            sync_strategy=SyncStrategy.Parallel_1_to_N,

            reuse_ratio = ifm_memtile_reuse_ratio,
        ) for col in range(dims.aie_cols)
    ]               
    
    memtile_wgt_transfer = [
        DataTransfer(
            wgt_repeat_counts(dims, wgt_memtile_repeat_count),
            AieTile(TileType.Memtile, col), wgt_memtile_addrs, dims.wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims),
            )],
            reuse_ratio=wgt_memtile_reuse_ratio if not dims.pin_wgt_bias_l1 else 1,
        ) for col in range(0, dims.aie_cols, bcst_col_step)
    ] 
    
    memtile_ofm_transfer = [
        DataTransfer(
            Yo_repeat_counts(dims, 0, dims.Co_loop // dims.mt_co_pack),
            AieTile(TileType.Memtile, col), ofm_memtile_addrs, dims.ofm_memtile_size,
            [pack_transfers(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                ofm_memtile_memory(dims, col) * dims.X_loop,
                ofm_memtile_s2mm(dims, col, row, co_index),
                [n for _, n in Yo_split_iters(dims, 0)] * dims.X_loop,
                dims.ofm_bits,
            ) for row in range(dims.aie_rows) for co_index in range(dims.mt_co_pack)],
            [pack_transfers(
                memtile_dma(col, DmaDir.MM2S, 5),
                ofm_memtile_memory(dims, col) * dims.X_loop,
                ofm_memtile_mm2s(dims, col),
                [n for _, n in Yo_split_iters(dims, 0)] * dims.X_loop,
                dims.ofm_bits,
            )],
            sync_strategy=SyncStrategy.Parallel_N_to_1,
        ) for col in range(dims.aie_cols)
    ]
 
    
    memtile_transfers += (memtile_ifm_transfer + memtile_wgt_transfer + memtile_ofm_transfer)

    shim_transfers = [
        generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop * dims.X_loop- 1),
            shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
            max_chain_length=dims.shim_BD_num['prm'],
        ) for col in range(dims.aie_cols)
    ] 
    shim_ifm_size = dims.Ci * dims.Yi * dims.Xi * dims.ifm_bits //8
    shim_ofm_size = dims.Co * dims.Yo * dims.Xo * dims.ofm_bits //8


    """ if mt_ci_split !=1 , it means the whole ifm Ci shard (for each Y split) can't fit into the memtile
                             so the ifm reuse can't be enabled in memtile (mm2s to core)
                             then each Cos caculation will require refetch both wgt and ifm. 
        and
        why using shim_ci_split (= mt_ci_split, naming differently just reserved for future change in case) is used because of matching the fetch scheme. 
        if shim_ci_split = 1; the access pattern will be (element, not convert to word)
                                step0: 1,  wrap0 = Ci
                                step1: Ci, wrap1 = Xis
                                step2: Ci*Xi
                            but memtile do need below format:
                                step0: 1,     wrap0 = Ci // mt_ci_split
                                step1: Ci,    wrap1 = Xis
                                step2: Ci*Xi, wrap2 = Yis
                                step3: Ci // mt_ci_split
                            this will exceed the shim dimmension restriction
        if shim_ci_split > 1,  the shim transfer will chained up to 4 BD 
                            each BD will be split with offset = Ci // mt_ci_split

    """
       
    shim_ifm_transfer = [generate_packed_shim_data_transfer(
        ifm_shim_repeat_counts(dims, 0),
        shim_dma(col, DmaDir.MM2S, 0), conv_shim_alloc.ifm_buffer_id,
        [ifm_shim_memory(dims)] * (len(Yi_split_iters(dims, 0, 64))) * dims.X_loop,
        ifm_shim_mm2s(dims, col, 0),
        [n for _, n in Yi_split_iters(dims, 0, 64)] * dims.X_loop,
        [s + i*dims.Y_loop for i in range(dims.X_loop) for s, _ in Yi_split_iters(dims, 0, 64)],
        dims.ifm_bits
    ) for col in range(dims.aie_cols)] 
         
    shim_wgt_transfer =[
        generate_shim_data_transfer(
            (([wgt_shim_repeat_count] + [0] * (dims.Y_loop - 1)) * dims.X_loop) if not dims.pin_wgt_bias_l1 else ([wgt_shim_repeat_count] + ([0] * (dims.Y_loop * dims.X_loop - 1))),
            shim_dma(col, DmaDir.MM2S, 1), conv_shim_alloc.wgt_buffer_id,
            wgt_shim_memory(dims),
            wgt_shim_mm2s(dims, (col // bcst_col_step)),
            max_chain_length=dims.shim_BD_num['wgt'],
        ) for col in range(0, dims.aie_cols, bcst_col_step)
    ]
  
    shim_ofm_transfer = [
        DataTransfer(
            Yo_repeat_counts(dims, 0, 1, False),
            AieTile(TileType.Shim, col), [conv_shim_alloc.ofm_buffer_id], shim_ofm_size,
            [pack_transfers(
                shim_dma(col, DmaDir.S2MM, 0),
                [ofm_shim_memory(dims)] * (len(Yo_split_iters(dims, 0))) * dims.X_loop,
                ofm_shim_s2mm(dims, col, Co_split),
                [n for _, n in Yo_split_iters(dims, 0)] * dims.X_loop,
                dims.ofm_bits,
            ) for Co_split in range(dims.Co_loop // dims.mt_co_pack)] ,
            [],
        ) for col in range(dims.aie_cols)
    ]  
 
  
    shim_transfers += (shim_ifm_transfer + shim_wgt_transfer + shim_ofm_transfer)
    
   
    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        overlay_8x4_dma_connections() if (dims.aie_cols, dims.aie_rows) == (8, 4) else overlay_4x4_dma_connections(),
        back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0
    )

def main():
    Ci, Yi, Xi = 128, 56, 56
    Co, Yo, Xo = 128, 56, 56
    Cis, Yis, Xis = 32, 5, 64
    Cos, Yos, Xos = 32, 4, 32
    Ky, Kx = 3, 3
    Sy, Sx = 1, 1
    Py_b, Px_b, Py_a, Px_a = 1, 1, 1, 1
    X_split = 2
    Co_split = 2

    Ci_gran = 8
    Co_gran = 8
    X_gran = 8
    X_align = 64

    ifm_bits = 16
    wgt_bits = 8
    ofm_bits = 16
    tdm_bits = 32

    bias_bits = 64

    qdq_param_bytes = 5 * 4
    wgt_subv_size = (
        iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, X_align) +
        iceil(((Cos * bias_bits) // 8) + qdq_param_bytes, X_align)
    )

    is_standalone_dwc = True
    
    dims = ConvDims(
        Ci, Cis, Ci_gran, Co, Cos, Co_gran, Co_split,
        Yi, Yis, Yo, Yos,
        Xi, Xis, Xo, Xos, X_gran, X_align, X_split,
        Ky, Kx,
        Sy, Sx,
        Py_b, Px_b, Py_a, Px_a,
        ifm_bits, wgt_bits, ofm_bits, tdm_bits,
        wgt_subv_size,
        is_standalone_dwc,
    )

    kernel_names = ['run_conv_a16w8_qdq', 'run_conv_xint8']
    kernel_includes = [
        'super.hh',
        'conv/direct_conv_int16x8_generic/direct_conv_int16x8_generic_wrapper.cc',
        'conv/procyon_kernels/conv/run_conv_wrapper.cc',
    ]
    
    host_cpp = 'dwc_main.cpp' if is_standalone_dwc else 'conv_main.cpp' 
        
    clean_overlay()
    compile_dataflow(dims, BackEnd.Adf, kernel_names, kernel_includes)
    build_sim_overlay(host_cpp, conv_preproc_directives(dims, BackEnd.Adf))

if __name__ == '__main__':
    main()

