import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
from typing import Tuple, List, Union, Optional
from dataclasses import dataclass

from dmacompiler import \
    OverlayShape, BackEnd, \
    DataTransfer, SyncStrategy, \
    AieTile, TileType, \
    AieDma, DmaDir, TransferParams, core_dma, memtile_dma, shim_dma, \
    compute_buffer_size, \
    generate_core_buffer_config, \
    generate_transfer_params, \
    generate_shim_data_transfer, \
    run_layer_compilation, \
    set_dev_gen, DevGen, config
from conv.conv_common import \
    conv_input

"""Note:
    1. all shared common variables are defined here;
    2. each dims should add these variables especially the K/S/P should be default as:
       1) Ky = Kx = 1
       2) Sy = Sx = 1
       3) Py_b = Px_b = Py_a = Px_a = 0
       4) is_X8_split = False
       5) etc.
"""
@dataclass
class CommonDims():
    aie_cols: Optional[int] = None
    aie_rows: Optional[int] = None

    # Size of data in bits
    ifm_bits: Optional[int] = None
    ofm_bits: Optional[int] = None

    # Origial Input Size
    Ni: Optional[Union[int, List[int]]] = None
    Yi: Optional[Union[int, List[int]]] = None
    Xi: Optional[Union[int, List[int]]] = None
    Ci: Optional[Union[int, List[int]]] = None
    
    # Origial Output Size
    No: Optional[Union[int, List[int]]] = None
    Yo: Optional[Union[int, List[int]]] = None
    Xo: Optional[Union[int, List[int]]] = None
    Co: Optional[Union[int, List[int]]] = None
    
    # Padded Input Size
    Nip: Optional[Union[int, List[int]]] = None
    Yip: Optional[Union[int, List[int]]] = None
    Xip: Optional[Union[int, List[int]]] = None
    Cip: Optional[Union[int, List[int]]] = None
    
    # Padded Output Size
    Nop: Optional[Union[int, List[int]]] = None
    Yop: Optional[Union[int, List[int]]] = None
    Xop: Optional[Union[int, List[int]]] = None
    Cop: Optional[Union[int, List[int]]] = None
    
    # Input Subvolume Size
    Nis: Optional[int] = None
    Yis: Optional[int] = None
    Xis: Optional[int] = None
    Cis: Optional[int] = None
    
    # Output Subvolume Size
    Nos: Optional[int] = None
    Yos: Optional[int] = None
    Xos: Optional[int] = None
    Cos: Optional[int] = None
    
    # Input Subvolume Granularity
    Ni_gran: Optional[int] = None
    Yi_gran: Optional[int] = None
    Xi_gran: Optional[int] = None
    Ci_gran: Optional[int] = None
    
    # Output Subvolume Granularity
    No_gran: Optional[int] = None
    Yo_gran: Optional[int] = None
    Xo_gran: Optional[int] = None
    Co_gran: Optional[int] = None

    # Kernel/Filter Size
    Ky: Optional[int] = None
    Kx: Optional[int] = None
    
    # Stride
    Sy: Optional[int] = None
    Sx: Optional[int] = None

    # Padding
    Py_b: Optional[int] = None
    Px_b: Optional[int] = None
    Py_a: Optional[int] = None
    Px_a: Optional[int] = None
    
    # Split Mode acorss Columns
    is_Y8_split: Optional[bool] = None
    is_X8_split: Optional[bool] = None

    
