import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'transpose'))

from typing import List, Tuple, Optional, Type

from dmacompiler import \
    AieDma, DmaDir, \
    TransferParams, generate_transfer_params

from transpose_common import \
    TransposeKernelDims, \
    iceil

from dataflow_common import \
    ceildiv

def pack_transfers(
    dma: AieDma,
    memory_fmts: List[str],
    tiling_fmts: List[str],
    tiling_iters: List[int],
    bits_per_elem: int,
) -> TransferParams:
    assert len(memory_fmts) == len(tiling_fmts)
    assert len(tiling_fmts) == len(tiling_iters)
    def pack(items: list) -> list:
        assert len(items) == len(tiling_iters)
        res = []
        for item, num in zip(items, tiling_iters):
            res += [item] * num
        return res
    num_fmts = len(tiling_fmts)
    params = [
        generate_transfer_params(
            dma,
            memory_fmts[i],
            tiling_fmts[i],
            bits_per_block=bits_per_elem,
            enable_padding=(dma.channel.dir == DmaDir.MM2S),
        ) for i in range(num_fmts)
    ]
    packed_param = TransferParams(
        dma,
        pack([param.length_i(0) for param in params]),
        offset=pack([param.offset_i(0) for param in params]),
        step=pack([param.step_i(0) for param in params]),
        wrap=pack([param.wrap_i(0) for param in params]),
        padding=pack([param.padding_i(0) for param in params]),
    )
    return packed_param

def Yi_slice_stride(dims: TransposeKernelDims, col: int, n_iter: int, data_bits: int) -> Tuple[int, int, int, int]:
    Yi_split = ceildiv(dims.Yip, dims.Y_loop * dims.aie_cols)
    if dims.perm[3] == 1:
        Yi_gran = 32 // data_bits
        Yi_split = iceil(Yi_split, Yi_gran)
    else:
        Yi_split = iceil(Yi_split, dims.Yi_gran)
    Yi_stride = dims.aie_cols * Yi_split
    Yi_start = col * Yi_split + n_iter * Yi_stride
    Yi_stop = Yi_start + Yi_split if Yi_start <= dims.Yip else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yip)) - max(0, min(Yi_start, dims.Yip))
    return (Yi_start, Yi_stop, Yi_size)

def Ni_slice_stride(dims: TransposeKernelDims, n: int) -> Tuple[int, int, int]:
    Ni_split = dims.Nis
    Ni_start = n * Ni_split
    Ni_stop = Ni_start + Ni_split if Ni_start <= dims.Nip else Ni_start
    Ni_size = max(0, min(Ni_stop, dims.Nip)) - max(0, min(Ni_start, dims.Nip))
    return (Ni_start, Ni_stop, Ni_size)

def Ci_slice_stride(dims: TransposeKernelDims, n: int) -> Tuple[int, int, int]:
    Ci_split = dims.Cim
    Ci_start = n * Ci_split
    Ci_stop = min(Ci_start + Ci_split, dims.Cip)
    Ci_size = max(0, min(Ci_stop, dims.Cip)) - max(0, min(Ci_start, dims.Cip))
    return (Ci_start, Ci_stop, Ci_size)

def YiXi_slice(dims: TransposeKernelDims, n: int, is_Y: bool = True) -> Tuple[int, int, int]:
    if is_Y:
        split = dims.Yim
        stop  = dims.Yip
    else:
        split = dims.Xim
        stop  = dims.Xip
    YiXi_split = split
    YiXi_start = n * YiXi_split
    YiXi_stop = min(YiXi_start + YiXi_split, stop)
    YiXi_size = max(0, min(YiXi_stop, stop)) - max(0, min(YiXi_start, stop))
    return (YiXi_start, YiXi_stop, YiXi_size)

def YoXo_slice(dims: TransposeKernelDims, n: int, is_Y: bool = True) -> Tuple[int, int, int]:
    if is_Y:
        last_loop = dims.Y_loop -1
    else:
        last_loop = dims.X_loop -1
    if n == last_loop:
        padding = dims.num_padding
    else:
        padding = 0
    if is_Y:
        split = dims.Yom
        stop  = dims.Yop
    else:
        split = dims.Xom
        stop  = dims.Xop
    YiXi_split = split
    YiXi_start = n * YiXi_split
    YiXi_stop = min(YiXi_start + YiXi_split + padding, stop)
    YiXi_size = max(0, min(YiXi_stop, stop)) - max(0, min(YiXi_start, stop))
    return (YiXi_start, YiXi_stop, YiXi_size)


