import os
import sys
from typing import List

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', 'dmacompiler'))

from dmacompiler import \
    DmaChannel, DmaDir, AieDma, \
    AieTile, TileType, DmaConnection

def overlay_stack_addr() -> int:
    return 120*1024 

def overlay_stack_size() -> int:
    return 2*1024

def overlay_heap_size() -> int:
    return 2*1024

class ShimAlloc:
    def __init__(
        self,
        ifm_buffer_id: int,
        wgt_buffer_id: int,
        ofm_buffer_id: int,
        prm_buffer_id: int,
    ):
        self.ifm_buffer_id = ifm_buffer_id
        self.wgt_buffer_id = wgt_buffer_id
        self.ofm_buffer_id = ofm_buffer_id
        self.prm_buffer_id = prm_buffer_id
        
def shim_alloc() -> ShimAlloc:
    return ShimAlloc(1, 2, 0, 3)

def aie4_overlay_dma_connections(NumCols: int, NumRows: int) -> List[DmaConnection]:
    # assert NumCols == 3
    # assert NumRows == 4
    dma_connections = [
        DmaConnection(AieDma(AieTile(TileType.Shim,    col, 0), DmaChannel(DmaDir.MM2S, 0)),
                      AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 0)))
        for col in range(NumCols)
    ] + [
        DmaConnection(AieDma(AieTile(TileType.Shim,    col, 0), DmaChannel(DmaDir.MM2S, 1)),
                      AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.S2MM, 1)))
        for col in range(NumCols)
    ] + [
        DmaConnection(AieDma(AieTile(TileType.Memtile, col, 0),   DmaChannel(DmaDir.MM2S, row)),
                      AieDma(AieTile(TileType.Core,    col, row), DmaChannel(DmaDir.S2MM, 0)))
        for col in range(NumCols)
        for row in range(NumRows)
    ] + [
        DmaConnection(AieDma(AieTile(TileType.Memtile, col, 0),   DmaChannel(DmaDir.MM2S, 4)),
                      AieDma(AieTile(TileType.Core,    col, row), DmaChannel(DmaDir.S2MM, 1)))
        for col in range(NumCols)
        for row in range(NumRows)
    ] + [
        DmaConnection(AieDma(AieTile(TileType.Core,    col, row), DmaChannel(DmaDir.MM2S, 0)),
                      AieDma(AieTile(TileType.Memtile, col, 0),   DmaChannel(DmaDir.S2MM, 2 + row)))
        for col in range(NumCols)
        for row in range(NumRows)
    ] + [
        DmaConnection(AieDma(AieTile(TileType.Memtile, col, 0), DmaChannel(DmaDir.MM2S, 5)),
                      AieDma(AieTile(TileType.Shim,    col, 0), DmaChannel(DmaDir.S2MM, 0)))
        for col in range(NumCols)
    ]
    return dma_connections