from gc import enable
import os
import sys
from typing import List

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from kerneltest.maxpool.maxpool_helpers import \
    MaxpoolDims, \
    gen_aie4_maxpool_params \

from kerneltest.helpers import \
    ceildiv, \
    iceil, \
    round_up_to_multiple
    
from dmacompiler import \
    DevGen, \
    config, \
    OverlayShape, DataTransfer, TransferParams, SyncStrategy, BackEnd, generate_transfer_params, \
    DmaChannel, DmaDir, AieDma, AieTile, TileType, DmaConnection, \
    memtile_dma, shim_dma, \
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    run_layer_compilation, set_dev_gen

from kerneltest.overlay_1x1 import \
    overlay_stack_addr, \
    aie4_overlay_dma_connections, \
    shim_alloc 

set_dev_gen(DevGen.Aie4)

def memtile_param_memory(dims: MaxpoolDims) -> str:
    return f'row:{dims.aie_rows} Bytes:{dims.Param_size}'

def memtile_param_s2mm(dims: MaxpoolDims) -> str:
    return f'row:0:{dims.aie_rows} Bytes:0:{dims.Param_size}'

def memtile_param_mm2s(dims: MaxpoolDims, row: int) -> str:
    return f'row:{row}:{row+1} Bytes:0:{dims.Param_size}'

def memtile_wgt_memory(dims: MaxpoolDims) -> str:
    return f' Bytes:{dims.wgt_subv_bytes}'

def memtile_wgt_s2mm(dims: MaxpoolDims) -> str:
    return f'Bytes:0:{dims.wgt_subv_bytes}'

def memtile_wgt_mm2s(dims: MaxpoolDims) -> str:
    return f'Bytes:0:{dims.wgt_subv_bytes}'

def memtile_ifm_memory(dims: MaxpoolDims) -> str:
    return f'Yi:{dims.Yis} Ci:{dims.Cs} Xi:{dims.Xis} Ci:{dims.C_gran}'

def memtile_ifm_s2mm(dims: MaxpoolDims) -> str:
    return f'Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Cs}'

def memtile_ifm_mm2s(dims: MaxpoolDims) -> str:
    return f'Yi:0:{dims.Yis} Ci:0:{dims.Cs_pad}:{dims.C_gran} Xi:0:{dims.Xis} Ci:0:{dims.C_gran}'

def memtile_ofm_memory(dims: MaxpoolDims) -> str:
    # NOTE: Assuming C0 split across rows
    Co_per_col = dims.aie_rows * dims.Cs_pad
    return f'Yo:{dims.Yos} Xo:{dims.Xos} Co:{Co_per_col}'

def memtile_ofm_s2mm(dims: MaxpoolDims, row: int) -> str:
    start_Cos = row * dims.Cs_pad
    Stop_Cos = start_Cos + dims.Cs_pad
    Step_Cos = dims.Cs_pad
    return f'Yo:0:{dims.Yos} Co:0:{dims.Cs_pad}:{dims.C_gran} Xo:0:{dims.Xos} Co:0:{dims.C_gran}'

def memtile_ofm_mm2s(dims: MaxpoolDims) -> str:
    # NOTE: Assuming C0 split across rows
    Co_per_col = dims.aie_rows * dims.Cs
    return f'Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{Co_per_col}'

def shimtile_param_memory(dims: MaxpoolDims) -> str:
    return f'col:{dims.aie_cols} row:{dims.aie_rows} Bytes:{dims.Param_size}'

def shimtile_param_mm2s(dims: MaxpoolDims, col: int) -> str:
    return f'col:{col}:{col+1} row:0:{dims.aie_rows} Bytes:0:{dims.Param_size}'

def shimtile_wgt_memory(dims: MaxpoolDims) -> str:
    return f'Bytes:{dims.wgt_subv_bytes}'

def shimtile_wgt_mm2s(dims: MaxpoolDims) -> str:
    return f'Bytes:0:{dims.wgt_subv_bytes}'

def shimtile_ifm_memory(dims: MaxpoolDims) -> str:
    return f'Yi:{dims.Yi} Xi:{dims.Xi} Ci:{dims.C}'

def shimtile_ifm_mm2s(dims: MaxpoolDims, col: int) -> str:
    return f'Ci:{0}:{dims.C}:{dims.Cs} Yi:0:{dims.Yis} Xi:0:{dims.Xis} Ci:0:{dims.Cs}'

def shimtile_ofm_memory(dims: MaxpoolDims) -> str:
    return f'Yo:{dims.Yo} Xo:{dims.Xo} Co:{dims.C}'

