from curses import raw
import os
import sys
from typing import List

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))

from gemm_helpers import GemmDims, gen_aie4_gemm_params

from kerneltest.helpers import \
    ceildiv, \
    iceil

from dmacompiler import \
    DevGen, \
    config, \
    OverlayShape, DataTransfer, TransferParams, SyncStrategy, BackEnd, generate_transfer_params, \
    DmaChannel, DmaDir, AieDma, AieTile, TileType, DmaConnection, \
    memtile_dma, shim_dma, \
    ConfigBuffer, AcqBuffer, RelBuffer, CallKernel, Loop, \
    run_layer_compilation, set_dev_gen

from kerneltest.overlay_1x1 import \
    overlay_stack_addr, \
    aie4_overlay_dma_connections, \
    shim_alloc 

set_dev_gen(DevGen.Aie4)

def param_memtile_memory(dims: GemmDims) -> str:
    return f'Row:{dims.aie_rows} Bytes:{dims.param_size}'

def param_memtile_s2mm(dims: GemmDims) -> str:
    return f'Row:0:{dims.aie_rows} Bytes:0:{dims.param_size}'

def param_memtile_mm2s(dims : GemmDims, row : int) -> str:
    return f'Row:{row}:{row + 1} Bytes:0:{dims.param_size}'

def ifm_memtile_memory(dims: GemmDims) -> str:
    return f'K:{dims.K} M:{dims.Msubv} K:{dims.K_gran}'

def ifm_memtile_s2mm(dims: GemmDims) -> str:
    return f'M:0:{dims.Msubv} K:0:{dims.K}'

def ifm_memtile_mm2s(dims: GemmDims) -> str:
    return f'K:0:{dims.K}:{dims.K_gran} M:0:{dims.Msubv} K:0:{dims.K_gran}'

def wgt_memtile_memory(dims: GemmDims) -> str:
    return f'Row:{dims.K_loop} Col:{dims.N_loop} Bytes:{dims.wgt_subv_bytes}'

def wgt_memtile_mm2s(dims: GemmDims) -> str:
    return f'Row:0:{dims.K_loop} Col:0:{dims.N_loop} Bytes:0:{dims.wgt_subv_bytes}'

def wgt_memtile_s2mm(dims: GemmDims) -> str:
    return f'Row:0:{dims.K_loop} Col:0:{dims.N_loop} Bytes:0:{dims.wgt_subv_bytes}'

def out_memtile_memory(dims: GemmDims) -> str:
    return f'M:{dims.Msubv} N:{dims.Nsubv}'

def out_memtile_s2mm(dims: GemmDims) -> str:
    return f'M:0:{dims.Msubv} N:0:{dims.Nsubv}'

def out_memtile_mm2s(dims: GemmDims) -> str:
    return f'M:0:{dims.Msubv} N:0:{dims.Nsubv}'

def param_shim_memory(dims : GemmDims) -> str:
    return f'Col:{dims.aie_cols} Row:{dims.aie_rows} Bytes:{dims.param_size}'

def param_shim_mm2s(dims : GemmDims, col: int) -> str:
    return f'Col:{col}:{col + 1} Row:0:{dims.aie_rows} Bytes:0:{dims.param_size}'

def ifm_shim_memory(dims: GemmDims) -> str:
    return f'M:{dims.M} K:{dims.K}'

def ifm_shim_mm2s(dims: GemmDims) -> str:
    return f'M:0:{dims.M} K:0:{dims.K}'

def wgt_shim_memory(dims: GemmDims) -> str:
    return f'Row:{dims.K_loop} Col:{dims.N_loop} Bytes:{dims.wgt_subv_bytes}'

def wgt_shim_mm2s(dims : GemmDims) -> str:
    return f'Row:0:{dims.K_loop} Col:0:{dims.N_loop} Bytes:0:{dims.wgt_subv_bytes}'

def out_shim_memory(dims : GemmDims) -> str:
    return f'M:{dims.M} N:{dims.N}'

def out_shim_s2mm(dims: GemmDims) -> str:
    return f'M:0:{dims.M} N:0:{dims.N}'

