import os
import sys
from typing import List

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from kerneltest.DWC.dwc_helpers import DwcDims

from kerneltest.helpers import (
    ceildiv,
    iceil,
    round_up_to_multiple,
)

from kerneltest.overlay_1x1 import (
    overlay_stack_addr,
    aie4_overlay_dma_connections,
    shim_alloc,
)

from kernel.dwc_int16x8.dwc_qdq_a16w8_params import (
    generate_dwc_qdq_a16w8_params
)
    
from dmacompiler import (
    DevGen,
    OverlayShape, DataTransfer,  BackEnd, generate_transfer_params,
    DmaChannel, DmaDir, AieTile, TileType,
    memtile_dma, shim_dma,
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop,
    run_layer_compilation, set_dev_gen,
)


set_dev_gen(DevGen.Aie4)


def memtile_param_memory(dims: DwcDims) -> str:
    return f'row:{dims.aie_rows} Bytes:{dims.Param_size}'

def memtile_param_s2mm(dims: DwcDims) -> str:
    return f'row:0:{dims.aie_rows} Bytes:0:{dims.Param_size}'

def memtile_param_mm2s(dims: DwcDims, row: int) -> str:
    return f'row:{row}:{row+1} Bytes:0:{dims.Param_size}'

def memtile_ifm_memory(dims: DwcDims) -> str:
    return f'Ci:{dims.Cis} Yi:{dims.Yis} Xi:{dims.Xis} Ci:{dims.Ci_gran}'

def memtile_ifm_s2mm(dims: DwcDims) -> str:
    return f'Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Cis}'

def memtile_ifm_mm2s(dims: DwcDims) -> str:
    return f' Ci:0:{dims.Cis}:{dims.Ci_gran} Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Ci_gran}'

def memtile_wgt_memory(dims: DwcDims) -> str:
    return f'row:{dims.aie_rows} Bytes:{dims.wgt_subv_bytes}'

def memtile_wgt_s2mm(dims: DwcDims) -> str:
    return f'row:0:{dims.aie_rows} Bytes:0:{dims.wgt_subv_bytes}'

def memtile_wgt_mm2s(dims: DwcDims, row: int) -> str:
    return f'row:{row}:{row+1} Bytes:0:{dims.wgt_subv_bytes}'

def memtile_ofm_memory(dims: DwcDims) -> str:
    # NOTE: Assuming C0 split across rows
    Co_per_col = dims.aie_rows * dims.Cos
    return f'Yo:{dims.Yos} Xo:{dims.Xos} Co:{Co_per_col}'

def memtile_ofm_s2mm(dims: DwcDims, row: int) -> str:
    start_Cos = row * dims.Cos
    Stop_Cos = start_Cos + dims.Cos
    Step_Cos = dims.Cos
    return f'Co:0:{dims.Cos}:{dims.Ci_gran} Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Ci_gran}'

def memtile_ofm_mm2s(dims: DwcDims) -> str:
    # NOTE: Assuming C0 split across rows
    Co_per_col = dims.aie_rows * dims.Cos
    return f'Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{Co_per_col}'

def shimtile_param_memory(dims: DwcDims) -> str:
    return f'col:{dims.aie_cols} row:{dims.aie_rows} Bytes:{dims.Param_size}'

def shimtile_param_mm2s(dims: DwcDims, col: int) -> str:
    return f'col:{col}:{col+1} row:0:{dims.aie_rows} Bytes:0:{dims.Param_size}'

def shimtile_ifm_memory(dims: DwcDims) -> str:
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'

def shimtile_ifm_mm2s(dims: DwcDims, col: int) -> str:
    return f'Ci:{0}:{dims.Ci}:{dims.Cis} Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Cis}'

def shimtile_wgt_memory(dims: DwcDims) -> str:
    wgt_col_shards = dims.Co // (dims.Cos * dims.aie_rows)
    wgt_row_shards = dims.Ci // dims.Cis
    return f'col:{wgt_col_shards} row:{wgt_row_shards} Bytes:{dims.wgt_subv_bytes}'

def shimtile_wgt_mm2s(dims: DwcDims, col: int) -> str:
    wgt_row_shards = dims.Ci // dims.Cis
    start_Co = col * dims.aie_rows 
    stop_Co = start_Co + dims.aie_rows
    Step_Co = dims.aie_rows
    return f'col:{start_Co}:{stop_Co}:{Step_Co} row:0:{wgt_row_shards}:1 Bytes:0:{dims.wgt_subv_bytes}'

def shimtile_ofm_memory(dims: DwcDims) -> str:
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'

def shimtile_ofm_s2mm(dims: DwcDims, col: int) -> str:
    start_Co = col * dims.Cos
    Stop_Co = dims.aie_rows * dims.Cos
    Step_Co = dims.aie_rows * dims.Cos
    return f'Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Cos}'