def shimtile_ofm_s2mm(dims: MaxpoolDims, col: int) -> str:
    start_Co = col * dims.Cs
    Stop_Co = dims.aie_rows * dims.Cs
    Step_Co = dims.aie_rows * dims.Cs
    return f'Yo:0:{dims.Yos} Xo:0:{dims.Xos} Co:0:{dims.Cs}'

def compile_maxpool_dataflow(
    dims: MaxpoolDims,
    back_end: BackEnd,
    core_ifm_s2mm_channel: int = 0,
):
    kernel_names = ['run_maxpool_int8x8']
    kernel_includes = ['super.hh', 'maxpool/maxpool_int8x8_wrapper.cc']
    shim_ifm_mm2s_channel = 1
    shim_param_mm2s_channel = 0
    shim_wgt_mm2s_channel = 0
    shim_ofm_s2mm_channel = 0

    memtile_param_s2mm_channel = 0
    memtile_wgt_s2mm_channel = 0
    memtile_param_mm2s_channel = []
    for row in range(dims.aie_rows):
        memtile_param_mm2s_channel.append(row)
    memtile_ifm_s2mm_channel = 1
    memtile_ifm_mm2s_channel = []
    for row in range(dims.aie_rows):
        memtile_ifm_mm2s_channel.append(row)
    core_ifm_s2mm_channel = 0
    core_wgt_s2mm_channel = 1
    memtile_ofm_s2mm_channel = []
    for row in range(dims.aie_rows):
        memtile_ofm_s2mm_channel.append(row+2)
    memtile_ofm_mm2s_channel = 5
    memtile_wgt_mm2s_channel = 4
    core_ofm_mm2s_channel = 0

    CoreParamSize = dims.Param_size
    ConvShimAlloc = shim_alloc()

    CoreBankSize = 16 * 1024 
    CoreAlignSize = 128
    print(f"act_subv_bytes", dims.act_subv_bytes, "out_subv_bytes", dims.out_subv_bytes)
    CoreIfmSize = round_up_to_multiple(dims.act_subv_bytes, CoreAlignSize)
    CoreOfmSize = round_up_to_multiple(dims.out_subv_bytes, CoreAlignSize)
    CoreStackAddr = overlay_stack_addr()
    CoreIfmPingAddr = 0
    CoreOfmPingAddr = CoreIfmPingAddr + CoreIfmSize
    # CoreIfmPongAddr = CoreOfmPingAddr + CoreOfmSize
    # CoreOfmPongAddr = CoreIfmPongAddr + CoreIfmSize
    CoreWgtAddr = CoreOfmPingAddr + CoreOfmSize
    print(f"CoreIfmPingAddr", CoreIfmPingAddr, "CoreIfmSize", CoreIfmSize)
    print(f"CoreOfmPingAddr", CoreOfmPingAddr, "CoreOfmSize", CoreOfmSize)
    # print(f"CoreIfmPongAddr", CoreIfmPongAddr, "CoreIfmSize", CoreIfmSize)
    # print(f"CoreOfmPongAddr", CoreOfmPongAddr, "CoreOfmSize", CoreOfmSize)
    assert ( (CoreWgtAddr + dims.wgt_subv_bytes) < overlay_stack_addr() )

    core_instructions = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel), CoreIfmPingAddr, None, dims.act_subv_bytes),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel), CoreWgtAddr, None, dims.wgt_subv_bytes),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel), CoreOfmPingAddr, None, dims.out_subv_bytes),
        AcqBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
        RelBuffer(DmaChannel(DmaDir.S2MM, core_wgt_s2mm_channel)),
        Loop((dims.C_loop * dims.Y_loop * dims.X_loop),
             [
                AcqBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                AcqBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
                CallKernel('run_maxpool_int8x8', gen_aie4_maxpool_params(dims, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.S2MM, core_ifm_s2mm_channel)),
                RelBuffer(DmaChannel(DmaDir.MM2S, core_ofm_mm2s_channel)),
             ]
        ),
    ]

    MemtileParamSize = dims.aie_rows * CoreParamSize
    MemtileIfmSize = dims.act_subv_bytes
    MemtileOfmSize = dims.aie_rows * dims.act_subv_bytes
    MemtileWgtSize = dims.wgt_subv_bytes
    MemtileParamAddr = 0
    MemtileWgtAddr = MemtileParamAddr + MemtileParamSize 
    MemtileIfmPingAddr = MemtileWgtAddr + MemtileWgtSize 
    MemtileIfmPongAddr = MemtileIfmPingAddr + MemtileIfmSize
    MemtileOfmPingAddr = MemtileIfmPongAddr + MemtileIfmSize
    MemtileOfmPongAddr = MemtileOfmPingAddr + MemtileOfmSize

    ShimIfmSize = dims.Yi * dims.Xi * dims.C * dims.act_bits // 8
    ShimOfmSize = dims.Yo * dims.Xo * dims.C * dims.out_bits // 8

    memtile_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [MemtileParamAddr], MemtileParamSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_param_s2mm_channel),
                memtile_param_memory(dims),
                memtile_param_s2mm(dims),
                dims.param_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_param_mm2s_channel[row]),
                memtile_param_memory(dims),
                memtile_param_mm2s(dims, row),
                dims.param_bits,
                ) for row in range(dims.aie_rows)
            ],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [MemtileWgtAddr], MemtileWgtSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_wgt_s2mm_channel),
                memtile_wgt_memory(dims),
                memtile_wgt_s2mm(dims),
                dims.param_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_wgt_mm2s_channel),
                memtile_wgt_memory(dims),
                memtile_wgt_mm2s(dims),
                dims.param_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [dims.Y_loop * dims.X_loop * dims.C_loop],
            AieTile(TileType.Memtile, col), [MemtileIfmPingAddr, MemtileIfmPongAddr], MemtileIfmSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_ifm_s2mm_channel),
                memtile_ifm_memory(dims),
                memtile_ifm_s2mm(dims),
                dims.act_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_ifm_mm2s_channel[row]),
                memtile_ifm_memory(dims),
                memtile_ifm_mm2s(dims),
                dims.act_bits,
                enable_padding = True if dims.Cs < dims.Cs_pad else False
                ) for row in range(dims.aie_rows)
            ],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [dims.Y_loop * dims.X_loop * dims.C_loop],
            AieTile(TileType.Memtile, col), [MemtileOfmPingAddr, MemtileOfmPongAddr], MemtileOfmSize,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, memtile_ofm_s2mm_channel[row]),
                memtile_ofm_memory(dims),
                memtile_ofm_s2mm(dims, row),
                dims.out_bits,
                ) for row in range(dims.aie_rows)
            ],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, memtile_ofm_mm2s_channel),
                memtile_ofm_memory(dims),
                memtile_ofm_mm2s(dims),
                dims.out_bits,
            )],
        ) for col in range(dims.aie_cols)
    ]

    shim_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.prm_buffer_id], MemtileParamSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, shim_param_mm2s_channel),
                shimtile_param_memory(dims),
                shimtile_param_mm2s(dims, col),
                dims.param_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.wgt_buffer_id], MemtileWgtSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, shim_wgt_mm2s_channel),
                shimtile_wgt_memory(dims),
                shimtile_wgt_mm2s(dims),
                dims.param_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.ifm_buffer_id], ShimIfmSize,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, shim_ifm_mm2s_channel),
                shimtile_ifm_memory(dims),
                shimtile_ifm_mm2s(dims, col),
                dims.act_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [ConvShimAlloc.ofm_buffer_id], ShimOfmSize,
            [generate_transfer_params(
                shim_dma(col, DmaDir.S2MM, shim_ofm_s2mm_channel),
                shimtile_ofm_memory(dims),
                shimtile_ofm_s2mm(dims, col),
                dims.out_bits,
            )],
            [],
        ) for col in range(dims.aie_cols)
    ]
    
    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instructions,
        memtile_transfers,
        shim_transfers,
        aie4_overlay_dma_connections(dims.aie_cols, dims.aie_rows),
        back_end=back_end,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id = 0,
    )