def compile_gemm_dataflow(
    dims : GemmDims,
    backend: BackEnd
):
    kernel_names = ['run_gemm_int16x8']
    kernel_includes = ['super.hh', 'gemm_qdq_int16x8/gemm_int16x8_wrapper.cc']
    
    gemm_shim_alloc = shim_alloc()
    
    core_align_size = 128
    core_ifm_size = iceil(dims.Msubv * dims.Ksubv * dims.ifm_bits // 8, core_align_size)
    core_wgt_size = iceil(dims.wgt_subv_bytes, core_align_size)
    core_out_size = iceil(dims.Msubv * dims.Nsubv * dims.out_bits // 8, core_align_size)
    
    core_stack_addr = overlay_stack_addr()
    
    core_ifm_ping_addr = 0
    core_wgt_ping_addr = core_ifm_ping_addr + core_ifm_size
    core_out_ping_addr = core_wgt_ping_addr + core_wgt_size

    core_spill_buf = core_out_ping_addr + core_out_size
    core_ifm_tmp_buffer = core_spill_buf + 1536 
    core_coeff_tmp_buffer = core_ifm_tmp_buffer + core_ifm_size 

    print(f'Core IFM Ping Address: {core_ifm_ping_addr}, core IFM Size: {core_ifm_size}')
    print(f'Core WGT Ping Address: {core_wgt_ping_addr}, core WGT Size: {core_wgt_size}')
    print(f'Core Out Ping Address: {core_out_ping_addr}, core Out Size: {core_out_size}')
    print(f'Core Spill Buffer: {core_spill_buf}, core Spill Size: 1536')
    print(f'Core IFM Temp Buffer: {core_ifm_tmp_buffer}, core IFM Temp Size: {core_ifm_size}')
    print(f'Core Coeff Temp Buffer: {core_coeff_tmp_buffer}, core Coeff Temp Size: {256}')

    assert core_coeff_tmp_buffer + 256 <= core_stack_addr, \
        f'Core stack address {core_stack_addr} is not sufficient for core buffers. '
    
    core_instrs = [
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 0), core_ifm_ping_addr, None, core_ifm_size),
        ConfigBuffer(DmaChannel(DmaDir.S2MM, 1), core_wgt_ping_addr, None, core_wgt_size),
        ConfigBuffer(DmaChannel(DmaDir.MM2S, 0), core_out_ping_addr, None, core_out_size),
    ] + [
        outer_instr for _ in range(dims.M_loop * dims.N_loop) for outer_instr in [
            inner_instr for _ in range(dims.K_loop - 1) for inner_instr in [
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0), disable=True),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 1), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0), disable=True),
                RelBuffer(DmaChannel(DmaDir.S2MM, 1), disable=True),
                ]
            ] + [
        AcqBuffer(DmaChannel(DmaDir.S2MM, 0), disable=True),
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1), disable=True),
        AcqBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
        RelBuffer(DmaChannel(DmaDir.S2MM, 0), disable=True),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1), disable=True),
        RelBuffer(DmaChannel(DmaDir.MM2S, 0), disable=True),
        ]
    ] + [
        CallKernel('run_gemm_int16x8', gen_aie4_gemm_params(0, dims, core_spill_buf, core_ifm_tmp_buffer, core_coeff_tmp_buffer)),
    ]
    
    # Memtile Allocation
    
    ifm_memtile_size = dims.M * dims. K * dims.ifm_bits // 8
    wgt_memtile_size = dims.K_loop * dims.N_loop * dims.wgt_subv_bytes
    out_memtile_size = dims.M * dims.N * dims.out_bits // 8
    
    param_memtile_size = dims.aie_rows * dims.param_size
    param_memtile_addr = 0
    ifm_memtile_ping_addr = param_memtile_addr + param_memtile_size 
    ifm_memtile_pong_addr = ifm_memtile_ping_addr + ifm_memtile_size
    wgt_memtile_ping_addr = ifm_memtile_pong_addr + ifm_memtile_size
    wgt_memtile_pong_addr = wgt_memtile_ping_addr + wgt_memtile_size
    out_memtile_ping_addr = wgt_memtile_pong_addr + wgt_memtile_size
    out_memtile_pong_addr = out_memtile_ping_addr + out_memtile_size
    print(f"Param Memtile Address: {param_memtile_addr}, Size: {param_memtile_size}")
    print(f"IFM Memtile Ping Address: {ifm_memtile_ping_addr}, Size: {ifm_memtile_size}")
    print(f"IFM Memtile Pong Address: {ifm_memtile_pong_addr}, Size: {ifm_memtile_size}")
    print(f"WGT Memtile Ping Address: {wgt_memtile_ping_addr}, Size: {wgt_memtile_size}")
    print(f"WGT Memtile Pong Address: {wgt_memtile_pong_addr}, Size: {wgt_memtile_size}")
    print(f"OUT Memtile Ping Address: {out_memtile_ping_addr}, Size: {out_memtile_size}")
    print(f"OUT Memtile Pong Address: {out_memtile_pong_addr}, Size: {out_memtile_size}")
    
    # Shim Allocation
    
    ifm_shim_size = dims.M * dims.K * dims.ifm_bits // 8
    wgt_shim_size = dims.K_loop * dims.N_loop * dims.wgt_subv_bytes
    out_shim_size = dims.M * dims.N * dims.out_bits // 8
    
    memtile_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [param_memtile_addr], param_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                param_memtile_memory(dims),
                param_memtile_s2mm(dims),
                dims.param_bits,
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 0),
                param_memtile_memory(dims),
                param_memtile_mm2s(dims, row),
                dims.param_bits,
                ) for row in range(dims.aie_rows)
            ],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [ifm_memtile_ping_addr], ifm_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 0),
                ifm_memtile_memory(dims),
                ifm_memtile_s2mm(dims),
                dims.ifm_bits
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, row),
                ifm_memtile_memory(dims),
                ifm_memtile_mm2s(dims),
                dims.ifm_bits
            ) for row in range(dims.aie_rows)
            ]
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [wgt_memtile_ping_addr], wgt_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 1),
                wgt_memtile_memory(dims),
                wgt_memtile_s2mm(dims),
                dims.wgt_bits
            )],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 4),
                wgt_memtile_memory(dims),
                wgt_memtile_mm2s(dims),
                dims.wgt_bits
            )]
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Memtile, col), [out_memtile_ping_addr], out_memtile_size,
            [generate_transfer_params(
                memtile_dma(col, DmaDir.S2MM, 2 + row),
                out_memtile_memory(dims),
                out_memtile_s2mm(dims),
                dims.out_bits
            ) for row in range(dims.aie_rows)
            ],
            [generate_transfer_params(
                memtile_dma(col, DmaDir.MM2S, 5),
                out_memtile_memory(dims),
                out_memtile_mm2s(dims),
                dims.out_bits
            )]
        ) for col in range(dims.aie_cols)
    ]
    
    shim_transfers = [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [gemm_shim_alloc.prm_buffer_id], param_memtile_size,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, 0),
                param_shim_memory(dims),
                param_shim_mm2s(dims, col),
                dims.param_bits,
            )],
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [gemm_shim_alloc.ifm_buffer_id], ifm_shim_size,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, 0),
                ifm_shim_memory(dims),
                ifm_shim_mm2s(dims),
                dims.ifm_bits
            )]
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [gemm_shim_alloc.wgt_buffer_id], wgt_shim_size,
            [],
            [generate_transfer_params(
                shim_dma(col, DmaDir.MM2S, 1),
                wgt_shim_memory(dims),
                wgt_shim_mm2s(dims),
                dims.wgt_bits
            )]
        ) for col in range(dims.aie_cols)
    ] + [
        DataTransfer(
            [1],
            AieTile(TileType.Shim, col), [gemm_shim_alloc.ofm_buffer_id], out_shim_size,
            [generate_transfer_params(
                shim_dma(col, DmaDir.S2MM, 0),
                out_shim_memory(dims),
                out_shim_s2mm(dims),
                dims.out_bits
            )],
            []
        ) for col in range(dims.aie_cols)
    ]
    
    run_layer_compilation(
        OverlayShape(dims.aie_cols, dims.aie_rows),
        kernel_names,
        kernel_includes,
        core_instrs,
        memtile_transfers,
        shim_transfers,
        aie4_overlay_dma_connections(dims.aie_cols, dims.aie_rows),
        back_end=backend,
        core_stack_addr=overlay_stack_addr(),
        param_channel_id = 0,
    )

