import os
import sys
import json
import shutil
import logging
from typing import List, Optional
import numpy as np
import copy
import math
from math import floor, sqrt

from bilinear_resize_common import (
    bilinear_resize_input_subvol_dims,
    bilinear_resize_wgt_subvol_dims,
    generate_bilinear_resize_layer_params,
)

from kernels.bilinear_pixel_resize_bf16.bilinear_pixel_resize_bf16_kernel_params import (
    CoordinateTransfromationMode,
    genereate_bilinear_resize_kernel_params,
    mode_map,
    BilinearResizeShape,
)

from dataflow_common import (
    ceildiv, iceil,
    overlay_stack_addr, overlay_8x4_dma_connections,
    prm_memtile_memory, prm_memtile_mm2s, prm_memtile_s2mm,
    prm_shim_memory, prm_shim_mm2s,
    shim_alloc,
)

from dmacompiler import (
    OverlayShape, BackEnd,
    DataTransfer, SyncStrategy,
    AieTile, TileType, core_tile,
    AieDma, DmaDir, memtile_dma, shim_dma, core_dma, DmaChannel,
    CoreInstr, ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop,
    compute_buffer_size,
    TransferParams, generate_transfer_params,
    generate_shim_data_transfer,
    run_layer_compilation,
    set_dev_gen, DevGen, config,
    generate_memtile_data_transfers_N_to_1,
    generate_memtile_data_transfers_1_to_N,
    pack_reconfig_transfers,
    generate_core_buffer_config,
)
set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


