import os
import sys
from typing import List

CURRDIR = os.path.dirname(os.path.abspath(__file__))

from kerneltest.helpers import round_up_to_multiple
from resize_nni_helpers import (
    CoordinateTransfromationMode,
    ResizeNNIDims,
    resize_nni_wgt_subvol_dims,
    resize_nni_input_subvol_dims,
    generate_resize_nni_noqdq_a8_params,
)

from dmacompiler import (
    DevGen,
    OverlayShape, DataTransfer,  BackEnd, generate_transfer_params,
    DmaChannel, DmaDir, AieTile, TileType,
    memtile_dma, shim_dma,
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop,
    run_layer_compilation, set_dev_gen,
)

from kerneltest.overlay_1x1 import (
    overlay_stack_addr,
    aie4_overlay_dma_connections,
    shim_alloc,
)

from scheduler.common import (
    prm_memtile_memory,
    prm_shim_memory,
    prm_memtile_mm2s,
    prm_memtile_s2mm,
    prm_shim_mm2s,
)

set_dev_gen(DevGen.Aie4)


def memtile_param_memory(dims: ResizeNNIDims) -> str:
    return f'row:{dims.aie_rows} Bytes:{1024}'


def memtile_param_s2mm(dims: ResizeNNIDims) -> str:
    return f'row:0:{dims.aie_rows} Bytes:0:{1024}'


def memtile_param_mm2s(dims: ResizeNNIDims, row: int) -> str:
    return f'row:{row}:{row+1} Bytes:0:{1024}'


def shimtile_param_memory(dims: ResizeNNIDims) -> str:
    return f'col:{dims.aie_cols} row:{dims.aie_rows} Bytes:{1024}'


def shimtile_param_mm2s(dims: ResizeNNIDims, col: int) -> str:
    return f'col:{col}:{col+1} row:0:{dims.aie_rows} Bytes:0:{1024}'


def memtile_ifm_memory(dims: ResizeNNIDims) -> str:
    return f'Ci:{dims.Ci} Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci_gran}'


def memtile_ifm_s2mm(dims: ResizeNNIDims) -> str:
    return f'Yi:0:{dims.Yi} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'


def memtile_ifm_mm2s(dims: ResizeNNIDims) -> str:
    return f'Ci:0:{dims.Cis}:{dims.Ci_gran} Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Ci_gran}'


def memtile_wgt_memory(dims: ResizeNNIDims) -> str:
    return f'row:{dims.aie_rows} Bytes:{dims.wgt_subv_bytes}'


def memtile_wgt_s2mm(dims: ResizeNNIDims) -> str:
    return f'row:0:{dims.aie_rows} Bytes:0:{dims.wgt_subv_bytes}'


def memtile_wgt_mm2s(dims: ResizeNNIDims, row: int) -> str:
    return f'row:{row}:{row+1} Bytes:0:{dims.wgt_subv_bytes}'


def memtile_ofm_memory(dims: ResizeNNIDims) -> str:
    return f'Yo:{dims.Yos} Xo:{dims.Xos} Co:{dims.Cos}'


def memtile_ofm_s2mm(dims: ResizeNNIDims, row: int) -> str:
    return f'Co:0:{dims.Cos}:{dims.Ci_gran} Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Ci_gran}'


def memtile_ofm_mm2s(dims: ResizeNNIDims) -> str:
    return f'Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Cos}'


def shimtile_ifm_memory(dims: ResizeNNIDims) -> str:
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.Ci}'


def shimtile_ifm_mm2s(dims: ResizeNNIDims, col: int) -> str:
    return f'Ci:{0}:{dims.Ci}:{dims.Ci} Yi:0:{dims.Yi} Xi:0:{dims.Xi} Ci:0:{dims.Ci}'


def shimtile_wgt_memory(dims: ResizeNNIDims) -> str:
    return f'row:{1} Bytes:{dims.wgt_subv_bytes}'


def shimtile_wgt_mm2s(dims: ResizeNNIDims, col: int) -> str:
    return f' row:0:1 Bytes:0:{dims.wgt_subv_bytes}'


def shimtile_ofm_memory(dims: ResizeNNIDims) -> str:
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.Co}'


def shimtile_ofm_s2mm(dims: ResizeNNIDims, col: int) -> str:
    return f'Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Cos}'

