'''
# Common definitions for convolution dataflow
# This file contains shared definitions and utility functions for convolution operations
'''
from dataclasses import dataclass

from kernel.dwc_int16x8.dwc_qdq_a16w8_params import generate_dwc_qdq_a16w8_params

from dmacompiler import (
    generate_core_buffer_config,
    core_dma,
    DmaDir,
    DmaChannel,
    AcqBuffer,
    RelBuffer,
    CallKernel,
    set_dev_gen,
    DevGen,
    compute_buffer_size,
    Loop,
)

from scheduler.common import (
    LinearOpType,
)

from scheduler.conv.conv_config_builders import (
    ConvDims,
    ConvMapping,
)

from scheduler.conv.conv_common import (
    convL2Memory,
    ifm_core_channel,
    wgt_core_channel,
)

set_dev_gen(DevGen.Aie4)


def dwc_get_aligned_Xis(Xos: int, Sx: int, Kx: int, Kx_gran: int) -> int:
    if Sx == 2:
        return (Xos - 1) * Sx + Kx_gran
    else:
        return (Xos - 1) * Sx + Kx


@dataclass
class DwcL2MemoryAllocator:
    '''Allocate L2 memory buffers for convolution operations'''
    def __init__(
        self,
        dims: ConvDims,
        param_l2_memory: str,
        ifm_l2_memory: str,
        wgt_l2_memory: str,
        ofm_l2_memory: str,
        ifm_double_buffer: bool = False,
        wgt_double_buffer: bool = False,
        ofm_double_buffer: bool = False,
    ):
        memtile_size = 3*2**20  # 1 MiB memtile size
        # Compute the sizes of each L2 memory buffer
        self.param_size = compute_buffer_size(param_l2_memory)
        self.ifm_size = compute_buffer_size(ifm_l2_memory, dims.ifm_bits)
        self.wgt_size = compute_buffer_size(wgt_l2_memory, dims.wgt_bits)  # WGT subvol is Bytes
        self.ofm_size = compute_buffer_size(ofm_l2_memory, dims.ofm_bits)
        # allocate the L2 buffers based on the strategy
        self.param_addr = 0
        self.ifm_ping_addr = self.param_addr + self.param_size
        if ifm_double_buffer:
            self.ifm_pong_addr = self.ifm_ping_addr + self.ifm_size
        else:
            # NOTE: Set unused pong address to a safe value outside the memtile
            self.ifm_pong_addr = 9*2**20
        self.wgt_ping_addr = (self.ifm_pong_addr + self.ifm_size
                              if ifm_double_buffer else self.ifm_ping_addr + self.ifm_size)
        if wgt_double_buffer:
            self.wgt_pong_addr = self.wgt_ping_addr + self.wgt_size
        else:
            # NOTE: Set unused pong address to a safe value outside the memtile
            self.wgt_pong_addr = 9*2**20
        self.ofm_ping_addr = (self.wgt_pong_addr + self.wgt_size
                              if wgt_double_buffer else self.wgt_ping_addr + self.wgt_size)
        if ofm_double_buffer:
            self.ofm_pong_addr = self.ofm_ping_addr + self.ofm_size
        else:
            # NOTE: Set unused pong address to a safe value outside the memtile
            self.ofm_pong_addr = 9*2**20
        # Check if the allocated buffers fit within a single memtile
        total_size = self.param_size + self.ifm_size + \
            (self.ifm_size if ifm_double_buffer else 0) + \
            self.wgt_size + (self.wgt_size if wgt_double_buffer else 0) + \
            self.ofm_size + (self.ofm_size if ofm_double_buffer else 0)
        if total_size > memtile_size:
            raise ValueError(
                f"Allocated L2 buffers exceed memtile size: {total_size} > {memtile_size}. "
                "Consider reducing buffer sizes or using fewer double buffers."
            )
        # Create the convL2Memory object
        self.memory = convL2Memory(
            param_addr=self.param_addr,
            param_size=self.param_size,
            ifm_ping_addr=self.ifm_ping_addr,
            ifm_pong_addr=self.ifm_pong_addr,
            ifm_size=self.ifm_size,
            wgt_ping_addr=self.wgt_ping_addr,
            wgt_pong_addr=self.wgt_pong_addr,
            wgt_size=self.wgt_size,
            ofm_ping_addr=self.ofm_ping_addr,
            ofm_pong_addr=self.ofm_pong_addr,
            ofm_size=self.ofm_size,
        )


def dwc_wgt_shim_memory(dims: ConvDims) -> str:
    '''Define WGT DDR data order'''
    # NOTE: The weights are pre-formatted in blocks according to the subvolume size.
    # since the traversal along Cout for Co_split can be folded into a linear access.
    # This is required to keep the shim access pattern within five dimensions.
    return f'Cob:{dims.Co_loop * dims.Co_split} Subv:{dims.wgt_L1_size}'


