import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
from typing import Tuple, List

from conv_common_dwc import ceildiv, iceil, conv_input, conv_output


CORE_SYS_MEM_SIZE = 6144
CORE_DATA_MEM_SIZE = 65536 - CORE_SYS_MEM_SIZE
CORE_VECTOR_ALIGN = 64

def filter_grid(grid: List[int], dim: int, split: int, require_divisible: bool = False) -> List[int]:
    greater_points = [g for g in grid if g * split >= dim]
    filtered_grid = (
        grid if len(greater_points) == 0 else
        [g for g in grid if g <= min(greater_points)]
    )
    if require_divisible:
        filtered_grid = [g for g in filtered_grid
                         if ((dim % (g * split)) == 0) or (dim <= (g * split))]
    return filtered_grid

def X8_exception(
    X_split_grid: list, 
    Xo: int, Xi: int,
    X_gran: int, Ci_gran: int,
    aie_cols: int,
    Kx: int, Sx: int,
    ) -> list[bool]:
    """ add exception for X8 split because it might voilate the memtile dims for some X shape
     the X_reminder = Xi % Xos + Xi_overlap
     1)if combine the whole D0 and D1(X_reminder*8) as D0,  
       then the padding will be in D0:  8 *(Xos-Xi_reminder), and only 6b=64 available;
     2)if combine the D0 and partial D1(the X_reminder*2), (2 is example) as D0,  
         then the padding will be in D1:  8*(Xos-Xi_reminder)/(X_reminder*2)  
         and available 5b=32 might fit  
    
    """
    is_exception = [False] * len(X_split_grid)
    if not X_split_grid:
        return is_exception
    Wo = iceil(iceil(Xo, aie_cols), X_gran)
    Xo = Wo // aie_cols
    Xos_grid = [iceil(ceildiv(Xo, X_split), X_gran) for X_split in X_split_grid] 
    for idx, Xos in enumerate(Xos_grid):
        Xi_overlap = conv_input(Xos, Kx, Sx) - Xos
        if( Xi % Xos + Xi_overlap) >= (64 // Ci_gran):
            is_exception[idx] = True
            
    
    return is_exception

def enumerate_conv_subv_cases(
    aie_cols: int, aie_rows: int,
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int, int, int],
    ifm_bits: int, wgt_bits: int, bias_bits: int, ofm_bits: int, tdm_bits: int,
    Ci_gran: int, Co_gran: int, X_gran: int,
    inner_loop_range: int, outer_loop_range: int,
    is_X8_split: bool = False,
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int], int, int]:
    Ci, Yi, Xi = input
    Co, Yo, Xo = output
    Ky, Kx = kernel
    Sy, Sx = stride
    Py_b, Px_b, Py_a, Px_a = pad
    assert Yo == conv_output(Yi, Ky, Sy, Py_b, Py_a)
    assert Xo == conv_output(Xi, Kx, Sx, Px_b, Px_a)
    Ci_block = (Ci_gran * ifm_bits) // 8
    Cis_min = ceildiv(inner_loop_range, Ky * Kx) * Ci_gran
    Cos_min = outer_loop_range * Co_gran
    
    if (Xo >= 128) or (Xo >= 64 and Co>= 1024):
        is_X8_split = True 
        X_split_grid = [1, 2]
    else:
        is_X8_split = False
        X_split_grid = []
    
    X8_split_exception = X8_exception(X_split_grid, 
                                      Xo, Xi, X_gran, Ci_gran,
                                      aie_cols,
                                      Kx, Sx,
                                      )
    if is_X8_split:
        filtered_X_split_grid = [x for x, flag in zip(X_split_grid, X8_split_exception) if not flag]
        if filtered_X_split_grid:
            X_split_grid = filtered_X_split_grid
        else:
            is_X8_split = False
    if is_X8_split:
        X_split_grid = X_split_grid
        Wo = iceil(iceil(Xo, aie_cols), X_gran)
        Xo = Wo // aie_cols
    else:
        X_split_grid = [d for d in range(1, aie_rows + 1) if (aie_rows % d) == 0]
    Yos_grid = filter_grid(list(range(ceildiv(Yo, aie_cols), 0, -1)), Yo, aie_cols)
    Cis_grid = filter_grid(list(range(max(Cis_min, iceil(Ci, Ci_gran)), Cis_min - 1, -Ci_gran)), Ci, 1,
                           require_divisible=True)
    subv_splits = []
    for X_split in X_split_grid:
        Co_split = aie_rows if is_X8_split else aie_rows // X_split
        Cos_grid = filter_grid(list(range(max(Cos_min, iceil(Co, Co_gran)), Cos_min - 1, -Co_gran)), Co, Co_split,
                               require_divisible=True)
        for Yos in Yos_grid:
            for Cos in Cos_grid:
                for Cis in Cis_grid:
                    Yis = conv_input(Yos, Ky, Sy)
                    Xos = iceil(ceildiv(Xo, X_split), X_gran)
                    Xis = iceil(conv_input(Xos, Kx, Sx) * Ci_block, CORE_VECTOR_ALIGN) // Ci_block
                    ifm_size = (Cis * Yis * Xis * ifm_bits) // 8
                    wgt_size = (Cos * Cis * Ky * Kx * wgt_bits) // 8
                    bias_size = (Cos * bias_bits) // 8
                    ofm_size = (Cos * Yos * Xos * ofm_bits) // 8
                    # TDM is not needed when Ci_loop=1
                    tdm_size = (Cos * Yos * Xos * tdm_bits) // 8
                    ifm_sum_size = 0
                    tmp_buf_size = 0

                    # NOTE: We have a special case for a16w8 with QdQ
                    # to account for extra parameters and IFM sum scratchpad space
                    is_a16w8_qdq = (
                        (ifm_bits == 16) and
                        (wgt_bits == 8) and
                        (bias_bits == 64) and
                        (ofm_bits == 16) and
                        (tdm_bits == 32)
                    )
                    if is_a16w8_qdq:
                        qdq_param_bytes = 5 * 4
                        wgt_size = iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, CORE_VECTOR_ALIGN)
                        bias_size = iceil(((Cos * bias_bits) // 8) + qdq_param_bytes, CORE_VECTOR_ALIGN)
                        ifm_sum_size = iceil(max(128, iceil(Xis * Yis, 64), Yis * 8), 64) * tdm_bits // 8
                        tmp_buf_size = iceil(max(128, iceil(Xis * Yis, CORE_VECTOR_ALIGN)), Yis * 8) * tdm_bits // 8
                        tmp_buf_size=iceil(tmp_buf_size, CORE_VECTOR_ALIGN)

                    total_size = (
                        (ifm_size * 2) +
                        (wgt_size * 2) +
                        (bias_size * 2) +
                        ofm_size +
                        (tdm_size * 2) +
                        ifm_sum_size +
                        tmp_buf_size
                    )
                    """_check shim.mm2s transfer for wgt
                       1. each wgt subv defined as wgt_subv_size
                       2. and subv looped in Ci dim is based on Ci_wgtsubv_loop
                       3. and subv looped in Co dim is based on Co_wgtsubv_loop
                       4. the split cross column is on Co_wgtsubv_loop
                          1) so the shim transfer basic size will be Co_iter (in 32bits words)
                          2) the column size will be 4 (for 8x4 and 4x4)
                          3) the iteration step if use will be shim_iter
                          4) or the last step_size will be shim_iter if not using iteration
                          5) in any case, the shim_iter will be less than 20 bits ( 2^20 = 1M)

                    """
                    wgt_subv_size = (
                        iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, 64) +
                        iceil(((Cos * bias_bits) // 8) + qdq_param_bytes, 64)
                    )                    
                    
                    Co_mt_pack = ceildiv(Co, (Co_split * Cos)) if is_X8_split else 1
                    mt_Xi_size = Xis if is_X8_split else Xi
                    total_ifm_reuse_memtile_size = (
                            4096 +                                #param size
                            (Yis * mt_Xi_size * Ci * ifm_bits ) // 8 +   #ifm shard size
                            (wgt_subv_size * 2) +                 #wgt ifm size 
                            Co_mt_pack * (Yos * Xos * X_split * Cos * Co_split * ofm_bits) // 8 # ofm size
                        )
                        
                    ifm_streaming_mode = total_ifm_reuse_memtile_size > 512 * 1024  # maximim of memtile size
                    
                    
                    # for is_X8_split, currently we do ifm non-straming mode to maximize the reuse of the ifm in memtile
                    if is_X8_split:
                        Y_loop_valid = ceildiv(Yo, Yos) <= 64 * 8
                        Co_loop_valid = Co_mt_pack <= 4 # the max queue limit
                        is_X8_split_valid = not ifm_streaming_mode  and  Y_loop_valid and Co_loop_valid 
                        shim_iter_valid = True
                    else:
                        Ci_wgtsubv_loop = ceildiv(Ci, Cis)
                        Co_wgtsubv_loop = ceildiv(Co, (Co_split * Cos)) * Co_split

                        Co_iter = Ci_wgtsubv_loop * wgt_subv_size // 4 #words
                        avail_broadcast_cols = aie_cols // 2 if aie_cols > 4 else aie_cols
                        if Co_wgtsubv_loop > avail_broadcast_cols:
                            shim_iter = Co_iter * avail_broadcast_cols
                        else:
                            shim_iter = 0
                        shim_iter_valid = (shim_iter <= 2**20)
                        shim_ofm_iter_step = Yos * Xo * Co * ofm_bits * aie_cols // 8 // 4
                        shim_ofm_iter_valid = shim_ofm_iter_step < 2 ** 20  
                        shim_iter_valid = shim_iter_valid and shim_ofm_iter_valid
                      
                        is_X8_split_valid = True
                    Cos_valid = (Cos %(Co_gran * 2) == 0) if Sx == 2 else True
                    is_valid =  (total_size <= CORE_DATA_MEM_SIZE) and shim_iter_valid and is_X8_split_valid and Cos_valid
                    if is_valid:
                        subv_splits.append((
                            (Cis, Yis, Xis),
                            (Cos, Yos, Xos),
                            X_split,
                            Co_split,
                            is_X8_split,
                            ifm_streaming_mode,
                        ))
    return subv_splits

def sort_conv_subv_cost(
    aie_cols: int, aie_rows:int,
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int, int, int],
    subv_splits: List[Tuple[Tuple[int, int, int], Tuple[int, int, int], int, int, bool, bool]],
) -> List[Tuple[Tuple[int, int, int], Tuple[int, int, int], int, int, bool]]:
    Ci, Yi, Xi = input
    Co, Yo, Xo = output
    Ky, Kx = kernel
    Sy, Sx = stride
    Py_b, Px_b, Py_a, Px_a = pad
    assert Yo == conv_output(Yi, Ky, Sy, Py_b, Py_a)
    assert Xo == conv_output(Xi, Kx, Sx, Px_b, Px_a)

    def split_key(split):
        (Cis, _, _), (Cos, Yos, Xos), X_split, Co_split, is_X8_split, ifm_streaming_mode = split
        Y_loop = ceildiv(Yo, (Yos * aie_cols))
        Co_loop = ceildiv(Co, (Cos * Co_split))
        Ci_loop = ceildiv(Ci, Cis)
        shape_ops = (
            Ci * Ky * Kx * Co * Yo * Xo
        )
        compute_ops = (
            (Ci_loop * Cis) * Ky * Kx *
            (Co_loop * Co_split * Cos) *
            (Y_loop * aie_cols * Yos) *
            (X_split * Xos)
        )
        loop_count = Y_loop * Co_loop * Ci_loop
        overcompute_ratio = compute_ops / shape_ops
        if is_X8_split:
            return(X_split, ifm_streaming_mode, loop_count, overcompute_ratio, Ci_loop)
        else:
            return (ifm_streaming_mode, loop_count, overcompute_ratio, Ci_loop)

    sorted_subv_splits = sorted(subv_splits, key=split_key)
    return sorted_subv_splits

def conv_subv_split_mode(
    aie_cols:int, aie_rows:int,
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int, int, int],
    ifm_bits: int, wgt_bits: int, bias_bits: int, ofm_bits: int, tdm_bits: int,
    Ci_gran: int, Co_gran: int, X_gran: int,
    inner_loop_range: int, outer_loop_range: int,
):
    subv_splits = enumerate_conv_subv_cases(
        aie_cols, aie_rows,
        input, output, kernel, stride, pad,
        ifm_bits, wgt_bits, bias_bits, ofm_bits, tdm_bits,
        Ci_gran, Co_gran, X_gran,
        inner_loop_range, outer_loop_range,
    )
    sorted_subv_splits = sort_conv_subv_cost(
        aie_cols, aie_rows,
        input, output, kernel, stride, pad,
        subv_splits,
    )
    return sorted_subv_splits[0]

def main():
    subv_split, _ = conv_subv_split_mode(
        8, 4,
        (128, 32, 32),
        (256, 32, 32),
        (3, 3),
        (1, 1),
        (1, 1, 1, 1),
        16, 8, 16, 16, 32,
        8, 16, 8,
        8, 1,
    )
    print(subv_split)

if __name__ == '__main__':
    main()