def Yi_slice_iter(dims: CommonDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yi_split = dims.Yos * dims.Sy
    Yi_stride = dims.aie_cols * Yi_split if not dims.is_X8_split else Yi_split
    Yi_start = (col * Yi_split) + (start_iter * Yi_stride) - dims.Py_b
    Yi_stop = Yi_start + dims.Yis if Yi_start <= dims.Yi else Yi_start
    Yi_size = max(0, min(Yi_stop, dims.Yi)) - max(0, min(Yi_start, dims.Yi))
    return (Yi_start, Yi_stop, Yi_stride, Yi_size)

def Xi_slice_iter(dims: CommonDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    if dims.is_X8_split:
        Xi_split = dims.Xos * dims.Sx
        Xi_stride = dims.aie_cols * Xi_split
        Xi_start = (col * Xi_split) + (start_iter * Xi_stride) - dims.Px_b

        Xi_stop = (
            min(Xi_start + conv_input(dims.Xos, dims.Kx, dims.Sx),
                dims.Xi + dims.Px_a) if Xi_start <= dims.Xi else
            Xi_start
        )
        Xi_size = min(dims.Xi, Xi_stop) - max(0,Xi_start)
    else:
        Xi_start = 0
        Xi_stop = dims.Xi
        Xi_size = Xi_stop - Xi_start
        Xi_stride = dims.Xi    
    return (Xi_start, Xi_stop, Xi_stride, Xi_size)

def Yo_slice_iter(dims: CommonDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    Yo_stride = dims.aie_cols * dims.Yos if not dims.is_X8_split else dims.Yos
    Yo_start = (col * dims.Yos) + (start_iter * Yo_stride)
    Yo_stop = min(Yo_start + dims.Yos, dims.Yo) if Yo_start <= dims.Yo else Yo_start
    Yo_size = Yo_stop - Yo_start
    return (Yo_start, Yo_stop, Yo_stride, Yo_size)

def Xo_slice_iter(dims: CommonDims, col: int, start_iter: int) -> Tuple[int, int, int, int]:
    if dims.is_X8_split:
        Xo_stride = dims.aie_cols * dims.Xos
        Xo_start = ((col * dims.Xos) + (start_iter * Xo_stride)) if dims.is_X8_split else 0
        Xo_stop = (min(Xo_start + dims.Xos, dims.Xo) if Xo_start <= dims.Xo else Xo_start) if dims.is_X8_split else dims.Xo
        Xo_size = Xo_stop - Xo_start
    else:
        Xo_start = 0
        Xo_stop = dims.Xo
        Xo_stride = dims.Xo
        Xo_size = Xo_stop - Xo_start
    return (Xo_start, Xo_stop, Xo_stride, Xo_size)

def Yi_split_iters(dims: CommonDims, col: int, max_iter: int = 0) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int, max_iter: int) -> bool:
        Yi_start, Yi_stop, _, _ = Yi_slice_iter(dims, col, start_iter + num_iters - 1)
        has_no_padding = not ((Yi_start < 0) or (Yi_stop > dims.Yi))
        if max_iter == 0:
            iter_stop = False
        else: 
            iter_stop = (num_iters == (max_iter + 1))
        return has_no_padding and not iter_stop
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters, max_iter):
            while can_iterate(start_iter, num_iters + 1, max_iter):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split

def Yo_split_iters(dims: CommonDims, col: int) -> List[Tuple[int, int]]:
    def can_iterate(start_iter: int, num_iters: int) -> bool:
        _, _, _, Yo_size = Yo_slice_iter(dims, col, start_iter + num_iters - 1)
        is_full_slice = Yo_size == dims.Yos
        return is_full_slice
    split = []
    curr_iters = 0
    while curr_iters < dims.Y_loop:
        start_iter = curr_iters
        num_iters = 1
        if can_iterate(start_iter, num_iters):
            while can_iterate(start_iter, num_iters + 1):
                num_iters += 1
        split.append((start_iter, num_iters))
        curr_iters += num_iters
    return split


def ifm_chain_length(dims: CommonDims) -> int:
    max_ifm_reuse = (2**6 - 1) // dims.aie_rows
    length = 1
    max_chain_length = 4
    for length in range(1, max_chain_length + 1):
        is_valid = (
            (dims.Co_loop % length == 0) and
            (dims.Co_loop // length <= max_ifm_reuse)
        )
        if is_valid: return length
    raise RuntimeError('Failed to allocate IFM chain!')

def wgt_chain_length(wgt_reuse: int) -> int:
    max_wgt_reuse = (2**6 - 1)
    max_chain_length = 4
    for length in range(1, max_chain_length + 1):
        is_valid = (
            (wgt_reuse % length == 0) and
            (wgt_reuse // length <= max_wgt_reuse)
        )
        if is_valid: return length
    raise RuntimeError('Failed to allocate WGT chain!')