def Co_slice_stride(dims: TransposeKernelDims, n: int) -> Tuple[int, int, int]:
    Co_split = dims.Cim #NOTE: this is not a typo, use Cim instead of Com
    Co_start = n * Co_split
    Co_stop = min(Co_start + Co_split, dims.Cop)
    Co_size = max(0, min(Co_stop, dims.Cop)) - max(0, min(Co_start, dims.Cop))
    return (Co_start, Co_stop, Co_size)

def YXC_slice_mt(
    dims:TransposeKernelDims,
    Y_size: int,
    C_size: int,
    row:int,
    s2mm = False,
    is_Y_split = True,
    )  -> Tuple[int, int, int, int, int, int]:
    def Y_slice(dims: TransposeKernelDims,
                 Y_size: int, row: int, s2mm = False, is_Y = True) -> Tuple[int, int]:
        # Yis can be Y or C dim
        if is_Y:
            Yis = ceildiv(dims.Yim, dims.aie_rows)
            Yis = iceil(Yis, 32 //dims.ifm_bits) if dims.perm[3] == 1 else Yis
        else:
            Yis = ceildiv(dims.Cim, dims.aie_rows)
        Y_split = Yis
        if not is_Y:
            Y_split = iceil(Y_split, dims.Ci_gran)
            Yis = Y_split
        Y_start = row * Y_split
        if s2mm:
            Y_stop = Y_start + Y_split
        else:
            if Y_start >= Y_size:
                # Y_start = Y_size - Y_split
                Y_start = 0
                Y_stop = dims.Yis if is_Y else dims.Cis
            else:
                Y_start = Y_start
                Y_stop  = Y_start + Y_split
        Yo_size = Y_split * dims.aie_rows
        return(Y_start, Y_stop, Yo_size, Yis)
    if Y_size <= 0:
        Y_size = dims.Yi_gran
    if C_size <=0:
        C_size = dims.Ci_gran
    if is_Y_split:
        Y_start, Y_stop, Yo_size, Yis = Y_slice(dims, Y_size, row, s2mm)
        C_start = 0
        C_split = dims.Cim
        Co_size = C_split
        C_stop  = C_split
        Cis     = C_split
    else:
        Y_split = dims.Yim
        Y_start = 0
        Y_stop  = Y_split
        Yo_size = Y_split
        C_start, C_stop, Co_size, Cis = Y_slice(dims, C_size, row, s2mm, is_Y=False)
        Yis = Yo_size
    if dims.perm[3] == 2:
        X_start = 0
        X_stop = iceil(dims.Xip, 32 // dims.ifm_bits)
    else:# Y will be transposed to inner most and Y meet max(64, W8)
        X_start = 0
        X_stop  = dims.Xip
    Xo_size = X_stop
    Xis = Xo_size

    return(Y_start, Y_stop, X_start, X_stop, C_start, C_stop,
           Yo_size, Xo_size, Co_size, Yis, Xis, Cis)

def Xi_slice_stride_shim(dims: TransposeKernelDims, col: int, n_iter: int) -> Tuple[int, int, int, int]:
    Xi_split = ceildiv(dims.Xip, dims.X_loop * dims.aie_cols)
    Xi_stride = dims.aie_cols * Xi_split
    Xi_start = col * Xi_split + n_iter * Xi_stride
    Xi_stop = Xi_start + Xi_split if Xi_start <= dims.Xip else Xi_start
    Xi_size = max(0, min(Xi_stop, dims.Xip)) - max(0, min(Xi_start, dims.Xip))
    return (Xi_start, Xi_stop, Xi_size)

def Xi_slice_stride_mt(dims: TransposeKernelDims, col: int, row: int, n_iter: int, data_bits: int, s2mm =False) -> Tuple[int, int, int, int]:
    _, _, X_size = Xi_slice_stride_shim(dims, col, n_iter)
    if X_size <= 0:
        X_size = dims.Xi_gran
    Xi_split = dims.Xis
    if dims.perm[3] == 2:
        Xi_split = iceil(Xi_split, 32 // data_bits )
    else:
        Xi_split = iceil(Xi_split, 1)
    Xi_start = row * Xi_split
    if s2mm:
        Xi_stop = Xi_start + Xi_split
    else:
        if Xi_start >= X_size:
            Xi_start = X_size - Xi_split
            Xi_stop = X_size
        else:
            Xi_start = Xi_start
            Xi_stop = Xi_start + Xi_split

    Xi_stop = Xi_start + Xi_split
    Xi_size = max(0, min(Xi_stop, dims.Xip)) - max(0, min(Xi_start, dims.Xip))
    return (Xi_start, Xi_stop, Xi_size)