'''
# Common definitions for convolution dataflow
# This file contains shared definitions and utility functions for convolution operations
'''
from dataclasses import dataclass
from typing import no_type_check

from kernel.gemm_qdq_int16x8.gemm_qdq_int16x8_params import generate_gemm_qdq_a16w8_params
from kernel.gemm_qdq_int16x4.gemm_qdq_int16x4_params import generate_gemm_qdq_a16w4_params
from kernel.conv.conv_noqdq_a8w8_params import generate_conv_noqdq_a8w8_params
from kernel.conv_qdq_int16x8.conv_qdq_a16w8_params import generate_conv_qdq_a16w8_params
from kernel.gemm_qdq_int16x16.gemm_qdq_int16x16_params import generate_gemm_qdq_a16a16_params
from kernel.gemm_qdq_int16x16_transpose.gemm_qdq_int16x16_transpose_params import generate_gemm_qdq_a16a16_transpose_params

from dmacompiler import (
    generate_core_buffer_config,
    core_dma,
    DmaDir,
    DmaChannel,
    AcqBuffer,
    RelBuffer,
    CallKernel,
    set_dev_gen,
    DevGen,
    compute_buffer_size,
    Loop,
    ConfigBuffer,
)

from utils.utils_common import (
    overlay_3x4_core_stack_addr,
    BaseDims, log,
    split_to_mode,
    core_to_split,
)

from scheduler.common import (
    LinearOpType,
    TensorSlicer,
    broadcast_channels,
    overlay_3x4_F_ids,
    unicast_channels,
    overlay_3x4_O_ids,
)

from scheduler.conv.conv_config_builders import (
    ConvDims,
    ConvMapping,
)

set_dev_gen(DevGen.Aie4)


@dataclass
class convL2Memory:
    '''L2 memory buffer configuration for convolution operations - struct-like implementation'''
    __slots__ = ('param_addr', 'param_size', 'ifm_ping_addr', 'ifm_pong_addr', 'ifm_size',
                 'wgt_ping_addr', 'wgt_pong_addr', 'wgt_size', 'ofm_ping_addr', 'ofm_pong_addr', 'ofm_size',
                 'qdq_size', 'qdq_addr')

    def __init__(
        self,
        param_addr: int,
        param_size: int,
        ifm_ping_addr: int,
        ifm_pong_addr: int,
        ifm_size: int,
        wgt_ping_addr: int,
        wgt_pong_addr: int,
        wgt_size: int,
        ofm_ping_addr: int,
        ofm_pong_addr: int,
        ofm_size: int,
        qdq_size: int = 0,
        qdq_addr: int = 0,
    ):
        '''Initialize the convL2Memory with buffer addresses and sizes'''
        self.param_addr = param_addr
        self.param_size = param_size
        self.ifm_ping_addr = ifm_ping_addr
        self.ifm_pong_addr = ifm_pong_addr
        self.ifm_size = ifm_size
        self.wgt_ping_addr = wgt_ping_addr
        self.wgt_pong_addr = wgt_pong_addr
        self.wgt_size = wgt_size
        self.ofm_ping_addr = ofm_ping_addr
        self.ofm_pong_addr = ofm_pong_addr
        self.ofm_size = ofm_size
        self.qdq_size = qdq_size
        self.qdq_addr = qdq_addr

        # Validate that the buffers do not overlap
        self._validate_no_overlap()

    def _validate_no_overlap(self):
        '''Validate that memory buffers do not overlap by sorting addresses first'''
        # Create list of all buffer regions (start_addr, end_addr, name)
        buffers = [
            (self.param_addr, self.param_addr + self.param_size, "param"),
            (self.ifm_ping_addr, self.ifm_ping_addr + self.ifm_size, "ifm_ping"),
            (self.ifm_pong_addr, self.ifm_pong_addr + self.ifm_size, "ifm_pong"),
            (self.wgt_ping_addr, self.wgt_ping_addr + self.wgt_size, "wgt_ping"),
            (self.wgt_pong_addr, self.wgt_pong_addr + self.wgt_size, "wgt_pong"),
            (self.ofm_ping_addr, self.ofm_ping_addr + self.ofm_size, "ofm_ping"),
            (self.ofm_pong_addr, self.ofm_pong_addr + self.ofm_size, "ofm_pong"),
            (self.qdq_addr, self.qdq_addr + self.qdq_size, "qdq"),
        ]

        # Sort buffers by start address
        buffers.sort(key=lambda x: x[0])

        # Check for overlaps between consecutive buffers after sorting
        for i in range(len(buffers) - 1):
            current_start, current_end, current_name = buffers[i]
            next_start, next_end, next_name = buffers[i + 1]
            # pass if either buffer has zero size
            if current_end == current_start or next_end == next_start:
                continue
            # Check if current buffer overlaps with next buffer
            # Overlap occurs if current_end > next_start
            if current_end > next_start:
                overlap_start = next_start
                overlap_end = min(current_end, next_end)
                raise ValueError(
                    f"Memory buffer overlap detected between {current_name} "
                    f"[0x{current_start:x}-0x{current_end:x}) and {next_name} "
                    f"[0x{next_start:x}-0x{next_end:x}). "
                    f"Overlapping region: [0x{overlap_start:x}-0x{overlap_end:x})"
                )


