import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
from typing import Tuple, List

from pooling_common import ceildiv, iceil, pooling_input, pooling_output


CORE_SYS_MEM_SIZE = 6144
CORE_DATA_MEM_SIZE = 65536 - CORE_SYS_MEM_SIZE
CORE_VECTOR_ALIGN = 64

def filter_grid(grid: List[int], dim: int, split: int, require_divisible: bool = False) -> List[int]:
    greater_points = [g for g in grid if g * split >= dim]
    filtered_grid = (
        grid if len(greater_points) == 0 else
        [g for g in grid if g <= min(greater_points)]
    )
    if require_divisible:
        filtered_grid = [g for g in filtered_grid
                         if ((dim % (g * split)) == 0) or (dim <= (g * split))]
    return filtered_grid

def X8_exception(
    X_split_grid: list,
    Xo: int, Xi: int,
    X_gran: int, Ci_gran: int,
    aie_cols: int,
    Kx: int, Sx: int,
    ) -> list[bool]:
    """ add exception for X8 split because it might voilate the memtile dims for some X shape
     the X_reminder = Xi % Xos + Xi_overlap
     1)if combine the whole D0 and D1(X_reminder*8) as D0,
       then the padding will be in D0:  8 *(Xos-Xi_reminder), and only 6b=64 available;
     2)if combine the D0 and partial D1(the X_reminder*2), (2 is example) as D0,
         then the padding will be in D1:  8*(Xos-Xi_reminder)/(X_reminder*2)
         and available 5b=32 might fit

    """
    is_exception = [False] * len(X_split_grid)
    if not X_split_grid:
        return is_exception
    Wo = iceil(iceil(Xo, aie_cols), X_gran)
    Xo = Wo // aie_cols
    Xos_grid = [iceil(ceildiv(Xo, X_split), X_gran) for X_split in X_split_grid]
    for idx, Xos in enumerate(Xos_grid):
        Xi_overlap = pooling_input(Xos, Kx, Sx) - Xos
        if( Xi % Xos + Xi_overlap) >= (64 // Ci_gran):
            is_exception[idx] = True


    return is_exception

def Yi_padding_verify(aie_cols: int,
                     Yi: int,
                     Sy: int,
                     Yos: int,
                     Yis: int,
                     Py_b: int) -> bool:
    Yi_start =0
    start_iter = 0
    while Yi_start < Yi:
        for col in range(aie_cols):
            Yi_split = Yos * Sy
            Yi_stride = aie_cols * Yi_split
            Yi_start = (col * Yi_split) + (start_iter * Yi_stride) - Py_b
            Yi_stop = Yi_start + Yis if Yi_start <= Yi else Yi_start
            Yi_size = max(0, min(Yi_stop, Yi)) - max(0, min(Yi_start, Yi))
            if Yis -Yi_size >= 16:
                return False
            start_iter += 1
    return True

def enumerate_pooling_subv_cases(
    aie_cols: int, aie_rows: int,
    input: Tuple[int, int, int, int],
    output: Tuple[int, int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int, int, int],
    ifm_bits: int, wgt_subv_size: int, ofm_bits: int,
    has_scratch_buf: int, scratch_buf_bits: int,
    Ci_gran: int, Co_gran: int, X_gran: int,
    is_X8_split: bool = False,
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int], int, int]:
    Ni, Yi, Xi, Ci = input
    No, Yo, Xo, Co = output
    Ky, Kx = kernel
    Sy, Sx = stride
    Py_b, Px_b, Py_a, Px_a = pad
    assert Yo == pooling_output(Yi, Ky, Sy, Py_b, Py_a)
    assert Xo == pooling_output(Xi, Kx, Sx, Px_b, Px_a)
    Ci_block = (Ci_gran * ifm_bits) // 8
    # Cis_min = ceildiv(8, Ky * Kx) * Ci_gran
    Cos_min = 1 * Co_gran

    spatial_split_mode = [8, 1, 1, 1] if Ni > Yi else [1, 8, 1, 1]
    N_sptial_split = ceildiv(Ni, spatial_split_mode[0])
    row_split_mode = [4, 1, 1, 1] if N_sptial_split > Xi else [1, 1, 1, 1]


    Nis = 1
    Nos = 1

    assert Ci == Co, f"the input chaanel:{Ci} and output channel: {Co} should be same"

    # if (Xo >= 128) or (Xo >= 64 and Co>= 1024):
    #     is_X8_split = True
    #     X_split_grid = [1, 2]
    # else:
    is_X8_split = False
    X_split_grid = []

    X8_split_exception = X8_exception(X_split_grid,
                                      Xo, Xi, X_gran, Ci_gran,
                                      aie_cols,
                                      Kx, Sx,
                                      )
    if is_X8_split:
        filtered_X_split_grid = [x for x, flag in zip(X_split_grid, X8_split_exception) if not flag]
        if filtered_X_split_grid:
            X_split_grid = filtered_X_split_grid
        else:
            is_X8_split = False
    if is_X8_split:
        X_split_grid = X_split_grid
        Wo = iceil(iceil(Xo, aie_cols), X_gran)
        Xo = Wo // aie_cols
    else:
        X_split_grid = [d for d in range(1, aie_rows + 1) if (aie_rows % d) == 0] if row_split_mode[0] == 1 else [1]
        N_row_split_grid = [d for d in range(1, aie_rows + 1) if (aie_rows % d) == 0] if row_split_mode[0] == aie_rows else [1]
        # X_split_grid = [aie_rows] # for temporary solution.
    Yos_grid = filter_grid(list(range(ceildiv(Yo, spatial_split_mode[1]), 0, -1)), Yo, spatial_split_mode[1])
    Nim_grid = filter_grid(list(range(ceildiv(Ni, spatial_split_mode[0]), 0, -1)), Ni, spatial_split_mode[0])
    # Cis_grid = filter_grid(list(range(max(Cis_min, iceil(Ci, Ci_gran)), Cis_min - 1, -Ci_gran)), Ci, 1,
    #                        require_divisible=True)
    subv_splits = []

    for X_split in X_split_grid:
        for N_row_split in N_row_split_grid:
            Co_split = aie_rows if is_X8_split else aie_rows // (X_split * N_row_split)
            Cos_grid = filter_grid(list(range(max(Cos_min, iceil(Co, Co_gran)), Cos_min - 1, -Co_gran)), Co, Co_split,
                                require_divisible=True)
            for Nim in Nim_grid:
                for Yos in Yos_grid:
                    for Cos in Cos_grid:
                        # for Cis in Cis_grid:
                        Nom = ceildiv(Nim, N_row_split) * N_row_split
                        if ceildiv(Ni, Nim*N_row_split) != ceildiv(No, Nom*N_row_split):
                            continue
                        Cis = Cos
                        Yis = pooling_input(Yos, Ky, Sy)
                        Xos = iceil(ceildiv(Xo, X_split), X_gran)
                        Xis = iceil(pooling_input(Xos, Kx, Sx) * Ci_block, CORE_VECTOR_ALIGN) // Ci_block
                        ifm_size = (Nis * Cis * Yis * Xis * ifm_bits) // 8
                        wgt_size = wgt_subv_size
                        ofm_size = (Nos * Cos * Yos * Xos * ofm_bits) // 8
                        scratch_buf_size = (Nos * Cos * Yos * Xos * scratch_buf_bits) // 8 if has_scratch_buf else 0

                        total_size_pingpong = (
                            (ifm_size * 2) +
                            wgt_size  +
                            (ofm_size * 2) +
                            scratch_buf_size
                        )
                        total_size_ping =  (
                            (ifm_size * 1) +
                            wgt_size  +
                            ofm_size +
                            scratch_buf_size
                        )
                        """_check shim.mm2s transfer for wgt
                            1. each wgt subv defined as wgt_subv_size
                            2. and subv looped in Ci dim is based on Ci_wgtsubv_loop
                            3. and subv looped in Co dim is based on Co_wgtsubv_loop
                            4. the split cross column is on Co_wgtsubv_loop
                                1) so the shim transfer basic size will be Co_iter (in 32bits words)
                                2) the column size will be 4 (for 8x4 and 4x4)
                                3) the iteration step if use will be shim_iter
                                4) or the last step_size will be shim_iter if not using iteration
                                5) in any case, the shim_iter will be less than 20 bits ( 2^20 = 1M)

                        """

                        Co_mt_pack = ceildiv(Co, (Co_split * Cos)) if is_X8_split else 1
                        mt_Xi_size = Xis if is_X8_split else Xi
                        total_ifm_reuse_memtile_size = (
                                4096 +                                #param size
                                (Nim * Yis * mt_Xi_size * Ci * ifm_bits ) // 8 +   #ifm shard size
                                (wgt_subv_size ) +                 #wgt ifm size
                                Co_mt_pack * (Nom * Yos * Xos * X_split * Cos * Co_split * ofm_bits) // 8 # ofm size
                            )

                        ifm_streaming_mode = total_ifm_reuse_memtile_size > 512 * 1024  # maximim of memtile size
                        #Yi padding valid
                        Yi_padding_valid = Yi_padding_verify(aie_cols, Yi, Sy, Yos, Yis, Py_b)

                        # for is_X8_split, currently we do ifm non-straming mode to maximize the reuse of the ifm in memtile
                        if is_X8_split:
                            Y_loop_valid = ceildiv(Yo, Yos) <= 64 * 8
                            Co_loop_valid = Co_mt_pack <= 4 # the max queue limit
                            is_X8_split_valid = not ifm_streaming_mode  and  Y_loop_valid and Co_loop_valid
                        else:
                            is_X8_split_valid = True
                        # Cos_valid = (Cos %(Co_gran * 2) == 0) if Sx == 2 else True
                        is_valid =  (total_size_pingpong <= CORE_DATA_MEM_SIZE or total_size_ping <= CORE_DATA_MEM_SIZE) \
                            and is_X8_split_valid and Yi_padding_valid
                        if is_valid:
                            subv_splits.append((
                                (Nis, Yis, Xis, Cis),
                                (Nos, Yos, Xos, Cos),
                                (Nim, Nom),
                                X_split,N_row_split,
                                Co_split,
                                is_X8_split,
                                ifm_streaming_mode,
                                spatial_split_mode,
                                row_split_mode,
                            ))
    return subv_splits