def compile_dwc_dataflow(
    dims: DwcDims,
    back_end: BackEnd,
    broadcast_tensor: int = 0,
):
    kernel_names = ['run_dwc_qdq_a16w8']
    kernel_includes = ['super.hh', 'dwc_int16x8/dwc_qdq_a16w8_wrapper.cc']
    shim_ifm_mm2s_channel = 1
    shim_param_mm2s_channel = 0
    shim_wgt_mm2s_channel = 0
    shim_ofm_s2mm_channel = 0

    memtile_param_s2mm_channel = 0
    memtile_param_mm2s_channel = []
    for row in range(dims.aie_rows):
        memtile_param_mm2s_channel.append(row)
    if broadcast_tensor == 0:
        # NOTE: IFM broadcast, wgt unicast
        memtile_ifm_s2mm_channel = 1
        memtile_ifm_mm2s_channel = 4
        memtile_wgt_s2mm_channel = 0
        memtile_wgt_mm2s_channel = []
        for row in range(dims.aie_rows):
            memtile_wgt_mm2s_channel.append(row)
        core_ifm_s2mm_channel = 1
        core_wgt_s2mm_channel = 0
    elif broadcast_tensor == 1:
        memtile_ifm_s2mm_channel = 0
        memtile_ifm_mm2s_channel = []
        for row in dims.aie_rows:
            memtile_ifm_mm2s_channel.append(row)
        memtile_wgt_s2mm_channel = 1
        memtile_wgt_mm2s_channel = 1
        core_ifm_s2mm_channel = 0
        core_wgt_s2mm_channel = 1
    memtile_ofm_s2mm_channel = []
    for row in range(dims.aie_rows):
        memtile_ofm_s2mm_channel.append(row+2)
    memtile_ofm_mm2s_channel = 5
    core_ofm_mm2s_channel = 0

    CoreParamSize = dims.Param_size
    ConvShimAlloc = shim_alloc()

    CoreBackSize = 16 * 1024 
    CoreAlignSize = 128
    print(f"wgt_subv_bytes", dims.wgt_subv_bytes, "act_subv_bytes", dims.act_subv_bytes, "out_subv_bytes", dims.out_subv_bytes)
    CoreIfmSize = round_up_to_multiple(dims.act_subv_bytes, CoreAlignSize)
    CoreWgtSize = round_up_to_multiple(dims.wgt_subv_bytes, CoreAlignSize)
    CoreOfmSize = round_up_to_multiple(dims.out_subv_bytes, CoreAlignSize)
    CoreStackAddr = overlay_stack_addr()
    CoreIfmPingAddr = 0
    CoreWgtPingAddr = CoreIfmPingAddr + CoreIfmSize
    CoreOfmPingAddr = CoreWgtPingAddr + CoreWgtSize
    # CoreIfmPongAddr = CoreOfmPingAddr + CoreOfmSize
    # CoreWgtPongAddr = CoreIfmPongAddr + CoreIfmSize
    # CoreOfmPongAddr = CoreWgtPongAddr + CoreWgtSize
    print(f"CoreIfmPingAddr", CoreIfmPingAddr, "CoreIfmSize", CoreIfmSize)
    print(f"CoreWgtPingAddr", CoreWgtPingAddr, "CoreWgtSize", CoreWgtSize)
    print(f"CoreOfmPingAddr", CoreOfmPingAddr, "CoreOfmSize", CoreOfmSize)
    # print(f"CoreIfmPongAddr", CoreIfmPongAddr, "CoreIfmSize", CoreIfmSize)
    # print(f"CoreWgtPongAddr", CoreWgtPongAddr, "CoreWgtSize", CoreWgtSize)
    # print(f"CoreOfmPongAddr", CoreOfmPongAddr, "CoreOfmSize", CoreOfmSize)
    assert ( (CoreOfmPingAddr + CoreOfmSize) < overlay_stack_addr() )

    if dims.Ci_loop == 1:
        core_instructions = [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel), CoreIfmPingAddr, None, dims.act_subv_bytes),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel), CoreWgtPingAddr, None, dims.wgt_subv_bytes),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel), CoreOfmPingAddr, None, dims.out_subv_bytes),
            AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
            AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
            AcqBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
            CallKernel('run_dwc_qdq_a16w8', generate_dwc_qdq_a16w8_params(
                dims.Yos, dims.Xos, dims.Cos,
                dims.Yis, dims.Xis,
                dims.Ky, dims.Kx,
                dims.Sy, dims.Sx,
                dims.Y_loop, dims.X_loop, dims.Co_loop,
                core_ifm_s2mm_channel,
                dims.sign_act, dims.sign_wgt, dims.sign_out,
                )),
            RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
            RelBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
            RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
        ]
    else:
        core_instructions = [
            ConfigBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel), CoreIfmPingAddr, None, dims.act_subv_bytes),
            ConfigBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel), CoreWgtPingAddr, None, dims.wgt_subv_bytes),
            ConfigBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel), CoreOfmPingAddr, None, dims.out_subv_bytes),
            Loop((dims.Ci_loop-1),
                 [
                    AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                    AcqBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                    RelBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
                 ]
            ),
            AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
            AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
            AcqBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
            CallKernel('run_dwc_qdq_a16w8', generate_dwc_qdq_a16w8_params(
                dims.Yos, dims.Xos, dims.Cos,
                dims.Yis, dims.Xis,
                dims.Ky, dims.Kx,
                dims.Sy, dims.Sx,
                dims.Y_loop, dims.X_loop, dims.Co_loop,
                core_ifm_s2mm_channel,
                dims.sign_act, dims.sign_wgt, dims.sign_out,
                dims.shift_out
                )), 
            RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
            RelBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
            RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
        ]

    MemtileParamSize = dims.aie_rows * CoreParamSize
    MemtileIfmSize = dims.act_subv_bytes
    MemtileWgtSize = dims.aie_rows * dims.wgt_subv_bytes
    MemtileOfmSize = dims.aie_rows * dims.act_subv_bytes
    MemtileParamAddr = 0
    MemtileIfmPingAddr = MemtileParamAddr + MemtileParamSize 
    MemtileIfmPongAddr = MemtileIfmPingAddr + MemtileIfmSize
    MemtileWgtPingAddr = MemtileIfmPongAddr + MemtileIfmSize
    MemtileWgtPongAddr = MemtileWgtPingAddr + MemtileWgtSize
    MemtileOfmPingAddr = MemtileWgtPongAddr + MemtileWgtSize
    MemtileOfmPongAddr = MemtileOfmPingAddr + MemtileOfmSize

    ShimIfmSize = dims.Yi * dims.Xi * dims.Ci * dims.act_bits // 8
    ShimWgtSize = dims.Co * dims.Ky * dims.Kx * dims.Ci * dims.wgt_bits // 8
    ShimOfmSize = dims.Yo * dims.Xo * dims.Co * dims.out_bits // 8

    memtile_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [MemtileParamAddr], MemtileParamSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_param_s2mm_channel),
                memtile_param_memory(dims),
                memtile_param_s2mm(dims),
                dims.param_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_param_mm2s_channel[row]),
                memtile_param_memory(dims),
                memtile_param_mm2s(dims, row),
                dims.param_bits,
                ) for row in range(dims.aie_rows)
            ],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [dims.Ci_loop * dims.Y_loop * dims.X_loop * dims.Co_loop],
            AieTile(TileType.Memtile, col), [MemtileWgtPingAddr, MemtileWgtPongAddr], MemtileWgtSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_wgt_s2mm_channel),
                memtile_wgt_memory(dims),
                memtile_wgt_s2mm(dims),
                dims.wgt_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_wgt_mm2s_channel[row]),
                memtile_wgt_memory(dims),
                memtile_wgt_mm2s(dims, row),
                dims.wgt_bits,
                ) for row in range(dims.aie_rows)
            ],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [dims.Ci_loop * dims.Y_loop * dims.X_loop * dims.Co_loop],
            AieTile(TileType.Memtile, col), [MemtileIfmPingAddr, MemtileIfmPongAddr], MemtileIfmSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_ifm_s2mm_channel),
                memtile_ifm_memory(dims),
                memtile_ifm_s2mm(dims),
                dims.act_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_ifm_mm2s_channel),
                memtile_ifm_memory(dims),
                memtile_ifm_mm2s(dims),
                dims.act_bits,
                enable_padding=True,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [dims.Y_loop * dims.X_loop * dims.Co_loop],
            AieTile(TileType.Memtile, col), [MemtileOfmPingAddr, MemtileOfmPongAddr], MemtileOfmSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_ofm_s2mm_channel[row]),
                memtile_ofm_memory(dims),
                memtile_ofm_s2mm(dims, row),
                dims.out_bits,
                ) for row in range(dims.aie_rows)
            ],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_ofm_mm2s_channel),
                memtile_ofm_memory(dims),
                memtile_ofm_mm2s(dims),
                dims.out_bits,
            )],
        ) for col in range(dims.aie_cols)
    ]

    shim_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.prm_buffer_id], MemtileParamSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, shim_param_mm2s_channel),
                shimtile_param_memory(dims),
                shimtile_param_mm2s(dims, col),
                dims.param_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.wgt_buffer_id], ShimWgtSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, shim_wgt_mm2s_channel),
                shimtile_wgt_memory(dims),
                shimtile_wgt_mm2s(dims, col),
                dims.wgt_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.ifm_buffer_id], ShimIfmSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, shim_ifm_mm2s_channel),
                shimtile_ifm_memory(dims),
                shimtile_ifm_mm2s(dims, col),
                dims.act_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.ofm_buffer_id], ShimOfmSize,
            [generate_transfer_params(
                shim_dma(col, DmaDir.S2MM, shim_ofm_s2mm_channel),
                shimtile_ofm_memory(dims),
                shimtile_ofm_s2mm(dims, col),
                dims.out_bits,
            )],
            [],
        ) for col in range(dims.aie_cols)
    ]
    
    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instructions,
        memtile_transfers,
        shim_transfers,
        aie4_overlay_dma_connections(dims.aie_cols, dims.aie_rows),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id = 0,
    )