def dwc_ifm_core_memory(dims: ConvDims) -> str:
    '''Define IFM L1 data order and shape'''
    return f'Ci:{dims.Cos} Yi:{dims.Yis} Xi:{dims.Xis} Ci:{dims.Co_gran}'


def dwc_ifm_core_s2mm(dims: ConvDims) -> str:
    '''IFM core s2mm access pattern for tile at (col, row)'''
    # NOTE: We require that the core access pattern doesn't change
    # in time, so it's safe to just use the slice in iteration 0
    # for size calculations.
    return f'Ci:0:{dims.Cos}:{dims.Co_gran} Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Co_gran}'


def dwc_wgt_core_memory(dims: ConvDims) -> str:
    '''Define WGT L1 data order'''
    return f'Subv:{dims.wgt_L1_size}'


def dwc_wgt_core_s2mm(dims: ConvDims) -> str:
    '''WGT core s2mm access pattern'''
    return f'Subv:0:{dims.wgt_L1_size}'


def dwc_ofm_core_memory(dims: ConvDims) -> str:
    '''Define OFM L1 data order and shape'''
    return f'Co:{dims.Cos} Yo:{dims.Yos} Xo:{dims.Xos} Co:{dims.Co_gran}'


def dwc_ofm_core_mm2s(dims: ConvDims) -> str:
    '''OFM core mm2s access pattern for tile at (col, row)'''
    return f'Co:0:{dims.Cos}:{dims.Co_gran} Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Co_gran}'


def dwc_a16w8_qdq_kernel_name() -> str:
    '''Return the kernel name for gemm a16w8_qdq'''
    return 'run_dwc_qdq_a16w8'


def dwc_call_kernel(
    dims: ConvDims,
    ifm_ch_num: int,
    mapping: ConvMapping,
    op_type: LinearOpType,
) -> CallKernel:
    '''Generate a callkernel instrction for DWC'''
    kernel_name = ''
    params = b''
    if op_type == LinearOpType.dwc_A16W8_qdq:
        params = generate_dwc_qdq_a16w8_params(
            dims.Yos, dims.Xos, dims.Cos, dims.Yis, dims.Xis,
            dims.Ky, dims.Kx, dims.Sy, dims.Sx, dims.Y_loop, dims.X_loop, dims.Co_loop,
            ifm_ch_num, dims.sign_A, dims.sign_W, dims.sign_O,
        )
        kernel_name = dwc_a16w8_qdq_kernel_name()
    else:
        raise ValueError(f"Unsupported DWC OP type: {op_type}, ifm_bits={dims.ifm_bits}, wgt_bits={dims.wgt_bits}, ofm_bits={dims.ofm_bits}")
    return CallKernel(kernel_name, params)


def generic_call_kernel(
    dims: ConvDims,
    ifm_ch_num: int,
    mapping: ConvMapping,
    op_type: LinearOpType,
) -> CallKernel:
    if op_type in [LinearOpType.dwc_A16W8_qdq, LinearOpType.dwc_A8W8_qdq, LinearOpType.dwc_A8W8_qdq]:
        return dwc_call_kernel(dims, ifm_ch_num, mapping, op_type)
    raise ValueError(f"Unsupported OP type {op_type} in DWC generic call kernel")


def generate_dwc_core_instrs(
    dims: ConvDims,
    mapping: ConvMapping,
    op_type: LinearOpType,
    col: int,
    row: int,
) -> list:
    '''Generate DWC core instructions for every core'''
    ifm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, ifm_core_channel(dims)),
        mapping.ifm_L1_ping_addr, mapping.ifm_L1_pong_addr,
        dwc_ifm_core_memory(dims),
        dwc_ifm_core_s2mm(dims),
        bits_per_block=dims.ifm_bits,
    )
    wgt_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.S2MM, wgt_core_channel(dims)),
        mapping.wgt_L1_ping_addr, mapping.wgt_L1_pong_addr,
        dwc_wgt_core_memory(dims),
        dwc_wgt_core_s2mm(dims),
    )
    ofm_config = generate_core_buffer_config(
        core_dma(col, row, DmaDir.MM2S, 0),
        mapping.ofm_L1_ping_addr, mapping.ofm_L1_pong_addr,
        dwc_ofm_core_memory(dims),
        dwc_ofm_core_mm2s(dims),
        bits_per_block=dims.ofm_bits,
    )
    core_instrs = [
        ifm_config,
        wgt_config,
        ofm_config,
        Loop(dims.Y_loop, [
            Loop(dims.X_loop, [
                Loop(dims.Co_loop, [
                    AcqBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims))),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims))),
                    AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                    generic_call_kernel(dims, ifm_core_channel(dims), mapping, op_type),
                    RelBuffer(DmaChannel(DmaDir.S2MM, wgt_core_channel(dims))),
                    RelBuffer(DmaChannel(DmaDir.S2MM, ifm_core_channel(dims))),
                    RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
                ]),
            ]),
        ])
    ]
    return core_instrs