class BilinearResizeL1MemoryAllocator:
    def __init__(self, dims: BilinearResizeShape, stack_addr: int):
        BankSize = 16 * 1024
        mem_align = 64
        self.ifm_ping_addr = 0
        self.ifm_pong_addr = 0
        self.wgt_ping_addr = 0
        self.wgt_pong_addr = 0
        self.ofm_ping_addr = 0
        self.ofm_pong_addr = 0
        self.qdq_addr = 0
        self.ifm_subvol_size = (dims.Cis * dims.Yis * dims.Xis * dims.act_bits) // 8
        self.wgt_subvol_size = (dims.wgt_subvol_dims * dims.act_bits) // 8
        self.ofm_subvol_size = (dims.Cos * dims.Yos * dims.Xos * dims.act_bits) // 8
        self.ifm_subvol_alloc_size = iceil((iceil(dims.Cis * dims.Yis * dims.Xis, 64) * dims.act_bits) // 8, mem_align)
        self.wgt_subvol_alloc_size = iceil((dims.wgt_subvol_dims * dims.act_bits) // 8, mem_align)
        self.ofm_subvol_alloc_size = iceil((iceil(dims.Cos * dims.Yos * dims.Xos, 64) * dims.act_bits) // 8, mem_align)
        self.wgt_subvol_offset = self.ifm_subvol_alloc_size
        self.qdq_bytes = dims.qdq_bytes
        '''
        NOTE: IFM and WGT subvolumes are unicast hence both the tensors are recieved as a single subvolume
        Within the kernel, the weight subvolume is offset from the base address of the IFM subvolume
        '''
        self.ifm_wgt_subvol_size = self.ifm_subvol_size + self.wgt_subvol_size 
        self.ifm_wgt_subvol_alloc_size = self.ifm_subvol_alloc_size + self.wgt_subvol_alloc_size 
        try:
            self.ofm_ping_addr = 0
            self.ofm_pong_addr = None if dims.ofm_l1_no_buff == 1 else (self.ofm_ping_addr + self.ofm_subvol_alloc_size)
            self.ifm_ping_addr = (self.ofm_pong_addr + self.ofm_subvol_alloc_size) 
            self.wgt_ping_addr = (self.ifm_ping_addr + self.ifm_subvol_alloc_size)
            self.ifm_pong_addr = None if dims.ifm_l1_no_buff == 1 else \
                (self.ifm_ping_addr + self.ifm_wgt_subvol_alloc_size)
            self.wgt_pong_addr = None if dims.ifm_l1_no_buff == 1 else \
                self.ifm_pong_addr + self.wgt_subvol_alloc_size
            self.qdq_addr = (self.ifm_pong_addr + self.ifm_wgt_subvol_alloc_size) if dims.ifm_l1_no_buff == 2 else \
                (self.ifm_ping_addr + self.ifm_wgt_subvol_alloc_size)
            assert self.qdq_addr + dims.qdq_bytes < stack_addr, "QDQ memory allocation overlaps with stack"
        except Exception as e:
            '''
            TODO: If scheme 1 failes, try scheme 2
            '''
            print(f"Error in L1 memory allocation: {e}")
            raise e


class BilinearResizeL2MemoryAllocator:
    def __init__(
        self,
        dims: BilinearResizeShape,
        l1alloc: BilinearResizeL1MemoryAllocator,
    ):
        memtile_size = 512 * 1024
        mem_align = 64
        self.param_addr = 0
        self.param_bytes = dims.param_subv_size * dims.aie_rows
        self.ifm_size = l1alloc.ifm_subvol_size * dims.aie_rows
        self.wgt_size = l1alloc.wgt_subvol_size * dims.aie_rows
        self.ofm_size = l1alloc.ofm_subvol_size * dims.aie_rows
        self.ifm_ping_addr = self.param_addr + self.param_bytes
        self.wgt_ping_addr = self.ifm_ping_addr + self.ifm_size
        self.ifm_pong_addr = self.wgt_ping_addr + self.wgt_size
        self.wgt_pong_addr = self.ifm_pong_addr + self.ifm_size
        self.ofm_ping_addr = self.wgt_pong_addr + self.wgt_size
        self.ofm_pong_addr = self.ofm_ping_addr + self.ofm_size
        self.qdq_addr = self.ofm_pong_addr + self.ofm_size
        self.qdq_bytes = l1alloc.qdq_bytes
        assert self.qdq_addr + self.qdq_bytes < memtile_size, "Memtile memory allocation exceeds memtile size"


def qdq_memtile_memory(dims) -> str:
    return f'Bytes:{dims.qdq_bytes}'

def qdq_memtile_s2mm() -> str:
    return f'Bytes'

def qdq_memtile_mm2s() -> str:
    return f'Bytes'

def qdq_shim_memory(dims) -> str:
    return f'Bytes:{dims.qdq_bytes}'

def qdq_shim_mm2s() -> str:
    return f'Bytes'


def core_to_split(dims: BilinearResizeShape, col: int, row: int) -> tuple[int, int, int]:
    '''Map core (col, row) to logical image split (Y_idx, X_idx, C_idx)'''
    def coreid(c: int, r: int) -> int:
        '''Flatten core to 1d index'''
        return (c * dims.aie_rows) + r
    # Key format is (Y_split, X_split, C_split)
    # Val is a lambda mapping physical core to image block position
    mode_lookup = {
        # Y_split, X_split, C_split 
        (1, 1, 32): (lambda id: (0, 0, id)),
        (4, 8, 1): (lambda id: ((id % 4), (id // 4), 0)),
        (8, 4, 1): (lambda id: ((id // 4), (id % 4), 0)),
    }
    (Y_idx, X_idx, C_idx,
     ) = mode_lookup[(dims.Y_split, dims.X_split, dims.C_split)](coreid(col, row))
    return Y_idx, X_idx, C_idx


def Yo_slice(dims: BilinearResizeShape, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Yo at core (col, row) during iteration i of Y_loop'''
    Y_idx, _, _ = core_to_split(dims, col, row)
    Yo_stride = dims.Yos * dims.Y_split
    Yo_start = min((Y_idx * dims.Yos) + (i * Yo_stride), dims.Yo)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo)
    return Yo_start, Yo_stop, Yo_stride


def Yi_slice(dims: BilinearResizeShape, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Yi at core (col, row) during iteration i of Y_loop'''
    Y_idx, _, _ = core_to_split(dims, col, row)
    in_bounds = ((Y_idx * dims.Yos) + (i * dims.Yos * dims.Y_split)) < dims.Yo
    Yi_stride = dims.Yis_step * dims.Y_split
    Yi_start = (
        (Y_idx * dims.Yis_step) + (i * Yi_stride) + dims.Yis_offset
        if in_bounds else dims.Yi
    )
    Yi_stop = Yi_start + dims.Yis if in_bounds else dims.Yi
    return Yi_start, Yi_stop, Yi_stride


def Xi_slice(dims: BilinearResizeShape, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Xi at core (col, row) during iteration i of X_loop'''
    _, X_idx, _ = core_to_split(dims, col, row)
    in_bounds = ((X_idx * dims.Xis_step) + (i * dims.Xis_step * dims.X_split)) < dims.Xo
    Xi_stride = dims.Xis_step * dims.X_split
    Xi_start = (
        (X_idx * dims.Xis_step) + (i * Xi_stride) + dims.Xis_offset
        if in_bounds else dims.Xi
    )
    Xi_stop = Xi_start + dims.Xis if in_bounds else dims.Xi
    return Xi_start, Xi_stop, Xi_stride


def Xo_slice(dims: BilinearResizeShape, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Xo at core (col, row) during iteration i of the X_loop'''
    _, X_idx, _ = core_to_split(dims, col, row)
    Xo_stride = dims.Xos * dims.X_split
    Xo_start = min((X_idx * dims.Xos) + (i * Xo_stride), dims.Xo)
    Xo_stop = min(Xo_start + dims.Xos, dims.Xo)
    return Xo_start, Xo_stop, Xo_stride


def C_slice(dims: BilinearResizeShape, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis C at core (col, row) during iteration i of the Co_loop'''
    _, _, C_idx = core_to_split(dims, col, row)
    C_stride = dims.Cos * dims.C_split
    C_start = min((C_idx * dims.Cos) + (i * C_stride), dims.Co)
    C_stop = min(C_start + dims.Cos, dims.Co)
    return C_start, C_stop, C_stride


def ifm_memtile_memory(dims: BilinearResizeShape) -> str:
    '''
    Only one subol of IFM fetch is needed for each core
    As there is no reuse pinning is not supported
    '''
    return f'row:{dims.aie_rows} C:{dims.Cis} Yi:{dims.Yis} Xi:{dims.Xis} C:8'


def ifm_memtile_mm2s(dims: BilinearResizeShape, row: int) -> str:
    '''
    IFM subvolume read access pattern
    '''
    return f'row:{row}:{row+1} C:{0}:{dims.Cis}:{8} Yi:0:{dims.Yis} Xi:0:{dims.Xis} C:0:8'


def ifm_memtile_s2mm(dims: BilinearResizeShape, col: int, row: int) -> list[str]:
    '''
    IFM subvolume write access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    access_pattern = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            Yi_start, Yi_stop, _ = Yi_slice(dims, col, row, y_iter)
            Xi_start, Xi_stop, _ = Xi_slice(dims, col, row, x_iter)
            Yis_start = abs(Yi_start) if Yi_start < 0 else 0
            Xis_start = abs(Xi_start) if Xi_start < 0 else 0 
            Yis_stop = min(Yi_stop, dims.Yi) - Yi_start 
            Xis_stop = min(Xi_stop, dims.Xi) - Xi_start 
            access_pattern.append(
                f'row:{row}:{row+1} Yi:{Yis_start}:{Yis_stop} Xi:{Xis_start}:{Xis_stop} C:0:{dims.Cis}'
            )
    return access_pattern 


def ofm_memtile_memory(dims: BilinearResizeShape) -> str:
    '''
    Only one subol of OFM fetch is needed for each core
    As there is no reuse pinning is not supported
    '''
    return f'row:{dims.aie_rows} Yo:{dims.Yos} Xo:{dims.Xos} C:{dims.Cos}'


def ofm_memtile_s2mm(dims: BilinearResizeShape, row: int) -> str:
    '''
    OFM subvolume write access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    return f'row:{row}:{row+1} C:0:{dims.Cos}:8 Yo:0:{dims.Yos} Xo:0:{dims.Xos} C:0:8'


def ofm_memtile_mm2s(dims: BilinearResizeShape, col: int, row: int) -> list[str]:
    '''
    OFM subvolume read access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    access_pattern = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            Yo_start, Yo_stop, _ = Yo_slice(dims, col, row, y_iter)
            Xo_start, Xo_stop, _ = Xo_slice(dims, col, row, x_iter)
            access_pattern.append(
                f'row:{row}:{row+1} Yo:{0}:{Yo_stop-Yo_start} Xo:{0}:{Xo_stop-Xo_start} C:0:{dims.Cos}'
            )
    return access_pattern


def ofm_shimtile_memory(dims: BilinearResizeShape) -> str:
    '''
    Full tensor
    '''
    return f'Yo:{dims.Yo} Xo:{dims.Xo} C:{dims.Co}'


def ofm_shimtile_s2mm(dims: BilinearResizeShape, col: int, row: int) -> list[str]:
    '''
    OFM subvolume read access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    access_pattern = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            Yo_start, Yo_stop, _ = Yo_slice(dims, col, row, y_iter)
            Xo_start, Xo_stop, _ = Xo_slice(dims, col, row, x_iter)
            C_start, Cos_stop, Cos_stride = C_slice(dims, col, row, 0)
            _, C_last_stop, _ = C_slice(dims, col, row, dims.C_loop-1)
            access_pattern.append(
                f'C:{C_start}:{C_last_stop}:{Cos_stride} Yo:{Yo_start}:{Yo_stop} Xo:{Xo_start}:{Xo_stop} C:{0}:{Cos_stop}'
            )
    return access_pattern


def ifm_shimtile_memory(dims: BilinearResizeShape) -> str:
    '''
    Full tensor
    '''
    return f'Yi:{dims.Yi} Xi:{dims.Xi} C:{dims.Ci}'


def ifm_shimtile_s2mm(dims: BilinearResizeShape, col: int, row: int) -> list[str]:
    '''
    IFM subvolume read access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    access_pattern = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            Yi_start, Yi_stop, _ = Yi_slice(dims, col, row, y_iter)
            Xi_start, Xi_stop, _ = Xi_slice(dims, col, row, x_iter)
            C_start, Cis_stop, Cis_stride = C_slice(dims, col, row, 0)
            _, C_last_stop, _ = C_slice(dims, col, row, dims.C_loop-1)
            # NOTE: Since shimtile cannot do padding any out of bound access is capped within the image dimensions
            Yi_start = max(0, Yi_start)
            Xi_start = max(0, Xi_start)
            Yi_stop = min(dims.Yi, Yi_stop)
            Xi_stop = min(dims.Xi, Xi_stop)
            access_pattern.append(
                f'C:{C_start}:{C_last_stop}:{Cis_stride} Yi:{Yi_start}:{Yi_stop} Xi:{Xi_start}:{Xi_stop} C:{0}:{Cis_stop}'
            )
    return access_pattern


def wgt_memtile_memory(dims: BilinearResizeShape) -> str:
    '''
    Only one subol of WGT fetch is needed for each core
    As there is no reuse pinning is not supported
    '''
    wgt_subvol_size = dims.wgt_subvol_dims * dims.act_bits // 8
    return f'row:{dims.aie_rows} Bytes:{wgt_subvol_size}'


def wgt_memtile_s2mm(dims: BilinearResizeShape) -> str:
    '''
    WGT subvolume write access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    wgt_subvol_size = dims.wgt_subvol_dims * dims.act_bits // 8
    return f'row:{0}:{dims.aie_rows} Bytes:{0}:{wgt_subvol_size}'


def wgt_memtile_mm2s(dims: BilinearResizeShape, row: int) -> str:
    '''
    WGT subvolume write access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    wgt_subvol_size = dims.wgt_subvol_dims * dims.act_bits // 8
    return f'row:{row}:{row+1} Bytes:{0}:{wgt_subvol_size}'


def wgt_shimtile_memory(dims: BilinearResizeShape) -> str:
    '''
    Full tensor
    '''
    Yblocks = dims.Y_loop * dims.Y_split
    Xblocks = dims.X_loop * dims.X_split
    return f'Yblocks:{Yblocks} Xblocks:{Xblocks} Bytes:{dims.wgt_subvol_dims * dims.act_bits // 8}'


def wgt_shimtile_mm2s(dims: BilinearResizeShape, col: int, row: int) -> list[str]:
    '''
    WGT subvolume read access pattern
    each subvolume dimension varies based on the col, row and Y_split, X_split
    '''
    access_pattern = []
    for y_iter in range(dims.Y_loop):
        for x_iter in range(dims.X_loop):
            Yidx, X_idx, _ = core_to_split(dims, col, row)
            Y_block_start = Yidx + (y_iter * dims.Y_split)
            X_block_start = X_idx + (x_iter * dims.X_split) 
            access_pattern.append(
                f'Yblocks:{Y_block_start}:{Y_block_start+1} Xblocks:{X_block_start}:{X_block_start+1} Bytes:{0}:{dims.wgt_subvol_dims * dims.act_bits // 8}'
            )
    return access_pattern

def gen_core_instrs(
    dims: BilinearResizeShape,
    l1_alloc: BilinearResizeL1MemoryAllocator,
    col: int,
    row: int,
) -> list:
    '''
    Generate core instructions for the given core (col, row)
    '''
    def largest_factor_pair(n):
        for d in range(floor(sqrt(n)), 1, -1):
            if (n % d) == 0:
                return (d, n // d)
        return (1, n)
    ifm_step0 = 1
    ifm_wrap0, ifm_wrap1 = largest_factor_pair(l1_alloc.ifm_subvol_size // 4)
    wgt_step = l1_alloc.ifm_subvol_alloc_size // 4

    ifm_wgt_bd_config = ConfigBuffer(
        DmaChannel(DmaDir.S2MM, 0), l1_alloc.ifm_ping_addr, l1_alloc.ifm_pong_addr,
        l1_alloc.ifm_wgt_subvol_size,
        step=[ifm_step0, ifm_wrap0, wgt_step],
        wrap=[ifm_wrap0, ifm_wrap1],
    )
        
    core_instrs = [
        # Recieve QDQ params
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), l1_alloc.qdq_addr, None, dims.qdq_bytes),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        # config for IFM and WGT subvolumes
        # config for OFM subvolumes
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), l1_alloc.ofm_ping_addr, l1_alloc.ofm_pong_addr, l1_alloc.ofm_subvol_size),
        Loop(dims.Y_loop, [
            Loop(dims.X_loop, [
                Loop(dims.C_loop, [
                    ifm_wgt_bd_config,
                    AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    CallKernel('run_bilinear_resize_bf16', generate_bilinear_resize_layer_params(
                                                                dims,
                                                                l1_alloc.wgt_subvol_offset,
                                                                col,
                                                                row,
                                                            )
                                ),
                    RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ]),
            ]),
        ])
    ]
    return core_instrs


def gen_bilinear_resize_schedule(
    dims: BilinearResizeShape,
    back_end: BackEnd,
    kernel_names: dict,
    kernel_include: List[str],
    verbose: bool = False,
) -> BilinearResizeShape:
    l1_alloc = BilinearResizeL1MemoryAllocator(dims, overlay_stack_addr())
    if verbose:
        print(f"L1 OFM Ping Addr: {l1_alloc.ofm_ping_addr} OFM_SIZE:{l1_alloc.ofm_subvol_alloc_size}")
        print(f"L1 OFM Pong Addr: {l1_alloc.ofm_pong_addr} OFM_SIZE:{l1_alloc.ofm_subvol_alloc_size}")
        print(f"L1 IFM Ping Addr: {l1_alloc.ifm_ping_addr} IFM_SIZE:{l1_alloc.ifm_subvol_alloc_size}")
        print(f"L1 WGT Ping Addr: {l1_alloc.ifm_ping_addr + l1_alloc.wgt_subvol_offset} WGT_SIZE:{l1_alloc.wgt_subvol_alloc_size}")
        print(f"L1 IFM Pong Addr: {l1_alloc.ifm_pong_addr} IFM_SIZE:{l1_alloc.ifm_subvol_alloc_size}")
        print(f"L1 WGT Pong Addr: {l1_alloc.ifm_pong_addr + l1_alloc.wgt_subvol_offset} WGT_SIZE:{l1_alloc.wgt_subvol_alloc_size}")
        print(f"L1 QDQ Addr: {l1_alloc.qdq_addr} QDQ_SIZE:{dims.qdq_bytes}")
    core_instrs = {}
    for col in range(dims.aie_cols):
        for row in range(dims.aie_rows):
            core_instrs[core_tile(col, row)] = gen_core_instrs(
                dims, l1_alloc, col, row,
            )
    l2_alloc = BilinearResizeL2MemoryAllocator(dims, l1_alloc)
    if verbose:
        print(f"L2 Param Addr: {l2_alloc.param_addr} PARAM_SIZE:{l2_alloc.param_bytes}")
        print(f"L2 IFM Ping Addr: {l2_alloc.ifm_ping_addr} IFM_SIZE:{l2_alloc.ifm_size}")
        print(f"L2 WGT Ping Addr: {l2_alloc.wgt_ping_addr} WGT_SIZE:{l2_alloc.wgt_size}")
        print(f"L2 IFM Pong Addr: {l2_alloc.ifm_pong_addr} IFM_SIZE:{l2_alloc.ifm_size}")
        print(f"L2 WGT Pong Addr: {l2_alloc.wgt_pong_addr} WGT_SIZE:{l2_alloc.wgt_size}")
        print(f"L2 OFM Ping Addr: {l2_alloc.ofm_ping_addr} OFM_SIZE:{l2_alloc.ofm_size}")
        print(f"L2 OFM Pong Addr: {l2_alloc.ofm_pong_addr} OFM_SIZE:{l2_alloc.ofm_size}")
        print(f"L2 QDQ Addr: {l2_alloc.qdq_addr} QDQ_SIZE:{l2_alloc.qdq_bytes}")
    
    bilinear_shim_alloc = shim_alloc()
    ddr_wgt_size = l1_alloc.wgt_subvol_size * (dims.Y_loop* dims.Y_split) * (dims.X_loop * dims.X_split)
    ddr_ifm_size = dims.Yi * dims.Xi * dims.Ci * dims.act_bits // 8
    ddr_ofm_size = dims.Yo * dims.Xo * dims.Co * dims.act_bits // 8
    ddr_ifm_wgt_size = ddr_ifm_size + ddr_wgt_size + dims.qdq_bytes
    print(f"DDR IFM Size: {ddr_ifm_size}")
    print(f"DDR WGT Size: {ddr_wgt_size}")
    print(f"DDR IFM WGT Size: {ddr_ifm_wgt_size}")
    print(f"DDR OFM Size: {ddr_ofm_size}")
    # Memtile transfers generation
    memtile_transfers = []
    for col in range(dims.aie_cols):
        # Layer params generation
        memtile_transfers += generate_memtile_data_transfers_1_to_N(
                repeat_counts=[1] + [0] * (dims.Y_loop * dims.X_loop - 1),
                write_dma=memtile_dma(col, DmaDir.S2MM, 0),
                read_dma_list=[memtile_dma(col, DmaDir.MM2S, row) for row in range(dims.aie_rows)],
                buffer_addrs=[l2_alloc.param_addr],
                buffer_size=l2_alloc.param_bytes,
                memory_format=prm_memtile_memory(dims),
                write_tiling_format=prm_memtile_s2mm(),
                read_tiling_format=[prm_memtile_mm2s(row) for row in range(dims.aie_rows)],
                parallel_locking=False,
            )
    for col in range(0, dims.aie_cols, 2):
        # QDQ transfers generation - broadcast
        memtile_transfers += generate_memtile_data_transfers_1_to_N(
                repeat_counts=[1] + [0] * (dims.Y_loop * dims.X_loop - 1),
                write_dma=memtile_dma(col, DmaDir.S2MM, 0),
                read_dma_list=[memtile_dma(col, DmaDir.MM2S, 4)],
                buffer_addrs=[l2_alloc.qdq_addr],
                buffer_size=l2_alloc.qdq_bytes,
                memory_format=qdq_memtile_memory(dims),
                write_tiling_format=qdq_memtile_s2mm(),
                read_tiling_format=[qdq_memtile_mm2s()],
        )
    memtile_ofm_transfer = []
    for col in range(dims.aie_cols):
        repeat_counts = [dims.C_loop] * (dims.Y_loop * dims.X_loop)
        ofm_s2mm_access_patterns = []
        ofm_memory_format = []
        ofm_mm2s_access_patterns = []
        for row in range(dims.aie_rows):
            ofm_memory_format.append([])
            ofm_s2mm_access_patterns.append([])
            for _ in range(dims.X_loop):
                for _ in range(dims.Y_loop):
                    '''
                    If there is reconfig for x_iter and y_iter,
                    the memory format and write access pattern will not change
                    they are replicated for all the iters
                    '''
                    ofm_memory_format[row].append(ofm_memtile_memory(dims))
                    ofm_s2mm_access_patterns[row].append(ofm_memtile_s2mm(dims, row))
            # NOTE: ofm_memtile_mm2s returns list of access patterns for x_iter and y_iter
            ofm_mm2s_access_patterns.append(ofm_memtile_mm2s(dims, col, row))
        memtile_ofm_transfer += generate_memtile_data_transfers_N_to_1(
            repeat_counts=repeat_counts,
            write_dma_list=[memtile_dma(col, DmaDir.S2MM, 2+row) for row in range(dims.aie_rows)],
            read_dma=memtile_dma(col, DmaDir.MM2S, 5),
            buffer_addrs=[l2_alloc.ofm_ping_addr],  # NOTE: IFM ping-pong is disabled as the memtile runs out of BD
            buffer_size=l2_alloc.ofm_size,
            memory_format=ofm_memory_format,
            write_tiling_format=ofm_s2mm_access_patterns,
            read_tiling_format=ofm_mm2s_access_patterns,
            bits_per_block=dims.act_bits,
            parallel_locking=True,
            write_buffer_offset=[0],
            read_buffer_offset=[0],
        )
    memtile_transfers += memtile_ofm_transfer
    memtile_ifm_transfer = []
    for col in range(dims.aie_cols):
        repeat_counts = [dims.C_loop] * (dims.Y_loop * dims.X_loop)
        ifm_s2mm_access_patterns = []
        ifm_memory_format = []
        ifm_mm2s_access_patterns = []
        wgt_mm2s_bd_list = []
        for row in range(dims.aie_rows):
            ifm_memory_format.append([])
            ifm_mm2s_access_patterns.append([])
            for _ in range(dims.X_loop):
                for _ in range(dims.Y_loop):
                    '''
                    If there is reconfig for x_iter and y_iter,
                    the memory format and write access pattern will not change
                    they are replicated for all the iters
                    '''
                    ifm_memory_format[row].append(ifm_memtile_memory(dims))
                    ifm_mm2s_access_patterns[row].append(ifm_memtile_mm2s(dims, row))
            # NOTE: ofm_memtile_mm2s returns list of access patterns for x_iter and y_iter
            ifm_s2mm_access_patterns.append(ifm_memtile_s2mm(dims, col, row))
            wgt_mm2s_bd_list.append(generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims, row),
                buffer_offset=dims.aie_rows * l1_alloc.ifm_subvol_size,
            ))
        ifm_transfer = generate_memtile_data_transfers_1_to_N(
            repeat_counts=repeat_counts,
            write_dma=memtile_dma(col, DmaDir.S2MM, 1),
            read_dma_list=[memtile_dma(col, DmaDir.MM2S, row) for row in range(dims.aie_rows)],
            buffer_addrs=[l2_alloc.ifm_ping_addr],  # NOTE: IFM ping-pong is disabled as the memtile runs out of BD
            buffer_size=l2_alloc.ifm_size,
            memory_format=ifm_memory_format,
            write_tiling_format=ifm_s2mm_access_patterns,
            read_tiling_format=ifm_mm2s_access_patterns,
            bits_per_block=dims.act_bits,
            parallel_locking=False,
            write_buffer_offset=[0],
            read_buffer_offset=[0],
        )
        wgt_s2mm_writre_bd = generate_transfer_params(
            memtile_dma(col, DmaDir.S2MM, 1),
            wgt_memtile_memory(dims),
            wgt_memtile_s2mm(dims),
            buffer_offset=dims.aie_rows * l1_alloc.ifm_subvol_size,
        )
        ifm_transfer[0].write_params.append(wgt_s2mm_writre_bd)
        ifm_transfer[0].read_params+=wgt_mm2s_bd_list
        memtile_ifm_transfer += ifm_transfer
    memtile_transfers += memtile_ifm_transfer

    # Shimtile transfers generation
    ''' 
    NOTE: In the dataflow, the ifm and wgt are fetched on the same shim channel
    Due to AIESIM limitation, both the ifm and wgt must be part of the same GMIO buffer
    However there is no such limitation in the hardware and the ifm and wgt can be fetched
    on different BOs on DDR.
    '''
    shimtile_transfers = []
    for col in range(dims.aie_cols):
        # Layer params generation
        shimtile_transfers += [generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop * dims.X_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            bilinear_shim_alloc.prm_buffer_id,
            prm_shim_memory(dims),
            prm_shim_mm2s(col),
        )]
    for col in range(0, dims.aie_cols, 2):
        # QDQ transfers generation
        shimtile_transfers += [generate_shim_data_transfer(
            [1] + [0] * (dims.Y_loop * dims.X_loop - 1),
            shim_dma(col, DmaDir.MM2S, 0),
            bilinear_shim_alloc.ifm_buffer_id if back_end is BackEnd.Adf else bilinear_shim_alloc.wgt_buffer_id,
            qdq_shim_memory(dims),
            qdq_shim_mm2s(),
            buffer_offset=(ddr_ifm_size+ddr_wgt_size) if back_end is BackEnd.Adf else ddr_wgt_size,
        )]
    shim_ofm_transfers = []
    for col in range(dims.aie_cols):
        shim_ofm_repeat_counts = [dims.C_loop] * (dims.Y_loop * dims.X_loop)
        write_transfers = []
        for row in range(dims.aie_rows):
            '''
            Each row within a column will produce different shards of X, Y or C
            based on the Y_split, X_split and C_split
            hence there is a 4 way chained BD
            '''
            ofm_shim_memory_format = []
            for _ in range(dims.X_loop):
                for _ in range(dims.Y_loop):
                    '''
                    If there is reconfig for x_iter and y_iter,
                    the memory format will not change
                    they are replicated for all the iters
                    '''
                    ofm_shim_memory_format.append(ofm_shimtile_memory(dims))
            if dims.C_loop > 1:
                repeat_coeffs, gen_write_transfer = pack_reconfig_transfers(
                                                        shim_dma(col, DmaDir.S2MM, 0),
                                                        ofm_shim_memory_format,
                                                        ofm_shimtile_s2mm(dims, col, row),
                                                        bits_per_elem=dims.act_bits,
                                                        use_iter_step=[True],
                                                    )
                write_transfers.append(gen_write_transfer)
                assert all(coeff_val  == repeat_coeffs[0] for coeff_val in repeat_coeffs), "Repeat coeffs should be same for all the iterations"
                assert repeat_coeffs[0] == dims.C_loop, f"Repeat coeffs of {repeat_coeffs[0]} and C_loop of {dims.C_loop} mismatch for shim OFM transfers"
            else:
                gen_write_transfer = pack_reconfig_transfers(
                                                        shim_dma(col, DmaDir.S2MM, 0),
                                                        ofm_shim_memory_format,
                                                        ofm_shimtile_s2mm(dims, col, row),
                                                        bits_per_elem=dims.act_bits,
                                                        # use_iter_step=[True],
                                                    )
                write_transfers.append(gen_write_transfer)
        shim_ofm_transfers += [
            DataTransfer(
                shim_ofm_repeat_counts,
                AieTile(TileType.Shim, col), [bilinear_shim_alloc.ofm_buffer_id], ddr_ofm_size,
                write_transfers,
                [],
            )
        ]
    shimtile_transfers += shim_ofm_transfers
    '''
    IFM / WGT transfers generation
    Since IFM and WGT are fetched per subvol of OFM and there is no accumulation, 
    both the tensors traverse through the unicast path.
    hence a design choice is made to fetch both the tensors through the same channel
    in iterleved fashion
    '''
    shimtile_ifm_wgt_transfers = []
    for col in range(dims.aie_cols):
        repeat_counts = [dims.C_loop] * (dims.Y_loop * dims.X_loop)
        read_transfers = []
        for row in range(dims.aie_rows):
            ifm_shim_memory_format = []
            for _ in range(dims.X_loop):
                for _ in range(dims.Y_loop):
                    '''
                    If there is reconfig for x_iter and y_iter,
                    the memory format will not change
                    they are replicated for all the iters
                    '''
                    ifm_shim_memory_format.append(ifm_shimtile_memory(dims))
            if dims.C_loop > 1:
                repeat_coeffs, gen_ifm_read_transfer = pack_reconfig_transfers(
                                                        shim_dma(col, DmaDir.MM2S, 1),
                                                        ifm_shim_memory_format,
                                                        ifm_shimtile_s2mm(dims, col, row),  # This function returns list of access patterns for x_iter and y_iter
                                                        bits_per_elem=dims.act_bits,
                                                        use_iter_step=[True],
                                                    )
                read_transfers.append(gen_ifm_read_transfer)
                assert all(coeff_val == repeat_coeffs[0] for coeff_val in repeat_coeffs), "Repeat coeffs should be same for all the iterations"
                assert repeat_coeffs[0] == dims.C_loop, "Repeat coeffs and C_loop mismatch for shim IFM transfers"
            else:
                gen_ifm_read_transfer = pack_reconfig_transfers(
                                                        shim_dma(col, DmaDir.MM2S, 1),
                                                        ifm_shim_memory_format,
                                                        ifm_shimtile_s2mm(dims, col, row),  # This function returns list of access patterns for x_iter and y_iter
                                                        bits_per_elem=dims.act_bits,
                                                        # use_iter_step=[True],
                                                    )
                read_transfers.append(gen_ifm_read_transfer)
        ''' 
        NOTE: In the dataflow, the ifm and wgt are fetched on the same shim channel
        Due to AIESIM limitation, both the ifm and wgt must be part of the same GMIO buffer
        However there is no such limitation in the hardware and the ifm and wgt can be fetched
        on different BOs on DDR.
        '''
        for row in range(dims.aie_rows):
            wgt_shim_memory_format = []
            for _ in range(dims.X_loop):
                for _ in range(dims.Y_loop):
                    wgt_shim_memory_format.append(wgt_shimtile_memory(dims))
            wgt_read_transfer = None
            if back_end is BackEnd.Adf:
                wgt_read_transfer = pack_reconfig_transfers(
                                                        shim_dma(col, DmaDir.MM2S, 1),
                                                        wgt_shim_memory_format,
                                                        wgt_shimtile_mm2s(dims, col, row),  # This function returns list of access patterns for x_iter and y_iter
                                                        buffer_offset=[ddr_ifm_size],
                                                    )
            else:
                wgt_read_transfer = pack_reconfig_transfers(
                                                        shim_dma(col, DmaDir.MM2S, 1),
                                                        wgt_shim_memory_format,
                                                        wgt_shimtile_mm2s(dims, col, row),  # This function returns list of access patterns for x_iter and y_iter
                                                    )
                wgt_read_transfer.shim_buffer_index = bilinear_shim_alloc.wgt_buffer_id
            read_transfers.append(wgt_read_transfer)
        shimtile_ifm_wgt_transfers += [
            DataTransfer(
                repeat_counts,
                AieTile(TileType.Shim, col), [bilinear_shim_alloc.ifm_buffer_id], ddr_ifm_wgt_size,
                [],
                read_transfers,
            )
        ]
    shimtile_transfers += shimtile_ifm_wgt_transfers

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_include,
        core_instrs,
        memtile_transfers,
        shimtile_transfers,
        overlay_8x4_dma_connections(),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id=0,
        enable_debug_print=True,
    )

    return dims