@dataclass
class ConvL2MemoryAllocator:
    '''Allocate L2 memory buffers for convolution operations'''
    def __init__(
        self,
        dims: ConvDims,
        param_l2_memory: str,
        ifm_l2_memory: str,
        wgt_l2_memory: str,
        ofm_l2_memory: str,
        ifm_double_buffer: bool = False,
        wgt_double_buffer: bool = False,
        ofm_double_buffer: bool = False,
        gemm_mode: str = 'wgt'
    ):
        if gemm_mode == 'wgt':
            memtile_size = 3*2**20  # 1 MiB memtile size
            # Compute the sizes of each L2 memory buffer
            self.param_size = compute_buffer_size(param_l2_memory)
            self.ifm_size = compute_buffer_size(ifm_l2_memory, dims.ifm_bits)
            self.wgt_size = compute_buffer_size(wgt_l2_memory)  # WGT subvol is Bytes
            self.ofm_size = compute_buffer_size(ofm_l2_memory, dims.ofm_bits)
            # allocate the L2 buffers based on the strategy
            self.param_addr = 0
            self.ifm_ping_addr = self.param_addr + self.param_size
            if ifm_double_buffer:
                self.ifm_pong_addr = self.ifm_ping_addr + self.ifm_size
            else:
                # NOTE: Set unused pong address to a safe value outside the memtile
                self.ifm_pong_addr = 9*2**20
            self.wgt_ping_addr = (self.ifm_pong_addr + self.ifm_size
                                  if ifm_double_buffer else self.ifm_ping_addr + self.ifm_size)
            if wgt_double_buffer:
                self.wgt_pong_addr = self.wgt_ping_addr + self.wgt_size
            else:
                # NOTE: Set unused pong address to a safe value outside the memtile
                self.wgt_pong_addr = 9*2**20
            self.ofm_ping_addr = (self.wgt_pong_addr + self.wgt_size
                                  if wgt_double_buffer else self.wgt_ping_addr + self.wgt_size)
            if ofm_double_buffer:
                self.ofm_pong_addr = self.ofm_ping_addr + self.ofm_size
            else:
                # NOTE: Set unused pong address to a safe value outside the memtile
                self.ofm_pong_addr = 9*2**20
            # Check if the allocated buffers fit within a single memtile
            total_size = self.param_size + self.ifm_size + \
                (self.ifm_size if ifm_double_buffer else 0) + \
                self.wgt_size + (self.wgt_size if wgt_double_buffer else 0) + \
                self.ofm_size + (self.ofm_size if ofm_double_buffer else 0)
            if total_size > memtile_size:
                raise ValueError(
                    f"Allocated L2 buffers exceed memtile size: {total_size} > {memtile_size}. "
                    "Consider reducing buffer sizes or using fewer double buffers."
                )
            # Create the convL2Memory object
            self.memory = convL2Memory(
                param_addr=self.param_addr,
                param_size=self.param_size,
                ifm_ping_addr=self.ifm_ping_addr,
                ifm_pong_addr=self.ifm_pong_addr,
                ifm_size=self.ifm_size,
                wgt_ping_addr=self.wgt_ping_addr,
                wgt_pong_addr=self.wgt_pong_addr,
                wgt_size=self.wgt_size,
                ofm_ping_addr=self.ofm_ping_addr,
                ofm_pong_addr=self.ofm_pong_addr,
                ofm_size=self.ofm_size,
            )
        elif gemm_mode == 'act':
            memtile_size = 3*2**20  # 1 MiB memtile size
            # Compute the sizes of each L2 memory buffer
            self.param_size = compute_buffer_size(param_l2_memory)
            self.ifm_size = compute_buffer_size(ifm_l2_memory, dims.ifm_bits)
            self.wgt_size = compute_buffer_size(wgt_l2_memory, dims.wgt_bits)
            self.ofm_size = compute_buffer_size(ofm_l2_memory, dims.ofm_bits)
            self.qdq_size = dims.qdq_param_size

            # allocate the L2 buffers based on the strategy
            self.param_addr = 0
            self.ifm_ping_addr = self.param_addr + self.param_size
            if ifm_double_buffer:
                self.ifm_pong_addr = self.ifm_ping_addr + self.ifm_size
            else:
                self.ifm_pong_addr = 9*2**20

            self.wgt_ping_addr = (self.ifm_pong_addr + self.ifm_size
                                  if ifm_double_buffer else self.ifm_ping_addr + self.ifm_size)
            if wgt_double_buffer:
                self.wgt_pong_addr = self.wgt_ping_addr + self.wgt_size
            else:
                self.wgt_pong_addr = 2*9*2**20

            self.ofm_ping_addr = (self.wgt_pong_addr + self.wgt_size
                                  if wgt_double_buffer else self.wgt_ping_addr + self.wgt_size)
            if ofm_double_buffer:
                self.ofm_pong_addr = self.ofm_ping_addr + self.ofm_size
            else:
                self.ofm_pong_addr = 4*9*2**20

            # New fixed (non ping-pong) buffers placed sequentially after OFM region
            self.qdq_addr = (self.ofm_pong_addr + self.ofm_size
                             if ofm_double_buffer else self.ofm_ping_addr + self.ofm_size)

            # Check if the allocated buffers fit within a single memtile
            total_size = (
                self.param_size +
                self.ifm_size * (2 if ifm_double_buffer else 1) +
                self.wgt_size * (2 if wgt_double_buffer else 1) +
                self.ofm_size * (2 if ofm_double_buffer else 1) +
                self.qdq_size
            )
            if self.qdq_addr + self.qdq_size > memtile_size or total_size > memtile_size:
                raise ValueError(
                    f"Allocated L2 buffers exceed memtile size: {total_size} > {memtile_size}. "
                    "Reduce sizes or disable some double buffers."
                )
            self.memory = convL2Memory(
                param_addr=self.param_addr,
                param_size=self.param_size,
                ifm_ping_addr=self.ifm_ping_addr,
                ifm_pong_addr=self.ifm_pong_addr,
                ifm_size=self.ifm_size,
                wgt_ping_addr=self.wgt_ping_addr,
                wgt_pong_addr=self.wgt_pong_addr,
                wgt_size=self.wgt_size,
                ofm_ping_addr=self.ofm_ping_addr,
                ofm_pong_addr=self.ofm_pong_addr,
                ofm_size=self.ofm_size,
                qdq_size=self.qdq_size,
                qdq_addr=self.qdq_addr,
            )
        else:
            raise ValueError(f"Unsupported gemm_mode={gemm_mode}. Supported modes are 'wgt' and 'act'.")