def generate_maxpool_dataflow(
        Yi: int,
        Xi: int,
        Yo: int,
        Xo: int, 
        Co: int, 
        Yis: int, 
        Xis: int, 
        Yos: int, 
        Xos: int, 
        Cs: int, 
        Ky: int, 
        Kx: int, 
        Py: int, 
        Px: int, 
        Sy: int, 
        Sx: int,
        Backend: int = 0,
):        
    N = 1
    aie_rows = 1
    aie_cols = 1
    act_bits = 8
    out_bits = 8
    param_bits = 8 
    # NOTE: output subvol granuralities for maxpool
    Y_gran = 1
    X_gran = 1
    C_gran = 64
    Cs_min = 128
    Cs_pad = Cs if Cs > Cs_min else Cs_min 
    dims = MaxpoolDims(
        N,
        Yi, Xi, Co, Yo, Xo, Yis, Xis, Cs, Cs_pad, Yos, Xos, Ky, Kx, Py, Px, Sy, Sx,
        aie_rows, aie_cols, act_bits, out_bits, param_bits,
        Y_gran, X_gran, C_gran,
    )
    if Backend == 0:
        back_end = BackEnd.Adf
    else:
        back_end = BackEnd.TxnHostPatch
    compile_maxpool_dataflow(dims, back_end) 