def generate_gemm_dataflow(
    M: int, 
    K: int,
    N: int, 
    Msubv: int, 
    Ksubv: int,
    Nsubv: int,
    sign_A: bool,
    sign_W: bool,
    sign_O: bool,
    shift_out: int,
    vector_coeff: int,
    back_end: BackEnd,
):  
    aie_rows = 1
    aie_cols = 1
    act_bits = 16
    wgt_bits = 8
    c0_bits = 32
    c1_bits = 32
    c2_bits = 32
    out_bits = 16
    K_gran = 64
    assert vector_coeff in [0, 1, 2], \
        f'Invalid vector coefficient {vector_coeff}. Must be 0, 1, or 2.'
    assert Ksubv >= K_gran * 4, \
        f'Ksubv {Ksubv} must be at least {K_gran * 4} to ensure inner loop of atleast 4.'
    dims = GemmDims(
        M, K, N,
        Msubv, Ksubv, Nsubv,
        aie_rows, aie_cols,
        act_bits, wgt_bits, c0_bits,
        c1_bits, c2_bits, out_bits,
        K_gran,
        sign_A, sign_W, sign_O, shift_out, vector_coeff
    )
    compile_gemm_dataflow(dims, back_end)

def main():
    M = 32
    K = 64
    N = 64
    
    Msubv = 32
    Ksubv = 64
    Nsubv = 64
    
    generate_gemm_dataflow(
        M, K, N,
        Msubv, Nsubv, Ksubv,
        False, False, False,
        0, 1
    )
    
    
if __name__ == "__main__":
    main()


    
    