@dataclass(frozen=True)
class ConvDataFlowRepeats:
    '''
    Define the number of repeats for each conv dataflow operation
    Each stage has a dictionary mapping the column index to a list of repeats
    '''
    ifm_L3_mm2s_repeats: dict[int, list[int]]
    wgt_L3_mm2s_repeats: dict[int, list[int]]
    ofm_L3_s2mm_repeats: dict[int, list[int]]
    ifm_L2_s2mm_repeats: dict[int, list[int]]
    ifm_L2_mm2s_repeats: dict[int, list[int]]
    ofm_L2_mm2s_repeats: dict[int, list[int]]
    ofm_L2_s2mm_repeats: dict[int, list[int]]
    wgt_L2_mm2s_repeats: dict[int, list[int]]
    wgt_L2_s2mm_repeats: dict[int, list[int]]


def map_shim_ch_memtile_ch(shim_channel: int) -> int:
    '''For a given shim mm2s channel return the memtile fill channel'''
    fill_channels = overlay_3x4_F_ids()
    return fill_channels[shim_channel]


def ifm_memtile_channels(dims: ConvDims, col: int) -> list[tuple[int, int]]:
    '''Map spatial split and column to IFM channel allocations (row, id)'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: unicast_channels(),
        1: broadcast_channels(col),
    }
    channels = channel_lookup[mode]
    return channels


def wgt_memtile_channels(dims: ConvDims, col: int) -> list[tuple[int, int]]:
    '''Map spatial split and column to WGT channel allocations (row, id)'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: broadcast_channels(col),
        1: unicast_channels(),
    }
    channels = channel_lookup[mode]
    return channels


def ofm_memtile_channels() -> list[tuple[int, int]]:
    '''Generate OFM channels allocations (row, id)'''
    return list(enumerate(overlay_3x4_O_ids()))