def sort_pooling_subv_cost(
    aie_cols: int, aie_rows:int,
    input: Tuple[int, int, int, int],
    output: Tuple[int, int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int, int, int],
    subv_splits: List[Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int], Tuple[int, int],
                            int, int, int, bool, bool, Tuple[int, int, int, int], Tuple[int, int, int, int],
                            ]],
) -> List[Tuple[Tuple[int, int, int, int], Tuple[int, int, int, int], Tuple[int, int],
                            int, int, int, bool, bool, Tuple[int, int, int, int], Tuple[int, int, int, int],
                            ]]:
    Ni, Yi, Xi, Ci = input
    No, Yo, Xo, Co = output
    Ky, Kx = kernel
    Sy, Sx = stride
    Py_b, Px_b, Py_a, Px_a = pad
    assert Yo == pooling_output(Yi, Ky, Sy, Py_b, Py_a)
    assert Xo == pooling_output(Xi, Kx, Sx, Px_b, Px_a)

    def split_key(split):
        (_, _, _, Cis), (_, Yos, Xos, Cos), (_, Nom), X_split, N_row_split, Co_split,\
        is_X8_split, ifm_streaming_mode, (N_spatial, Y_spatial, _, _), (_, _, _, _)  = split
        N_loop = ceildiv(No, (Nom * N_spatial))
        Y_loop = ceildiv(Yo, (Yos * Y_spatial))
        Co_loop = ceildiv(Co, (Cos * Co_split))
        Ci_loop = ceildiv(Ci, Cis)
        shape_ops = (
            Ci * Ky * Kx * Co * Yo * Xo * No
        )
        compute_ops = (
            (Ci_loop * Cis) * Ky * Kx *
            (Co_loop * Co_split * Cos) *
            (Y_loop * Y_spatial * Yos) *
            (N_loop * N_spatial * Nom) *
            (X_split * Xos)
        )
        loop_count = N_loop * Y_loop * Co_loop * Ci_loop
        overcompute_ratio = compute_ops / shape_ops
        if is_X8_split:
            return(X_split, ifm_streaming_mode, loop_count, overcompute_ratio, Ci_loop)
        else:
            return (ifm_streaming_mode, loop_count, overcompute_ratio, Ci_loop)

    sorted_subv_splits = sorted(subv_splits, key=split_key)
    return sorted_subv_splits

def pooling_subv_split_mode(
    aie_cols:int, aie_rows:int,
    input: Tuple[int, int, int, int],
    output: Tuple[int, int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int, int, int],
    ifm_bits: int, qdq_bits: int, ofm_bits: int,
    has_scratch_buf: int, scratch_buf_bits: int,
    Ci_gran: int, Co_gran: int, X_gran: int,
):
    subv_splits = enumerate_pooling_subv_cases(
        aie_cols, aie_rows,
        input, output, kernel, stride, pad,
        ifm_bits, qdq_bits, ofm_bits,
        has_scratch_buf, scratch_buf_bits,
        Ci_gran, Co_gran, X_gran,
    )
    sorted_subv_splits = sort_pooling_subv_cost(
        aie_cols, aie_rows,
        input, output, kernel, stride, pad,
        subv_splits,
    )
    return sorted_subv_splits[0]

def main():
    subv_split, _ = pooling_subv_split_mode(
        8, 4,
        (128, 32, 32),
        (256, 32, 32),
        (3, 3),
        (1, 1),
        (1, 1, 1, 1),
        16, 256, 16,  # 256 is qdq = 64 * 32bits
        8, 8, 4,
    )
    print(subv_split)

if __name__ == '__main__':
    main()