def compile_resize_nni_dataflow(
    dims: ResizeNNIDims,
    back_end: BackEnd,
):
    print(f"ResizeNNIDims: {dims}")
    kernel_names = ['run_resize_nni_a8']
    kernel_includes = ['super.hh', 'resize_nni/resize_nearest_wrapper.cc']
    ResizeShimAlloc = shim_alloc()
    CoreParamSize = 1024
    CoreIfmSize = dims.ifm_subv_bytes
    CoreWgtSize = dims.wgt_subv_bytes
    CoreOfmSize = dims.ofm_subv_bytes
    CoreIfmPingAddr = 0
    CoreWgtPingAddr = round_up_to_multiple(CoreIfmPingAddr + CoreIfmSize, 128)
    CoreOfmPingAddr = round_up_to_multiple(CoreWgtPingAddr + CoreWgtSize, 128)
    print(f"CoreIfmPingAddr", CoreIfmPingAddr, "CoreIfmSize", CoreIfmSize)
    print(f"CoreWgtPingAddr", CoreWgtPingAddr, "CoreWgtSize", CoreWgtSize)
    print(f"CoreOfmPingAddr", CoreOfmPingAddr, "CoreOfmSize", CoreOfmSize)
    assert ( (CoreOfmPingAddr + CoreOfmSize) < overlay_stack_addr() )
    core_instructions = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), CoreIfmPingAddr, None, dims.ifm_subv_bytes),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), CoreWgtPingAddr, None, dims.wgt_subv_bytes),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), CoreOfmPingAddr, None, dims.ofm_subv_bytes),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
        CallKernel('run_resize_nni_a8', generate_resize_nni_noqdq_a8_params(dims)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
    ]
    MemtileParamSize = dims.aie_rows * CoreParamSize
    MemtileIfmSize = dims.ifm_subv_bytes
    MemtileWgtSize = dims.aie_rows * dims.wgt_subv_bytes
    MemtileOfmSize = dims.aie_rows * dims.ofm_subv_bytes
    MemtileParamAddr = 0
    MemtileIfmPingAddr = MemtileParamAddr + MemtileParamSize
    MemtileIfmPongAddr = MemtileIfmPingAddr + MemtileIfmSize
    MemtileWgtPingAddr = MemtileIfmPongAddr + MemtileIfmSize
    MemtileWgtPongAddr = MemtileWgtPingAddr + MemtileWgtSize
    MemtileOfmPingAddr = MemtileWgtPongAddr + MemtileWgtSize
    MemtileOfmPongAddr = MemtileOfmPingAddr + MemtileOfmSize
    print(f"MemtileParamAddr", MemtileParamAddr, "MemtileParamSize", MemtileParamSize)
    print(f"MemtileIfmPingAddr", MemtileIfmPingAddr, "MemtileIfmSize", MemtileIfmSize)
    print(f"MemtileIfmPongAddr", MemtileIfmPongAddr, "MemtileIfmSize", MemtileIfmSize)
    print(f"MemtileWgtPingAddr", MemtileWgtPingAddr, "MemtileWgtSize", MemtileWgtSize)
    print(f"MemtileWgtPongAddr", MemtileWgtPongAddr, "MemtileWgtSize", MemtileWgtSize)
    print(f"MemtileOfmPingAddr", MemtileOfmPingAddr, "MemtileOfmSize", MemtileOfmSize)
    print(f"MemtileOfmPongAddr", MemtileOfmPongAddr, "MemtileOfmSize", MemtileOfmSize)

    ShimIfmSize = dims.Yi * dims.Xi * dims.Ci * dims.ifm_bits // 8
    ShimWgtSize = dims.wgt_subv_bytes
    ShimOfmSize = dims.Yo * dims.Xo * dims.Co * dims.ofm_bits // 8
    print(f"ShimIfmSize", ShimIfmSize)
    print(f"ShimWgtSize", ShimWgtSize)
    print(f"ShimOfmSize", ShimOfmSize)

    memtile_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [MemtileParamAddr], MemtileParamSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                memtile_param_memory(dims),
                memtile_param_s2mm(dims),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                memtile_param_memory(dims),
                memtile_param_mm2s(dims, row),
                ) for row in range(dims.aie_rows)
            ],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [MemtileWgtPingAddr], MemtileWgtSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                memtile_wgt_memory(dims),
                memtile_wgt_s2mm(dims),
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                memtile_wgt_memory(dims),
                memtile_wgt_mm2s(dims, row),
                ) for row in range(dims.aie_rows)
            ],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [MemtileIfmPingAddr], MemtileIfmSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                memtile_ifm_memory(dims),
                memtile_ifm_s2mm(dims),
                dims.ifm_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 0),
                memtile_ifm_memory(dims),
                memtile_ifm_mm2s(dims),
                dims.ifm_bits,
                enable_padding=True,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [MemtileOfmPingAddr], MemtileOfmSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 2),
                memtile_ofm_memory(dims),
                memtile_ofm_s2mm(dims, row),
                dims.ofm_bits,
                ) for row in range(dims.aie_rows)
            ],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 5),
                memtile_ofm_memory(dims),
                memtile_ofm_mm2s(dims),
                dims.ofm_bits,
            )],
        ) for col in range(dims.aie_cols)
    ]

    shim_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ResizeShimAlloc.prm_buffer_id], MemtileParamSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, 0),
                shimtile_param_memory(dims),
                shimtile_param_mm2s(dims, col),
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ResizeShimAlloc.wgt_buffer_id], ShimWgtSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, 1),
                shimtile_wgt_memory(dims),
                shimtile_wgt_mm2s(dims, col),
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ResizeShimAlloc.ifm_buffer_id], ShimIfmSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, 0),
                shimtile_ifm_memory(dims),
                shimtile_ifm_mm2s(dims, col),
                dims.ifm_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ResizeShimAlloc.ofm_buffer_id], ShimOfmSize,
            [generate_transfer_params(
                shim_dma(col, DmaDir.S2MM, 0),
                shimtile_ofm_memory(dims),
                shimtile_ofm_s2mm(dims, col),
                dims.ofm_bits,
            )],
            [],
        ) for col in range(dims.aie_cols)
    ]

    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instructions,
        memtile_transfers,
        shim_transfers,
        aie4_overlay_dma_connections(dims.aie_cols, dims.aie_rows),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id = 0,
    )