def ifm_core_channel(dims: ConvDims) -> int:
    '''Map spatial split to IFM core channel ID allocation'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: 0,
        1: 1,
    }
    channel = channel_lookup[mode]
    return channel


def wgt_core_channel(dims: ConvDims) -> int:
    '''Map spatial split to WGT core channel ID allocation'''
    mode = split_to_mode(dims)
    channel_lookup = {
        0: 1,
        1: 0,
    }
    channel = channel_lookup[mode]
    return channel


# Backward compatibility functions - these use the TensorSlicer class
def Yi_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Yi at core (col, row) during iteration i of Y_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Yi_slice(col, row, i)


def Xi_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Xi at core (col, row) during iteration i of X_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Xi_slice(col, row, i)


def Yo_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Yo at core (col, row) during iteration i of Y_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Yo_slice(col, row, i)


def Xo_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Xo at core (col, row) during iteration i of the X_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Xo_slice(col, row, i)


def Co_slice(dims: BaseDims, col: int, row: int, i: int) -> tuple[int, int, int]:
    '''Slice for axis Co at core (col, row) during iteration i of the Co_loop'''
    slicer = TensorSlicer(dims)
    return slicer.Co_slice(col, row, i)


def Co_split_idxs(dims: ConvDims, col: int) -> list[int]:
    '''Generate unique Cout split indicies for a given column'''
    idxs = [Co_idx
            for row in range(dims.aie_rows)
            for _, _, _, Co_idx in (core_to_split(dims, col, row),)]
    return idxs


def Co_split_size(dims: ConvDims) -> int:
    '''Calculate the number of Cout split blocks for a single column'''
    # NOTE: The Cout split size will be regular across all columns,
    # so it's safe to just use column zero.
    idxs = Co_split_idxs(dims, 0)
    size = max(idxs) - min(idxs) + 1
    return size


def Yo_split_idxs(dims: ConvDims, col: int) -> list[int]:
    '''Generate unique Yo split indicies for a given column'''
    idxs = [Yo_idx
            for row in range(dims.aie_rows)
            for _, Yo_idx, _, _ in (core_to_split(dims, col, row),)]
    return idxs


def Yo_split_size(dims: ConvDims) -> int:
    '''Calculate the number of Yo split blocks for a single column'''
    # NOTE: The Yo split size will be regular across all columns,
    # so it's safe to just use column zero.
    idxs = Yo_split_idxs(dims, 0)
    log(f"Yo split idxs: {idxs}")
    size = max(idxs) - min(idxs) + 1
    return size


def Xo_split_idxs(dims: ConvDims, col: int) -> list[int]:
    '''Generate unique Xo split indicies for a given column'''
    idxs = [Xo_idx
            for row in range(dims.aie_rows)
            for _, _, Xo_idx, _ in (core_to_split(dims, col, row),)]
    return idxs


def Xo_split_size(dims: ConvDims) -> int:
    '''Calculate the number of Xo split blocks for a single column'''
    # NOTE: The Xo split size will be regular across all columns,
    # so it's safe to just use column zero.
    idxs = Xo_split_idxs(dims, 0)
    size = max(idxs) - min(idxs) + 1
    return size


def Co_split_offset(dims: ConvDims, col: int, row: int) -> int:
    '''Calculate the relative Cout split offset within a column for the core at (col, row)'''
    idxs = Co_split_idxs(dims, col)
    _, _, _, Co_idx = core_to_split(dims, col, row)
    offset = Co_idx - min(idxs)
    return offset


def Yo_split_offset(dims: ConvDims, col: int, row: int) -> int:
    '''Calculate the relative Yo split offset within a column for the core at (col, row)'''
    idxs = Yo_split_idxs(dims, col)
    _, Yo_idx, _, _ = core_to_split(dims, col, row)
    offset = Yo_idx - min(idxs)
    return offset


def Xo_split_offset(dims: ConvDims, col: int, row: int) -> int:
    '''Calculate the relative Xo split offset within a column for the core at (col, row)'''
    idxs = Xo_split_idxs(dims, col)
    _, _, Xo_idx, _ = core_to_split(dims, col, row)
    offset = Xo_idx - min(idxs)
    return offset


def ifm_core_memory(dims: ConvDims) -> str:
    '''Define IFM L1 data order and shape'''
    return f'Ci:{dims.Cis} Yi:{dims.Yis} Xi:{dims.Xis} Ci:{dims.Ci_gran}'


def ifm_shim_memory(dims: ConvDims) -> str:
    '''Define IFM DDR data order and shape'''
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'


def wgt_core_memory_wgt(dims: ConvDims) -> str:
    '''Define WGT L1 data order'''
    return f'Subv:{dims.wgt_L1_size}'


def wgt_core_memory_act(dims: ConvDims) -> str:
    '''Define WGT L1 data order'''
    if dims.transpose_wgts:
        return f'Co:{dims.Cos} Ci:{dims.Cis}'
    return f'Co:{dims.Cos} Ci:{dims.Cis} Co:{dims.Co_gran_wgt} '


# NOTE: How does this compare to conv_L2_schedule
def wgt_core_memory(dims: ConvDims, gemm_mode: str = 'wgt') -> str:
    '''Define WGT L1 data order'''
    if gemm_mode == 'wgt':
        return wgt_core_memory_wgt(dims)
    if gemm_mode == 'act':
        return wgt_core_memory_act(dims)
    raise ValueError(f"Unsupported gemm_mode={gemm_mode}. Supported modes are 'wgt' and 'act'.")


def wgt_shim_memory_wgt(dims: ConvDims) -> str:
    '''Define WGT DDR data order'''
    # NOTE: The weights are pre-formatted in blocks according to the subvolume size.
    # We traverse first in Cout, then in Cin. This is a very intentional decision,
    # since the traversal along Cout for Co_split can be folded into a linear access.
    # This is required to keep the shim access pattern within five dimensions.
    return f'Cib:{dims.Ci_loop} Cob:{dims.Co_loop * dims.Co_split} Subv:{dims.wgt_L1_size}'


def wgt_shim_memory_act(dims: ConvDims) -> str:
    '''Define WGT DDR data order'''
    # NOTE: The weights are pre-formatted in blocks according to the subvolume size.
    # We traverse first in Cout, then in Cin. This is a very intentional decision,
    # since the traversal along Cout for Co_split can be folded into a linear access.
    # This is required to keep the shim access pattern within five dimensions.
    if dims.transpose_wgts:
        return f'Yi:{dims.Yi} Co:{dims.Co} Ci:{dims.Ci}'
    return f'Yi:{dims.Yi} Ci:{dims.Ci_orig} Co:{dims.Co}'


def wgt_shim_memory(dims: ConvDims, gemm_mode: str = 'wgt') -> str:
    '''Define WGT DDR data order'''
    if gemm_mode == 'wgt':
        return wgt_shim_memory_wgt(dims)
    if gemm_mode == 'act':
        return wgt_shim_memory_act(dims)
    raise ValueError(f"Unsupported gemm_mode={gemm_mode}. Supported modes are 'wgt' and 'act'.")


def ofm_core_memory(dims: ConvDims) -> str:
    '''Define OFM L1 data order and shape'''
    assert dims.Cos == dims.Co_gran
    return f'Yo:{dims.Yos} Xo:{dims.Xos} Co:{dims.Cos}'


def ofm_shim_memory(dims: ConvDims) -> str:
    '''Define OFM DDR data order and shape'''
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'


def ifm_core_s2mm(dims: ConvDims) -> str:
    '''IFM core s2mm access pattern for tile at (col, row)'''
    # NOTE: We require that the core access pattern doesn't change
    # in time, so it's safe to just use the slice in iteration 0
    # for size calculations.
    return f'Ci:0:{dims.Cis}:{dims.Ci_gran} Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Ci_gran}'


def wgt_core_s2mm_wgt(dims: ConvDims) -> str:
    '''WGT core s2mm access pattern'''
    return f'Subv:0:{dims.wgt_L1_size}'


def wgt_core_s2mm_act(dims: ConvDims) -> str:
    '''WGT core s2mm access pattern'''
    if dims.transpose_wgts:
        return f'Co:0:{dims.Cos} Ci:0:{dims.Cis}'
    return f'Co:0:{dims.Cos}:{dims.Co_gran_wgt} Ci:0:{dims.Cis} Co:0:{dims.Co_gran_wgt} '


def wgt_core_s2mm(dims: ConvDims, gemm_mode: str = 'wgt') -> str:
    '''WGT core s2mm access pattern'''
    if gemm_mode == 'wgt':
        return wgt_core_s2mm_wgt(dims)
    if gemm_mode == 'act':
        return wgt_core_s2mm_act(dims)
    raise ValueError(f"Unsupported gemm_mode={gemm_mode}. Supported modes are 'wgt' and 'act'.")


def ofm_core_mm2s(dims: ConvDims) -> str:
    '''OFM core mm2s access pattern for tile at (col, row)'''
    return f'Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Cos}'


def conv_noqdq_a8w8_qdq_kernel_name() -> str:
    '''Return the kernel name for conv noqdq_a8w8'''
    return 'run_conv_noqdq_a8w8'


def gemm_a16w8_qdq_kernel_name() -> str:
    '''Return the kernel name for gemm a16w8_qdq'''
    return 'run_gemm_int16x8'


def gemm_a16w4_qdq_kernel_name() -> str:
    '''Return the kernel name for gemm a16w4_qdq'''
    return 'run_gemm_int16x4'


def gemm_A16A16_v1_kernel_name() -> str:
    '''Return the kernel name for gemm a16a16_qdq'''
    return 'run_gemm_int16x16'


def gemm_A16A16_v2_kernel_name() -> str:
    '''Return the kernel name for gemm a16a16_qdq'''
    return 'run_gemm_int16x16_transpose'


def conv_call_kernel(
    dims: ConvDims,
    ifm_ch_num: int,
    mapping: ConvMapping,
    full_iters: bool = True,
) -> CallKernel:
    '''Generate a CallKernel instruction for the conv noqdq_a8w8 kernel'''
    kernel_name = ''
    params = b''
    if dims.ifm_bits == 8 and dims.wgt_bits == 8 and dims.bias_bits == 16:
        _, _, Ci = dims.Yi, dims.Xi, dims.Ci
        Yis, Xis, Cis = dims.Yis, dims.Xis, dims.Cis
        Ky, Kx = dims.Ky, dims.Kx
        Sy, Sx = dims.Sy, dims.Sx
        Yos, Xos, Cos = dims.Yos, dims.Xos, dims.Cos
        Y_loop, X_loop, Co_loop, Ci_loop = dims.Y_loop, dims.X_loop, dims.Co_loop, dims.Ci_loop
        params = generate_conv_noqdq_a8w8_params(Ci, Yos, Xos, Cos, Yis, Xis, Cis, Ky, Kx, Sy, Sx,
                                                 Y_loop, X_loop, Co_loop, Ci_loop, ifm_ch_num, full_iters)
        kernel_name = conv_noqdq_a8w8_qdq_kernel_name()
    elif dims.ifm_bits == 16 and dims.wgt_bits == 8 and dims.bias_bits == 32:
        core_coeff_tmp_buffer = mapping.qdq_L1_ping_addr
        core_ifm_tmp_buffer = mapping.vec_L1_ping_addr
        core_spill_buf = mapping.tdm_L1_ping_addr
        log("conv A16W8 QDQ core_spill_buf", core_spill_buf)
        log("conv A16W8 QDQ core_ifm_tmp_buffer", core_ifm_tmp_buffer)
        log("conv A16W8 QDQ core_coeff_tmp_buffer", core_coeff_tmp_buffer)
        _, _, Ci = dims.Yi, dims.Xi, dims.Ci
        Yis, Xis, Cis = dims.Yis, dims.Xis, dims.Cis
        Ky, Kx = dims.Ky, dims.Kx
        Sy, Sx = dims.Sy, dims.Sx
        Yos, Xos, Cos = dims.Yos, dims.Xos, dims.Cos
        Y_loop, X_loop, Co_loop, Ci_loop = dims.Y_loop, dims.X_loop, dims.Co_loop, dims.Ci_loop
        params = generate_conv_qdq_a16w8_params(Yos=Yos, Xos=Xos, Cos=Cos, Yis=Yis, Xis=Xis, Cis=Cis,
                                                Ky=Ky, Kx=Kx, Sy=Sy, Sx=Sx,
                                                Y_loop=Y_loop, X_loop=X_loop, Co_loop=Co_loop, Ci_loop=Ci_loop,
                                                ifm_ch_num=ifm_ch_num, core_spill_buf=core_spill_buf,
                                                core_coeff_tmp_buffer=core_coeff_tmp_buffer, full_iters=full_iters,)
        kernel_name = 'run_conv_qdq_a16w8'
    else:
        raise ValueError(f"Unsupported ifm_bits={dims.ifm_bits} and wgt_bits={dims.wgt_bits} and bias_bits={dims.bias_bits}"
                         "for conv kernel")
    return CallKernel(kernel_name, params)


def gemm_call_kernel(
    dims: ConvDims,
    ifm_ch_num: int,
    mapping: ConvMapping,
    optype: LinearOpType,
    gemm_mode: str = 'wgt',
    full_iters: bool = True
) -> CallKernel:
    '''Generate a CallKernel instruction for the conv noqdq_a8w8 kernel'''
    kernel_name = ''
    params = b''
    if dims.ifm_bits == 16 and dims.wgt_bits == 8 and dims.Kx == 1 and dims.Ky == 1:
        core_coeff_tmp_buffer = mapping.qdq_L1_ping_addr
        core_ifm_tmp_buffer = mapping.vec_L1_ping_addr
        core_spill_buf = mapping.tdm_L1_ping_addr
        log("core_spill_buf", core_spill_buf)
        log("core_ifm_tmp_buffer", core_ifm_tmp_buffer)
        log("core_coeff_tmp_buffer", core_coeff_tmp_buffer)
        _, _, Cis = dims.Yis, dims.Xis, dims.Cis
        Ky, Kx = dims.Ky, dims.Kx
        _, Xos, Cos = dims.Yos, dims.Xos, dims.Cos
        X_loop, Co_loop, Ci_loop = dims.X_loop, dims.Co_loop, dims.Ci_loop
        params = generate_gemm_qdq_a16w8_params(Xos, Cos, Cis, Ky, Kx, X_loop, Co_loop, Ci_loop,
                                                ifm_ch_num, core_spill_buf, core_ifm_tmp_buffer, core_coeff_tmp_buffer, full_iters)
        kernel_name = gemm_a16w8_qdq_kernel_name()
    elif dims.ifm_bits == 16 and dims.wgt_bits == 4 and dims.Kx == 1 and dims.Ky == 1:
        core_coeff_tmp_buffer = overlay_3x4_core_stack_addr() - mapping.ifm_L1_size
        core_ifm_tmp_buffer = core_coeff_tmp_buffer - 1536
        core_spill_buf = core_ifm_tmp_buffer - mapping.ofm_L1_size
        log("core_spill_buf", core_spill_buf)
        log("core_ifm_tmp_buffer", core_ifm_tmp_buffer)
        log("core_coeff_tmp_buffer", core_coeff_tmp_buffer)
        _, _, Cis = dims.Yis, dims.Xis, dims.Cis
        Ky, Kx = dims.Ky, dims.Kx
        _, Xos, Cos = dims.Yos, dims.Xos, dims.Cos
        X_loop, Co_loop, Ci_loop = dims.X_loop, dims.Co_loop, dims.Ci_loop
        params = generate_gemm_qdq_a16w4_params(Xos, Cos, Cis, Ky, Kx, X_loop, Co_loop, Ci_loop,
                                                ifm_ch_num, core_spill_buf, core_ifm_tmp_buffer, core_coeff_tmp_buffer, full_iters)
        kernel_name = gemm_a16w4_qdq_kernel_name()
    elif gemm_mode == 'act':
        if optype == LinearOpType.gemm_A16A16_v1:
            params = generate_gemm_qdq_a16a16_params(dims, mapping.tdm_L1_ping_addr, mapping.vec_L1_ping_addr, mapping.qdq_L1_ping_addr)
            kernel_name = gemm_A16A16_v1_kernel_name()
        if optype == LinearOpType.gemm_A16A16_v2:
            params = generate_gemm_qdq_a16a16_transpose_params(dims, mapping.tdm_L1_ping_addr, mapping.wght_transpose_sb_L1_ping_addr, mapping.vec_L1_ping_addr, mapping.qdq_L1_ping_addr)
            kernel_name = gemm_A16A16_v2_kernel_name()
    else:
        raise ValueError(f"Unsupported ifm_bits={dims.ifm_bits} or wgt_bits={dims.wgt_bits}"
                         "for gemm kernel")
    return CallKernel(kernel_name, params)


def generic_call_kernel(
    dims: ConvDims,
    ifm_ch_num: int,
    mapping: ConvMapping,
    full_iters: bool,
    optype: LinearOpType,
) -> CallKernel:
    '''Generate a CallKernel instruction for the specified kernel type'''
    if optype == LinearOpType.conv_A8W8_noqdq:
        return conv_call_kernel(dims, ifm_ch_num, mapping, full_iters)
    if optype in (LinearOpType.gemm_A16W8_qdq, LinearOpType.gemm_A16W4_qdq):
        return gemm_call_kernel(dims, ifm_ch_num, mapping, optype, gemm_mode='wgt', full_iters=full_iters)
    if optype in [LinearOpType.gemm_A16A16_v1, LinearOpType.gemm_A16A16_v2]:
        return gemm_call_kernel(dims, ifm_ch_num, mapping, optype, gemm_mode='act')
    if optype == LinearOpType.conv_A16W8_qdq:
        return conv_call_kernel(dims, ifm_ch_num, mapping, full_iters)
    raise ValueError(f"Unsupported optype={optype} for generic_call_kernel")


@no_type_check
def generate_conv_core_instrs_wgt(
    dims: ConvDims,
    mapping: ConvMapping,
    optype: LinearOpType,
    col: int,
    row: int,
    full_iters: bool = True
) -> list:
    '''Generate conv core instructions for core (col, row)'''
    ifm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, ifm_core_channel(dims)),
        mapping.ifm_L1_ping_addr, mapping.ifm_L1_pong_addr,
        ifm_core_memory(dims),
        ifm_core_s2mm(dims),
        bits_per_block=dims.ifm_bits
    )
    wgt_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, wgt_core_channel(dims)),
        mapping.wgt_L1_ping_addr, mapping.wgt_L1_pong_addr,
        wgt_core_memory(dims, gemm_mode='wgt'),
        wgt_core_s2mm(dims, gemm_mode='wgt'),
    )
    ofm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.MM2S, 0),
        mapping.ofm_L1_ping_addr, mapping.ofm_L1_pong_addr,
        ofm_core_memory(dims),
        ofm_core_mm2s(dims),
        bits_per_block=dims.ofm_bits
    )
    core_instrs = []
    if full_iters:
        core_instrs += [
            ifm_config,
            wgt_config,
            ofm_config,
        ] + [
            outer_instr for _ in range(dims.Y_loop * dims.X_loop * dims.Co_loop) for outer_instr in [
                inner_instr for _ in range(dims.Ci_loop - 1) for inner_instr in [
                    AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                    RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                    RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                ]
            ] + [
                AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
            ]
        ] + [
            generic_call_kernel(dims, ifm_core_channel(dims), mapping, full_iters, optype)
        ]
    else:
        core_instrs += [
            ifm_config,
            wgt_config,
            ofm_config,
        ] + [
            outer_instr for _ in range(dims.Y_loop * dims.X_loop * dims.Co_loop) for outer_instr in [
                inner_instr for _ in range(dims.Ci_loop - 1) for inner_instr in [
                    AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                    RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                    RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                ]
            ] + [
                AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims)), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims)), disable=True),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
            ]
        ] + [
            Loop(dims.Y_loop, [
                Loop(dims.X_loop, [
                    generic_call_kernel(dims, ifm_core_channel(dims), mapping, full_iters, optype),
                ]),
            ])
        ]
    return core_instrs


@no_type_check
def generate_conv_core_instrs_act(
    dims: ConvDims,
    mapping: ConvMapping,
    optype: LinearOpType,
    col: int,
    row: int,
    full_iters: bool = True
) -> list:
    '''Generate conv core instructions for core (col, row)'''
    ifm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, ifm_core_channel(dims)),
        mapping.ifm_L1_ping_addr, mapping.ifm_L1_pong_addr,
        ifm_core_memory(dims),
        ifm_core_s2mm(dims),
        bits_per_block=dims.ifm_bits
    )
    wgt_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, wgt_core_channel(dims)),
        mapping.wgt_L1_ping_addr, mapping.wgt_L1_pong_addr,
        wgt_core_memory(dims, gemm_mode='act'),
        wgt_core_s2mm(dims, gemm_mode='act'),
        bits_per_block=dims.wgt_bits
    )
    ofm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.MM2S, 0),
        mapping.ofm_L1_ping_addr, mapping.ofm_L1_pong_addr,
        ofm_core_memory(dims),
        ofm_core_mm2s(dims),
        bits_per_block=dims.ofm_bits
    )
    core_instrs = [ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), mapping.qdq_L1_ping_addr, None, mapping.qdq_L1_size),
                   AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
                   RelBuffer(DmaChannel(DmaDir.S2MM, 1)),]
    if full_iters:
        log(f"====================NEW Core ({col}, {row}) conv full iters====================")
        core_instrs += [
            ifm_config,
            wgt_config,
            ofm_config,
        ] + [
            outer_instr for _ in range(dims.Y_loop * dims.X_loop * dims.Co_loop) for outer_instr in [
                inner_instr for _ in range(dims.Ci_loop - 1) for inner_instr in [
                            AcqBuffer(DmaChannel(DmaDir.S2MM,
                                                 ifm_core_channel(dims)), disable=True),
                            AcqBuffer(DmaChannel(DmaDir.S2MM,
                                                 wgt_core_channel(dims)), disable=True),
                            RelBuffer(DmaChannel(DmaDir.S2MM,
                                                 ifm_core_channel(dims)), disable=True),
                            RelBuffer(DmaChannel(DmaDir.S2MM,
                                                 wgt_core_channel(dims)), disable=True),
                ]
             ] + [
                AcqBuffer(DmaChannel(DmaDir.S2MM,
                                     ifm_core_channel(dims)), disable=True),
                AcqBuffer(DmaChannel(DmaDir.S2MM,
                                     wgt_core_channel(dims)), disable=True),
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM,
                                     ifm_core_channel(dims)), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM,
                                     wgt_core_channel(dims)), disable=True),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
            ]
        ] + [
            generic_call_kernel(dims, ifm_core_channel(dims), mapping, full_iters, optype)
        ]
    else:
        core_instrs += [
            ifm_config,
            wgt_config,
            ofm_config,
        ] + [
            Loop(dims.Y_loop, [
                Loop(dims.X_loop, [
                    Loop(dims.Co_loop, [
                        generic_call_kernel(dims, ifm_core_channel(dims), mapping, full_iters, optype),
                        Loop(dims.Ci_loop - 1, [
                            AcqBuffer(DmaChannel(DmaDir.S2MM,
                                                 wgt_core_channel(dims)), disable=True),
                            AcqBuffer(DmaChannel(DmaDir.S2MM,
                                                 ifm_core_channel(dims)), disable=True),
                            RelBuffer(DmaChannel(DmaDir.S2MM,
                                                 wgt_core_channel(dims)), disable=True),
                            RelBuffer(DmaChannel(DmaDir.S2MM,
                                                 ifm_core_channel(dims)), disable=True),
                        ]),
                        AcqBuffer(DmaChannel(DmaDir.S2MM,
                                             wgt_core_channel(dims)), disable=True),
                        AcqBuffer(DmaChannel(DmaDir.S2MM,
                                             ifm_core_channel(dims)), disable=True),
                        AcqBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                        RelBuffer(DmaChannel(DmaDir.S2MM,
                                             wgt_core_channel(dims)), disable=True),
                        RelBuffer(DmaChannel(DmaDir.S2MM,
                                             ifm_core_channel(dims)), disable=True),
                        RelBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
                    ]),
                ]),
            ]),
        ]

    return core_instrs


@no_type_check
def generate_conv_core_instrs(
    dims: ConvDims,
    mapping: ConvMapping,
    optype: LinearOpType,
    col: int,
    row: int,
    full_iters: bool = True,
    gemm_mode: str = 'wgt'
) -> list:
    '''Generate conv core instructions for core (col, row)'''
    if gemm_mode == 'wgt':
        return generate_conv_core_instrs_wgt(dims, mapping, optype, col, row, full_iters)
    if gemm_mode == 'act':
        return generate_conv_core_instrs_act(dims, mapping, optype, col, row, full_iters)
    raise ValueError(f"Unsupported gemm_mode={gemm_mode}. Supported modes are 'wgt' and 'act'.")


def Yi_slice_per_column(dims: ConvDims, col: int, y_iter: int) -> tuple[int, int, int]:
    '''Slice for axis Yi at column col during iteration iter of Y_loop'''
    yi_ranges = {}
    Yi_start = 0
    Yi_stop = 0
    Yi_size = 0
    for row in range(dims.aie_rows):
        yis_start, yis_stop, _ = Yi_slice(dims, col, row, y_iter)
        yi_ranges[row] = (yis_start, yis_stop)
    yi_min = min(yi_ranges[row][0] for row in range(dims.aie_rows))
    yi_max = max(yi_ranges[row][1] for row in range(dims.aie_rows))
    yi_max = min(yi_max, dims.Yi)  # dims.Yi if yi_max > dims.Yi else yi_max
    yi_min = max(yi_min, 0)  # 0 if yi_min < 0 else yi_min
    Yi_start = yi_min
    Yi_stop = yi_max
    Yi_size = Yi_stop - Yi_start
    return Yi_start, Yi_stop, Yi_size


def Yo_slice_per_column(dims: ConvDims, col: int, y_iter: int) -> tuple[int, int, int]:
    '''Slice for axis Yo at column col during iteration iter of Y_loop'''
    yo_ranges = {}
    Yo_start = 0
    Yo_stop = 0
    Yo_size = 0
    for row in range(dims.aie_rows):
        yos_start, yos_stop, _ = Yo_slice(dims, col, row, y_iter)
        yo_ranges[row] = (yos_start, yos_stop)
    yo_min = min(yo_ranges[row][0] for row in range(dims.aie_rows))
    yo_max = max(yo_ranges[row][1] for row in range(dims.aie_rows))
    yo_max = min(yo_max, dims.Yo)  # dims.Yo if yo_max > dims.Yo else yo_max
    yo_min = max(yo_min, 0)  # 0 if yo_min < 0 else yo_min
    Yo_start = yo_min
    Yo_stop = yo_max
    Yo_size = Yo_stop - Yo_start
    return Yo_start, Yo_stop, Yo_size


def Xi_slice_per_column(dims: ConvDims, col: int, x_iter: int) -> tuple[int, int, int]:
    '''Slice for axis Xi at column col during iteration iter of X_loop'''
    xi_ranges = {}
    Xi_start = 0
    Xi_stop = 0
    Xi_size = 0
    for row in range(dims.aie_rows):
        xis_start, xis_stop, _ = Xi_slice(dims, col, row, x_iter)
        xi_ranges[row] = (xis_start, xis_stop)
    xi_min = min(xi_ranges[row][0] for row in range(dims.aie_rows))
    xi_max = max(xi_ranges[row][1] for row in range(dims.aie_rows))
    xi_max = min(xi_max, dims.Xi)  # dims.Xi if xi_max > dims.Xi else xi_max
    xi_min = max(xi_min, 0)  # 0 if xi_min < 0 else xi_min
    Xi_start = xi_min
    Xi_stop = xi_max
    Xi_size = Xi_stop - Xi_start
    return Xi_start, Xi_stop, Xi_size


def Xo_slice_per_column(dims: ConvDims, col: int, x_iter: int) -> tuple[int, int, int]:
    '''Slice for axis Xo at column col during iteration iter of X_loop'''
    xo_ranges = {}
    Xo_start = 0
    Xo_stop = 0
    Xo_size = 0
    for row in range(dims.aie_rows):
        xos_start, xos_stop, _ = Xo_slice(dims, col, row, x_iter)
        xo_ranges[row] = (xos_start, xos_stop)
    xo_min = min(xo_ranges[row][0] for row in range(dims.aie_rows))
    xo_max = max(xo_ranges[row][1] for row in range(dims.aie_rows))
    xo_max = min(xo_max, dims.Xo)  # dims.Xo if xo_max > dims.Xo else xo_max
    xo_min = max(xo_min, 0)  # 0 if xo_min < 0 else xo_min
    Xo_start = xo_min
    Xo_stop = xo_max
    Xo_size = Xo_stop - Xo_start
    return Xo_start, Xo_stop, Xo_size


def Co_slice_per_column(dims: ConvDims, col: int, co_iter: int) -> tuple[int, int, int]:
    '''Slice for axis Co at column col during iteration iter of Co_loop'''
    co_ranges = {}
    Co_start = 0
    Co_stop = 0
    Co_size = 0
    for row in range(dims.aie_rows):
        cos_start, cos_stop, _ = Co_slice(dims, col, row, co_iter)
        co_ranges[row] = (cos_start, cos_stop)
    co_min = min(co_ranges[row][0] for row in range(dims.aie_rows))
    co_max = max(co_ranges[row][1] for row in range(dims.aie_rows))
    co_max = min(co_max, dims.Co)  # dims.Co if co_max > dims.Co else co_max
    co_min = max(co_min, 0)  # 0 if co_min < 0 else co_min
    Co_start = co_min
    Co_stop = co_max
    Co_size = Co_stop - Co_start
    return Co_start, Co_stop, Co_size
