import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
from typing import Tuple, List

from conv_common_dwc import ceildiv, iceil, conv_input, conv_output

CORE_SYS_MEM_SIZE = 6144
CORE_DATA_MEM_SIZE = 65536 - CORE_SYS_MEM_SIZE
CORE_VECTOR_ALIGN = 64

def filter_grid(grid: List[int], dim: int, split: int, require_divisible: bool = False) -> List[int]:
    greater_points = [g for g in grid if g * split >= dim]
    filtered_grid = (
        grid if len(greater_points) == 0 else
        [g for g in grid if g <= min(greater_points)]
    )
    if require_divisible:
        filtered_grid = [g for g in filtered_grid
                         if ((dim % (g * split)) == 0) or (dim <= (g * split))]
    return filtered_grid

def enumerate_dwc_subv_cases(
    aie_cols: int, aie_rows: int,
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int],
    ifm_bits: int, wgt_bits: int, bias_bits: int, ofm_bits: int, tdm_bits: int,
    Ci_gran: int, Co_gran: int, X_gran: int,
    inner_loop_range: int, outer_loop_range: int,
) -> Tuple[Tuple[int, int, int], Tuple[int, int, int], int, int, bool, bool]:
    Ci, Yi, Xi = input
    Co, Yo, Xo = output
    Ky, Kx = kernel
    Sy, Sx = stride
    Py, Px = pad
    assert Yo == conv_output(Yi, Ky, Sy, Py)
    assert Xo == conv_output(Xi, Kx, Sx, Px)
    Ci_block = (Ci_gran * ifm_bits) // 8
    Cis_min = ceildiv(inner_loop_range, Ky * Kx) * Ci_gran
    Cos_min = outer_loop_range * Co_gran
    X_split_grid = [d for d in range(1, aie_rows + 1) if (aie_rows % d) == 0]
    Yos_grid = filter_grid(list(range(ceildiv(Yo, aie_cols), 0, -1)), Yo, aie_cols)
    Cis_grid = filter_grid(list(range(max(Cis_min, iceil(Ci, Ci_gran)), Cis_min - 1, -Ci_gran)), Ci, 1,
                           require_divisible=True)
    subv_splits = []
    for X_split in X_split_grid:
        Co_split = aie_rows // X_split
        for Yos in Yos_grid:
            for Cis in Cis_grid:
                Cos = Cis
                Yis = conv_input(Yos, Ky, Sy)
                Xos = iceil(ceildiv(Xo, X_split), X_gran)
                Xis = iceil(conv_input(Xos, Kx, Sx) * Ci_block, CORE_VECTOR_ALIGN) // Ci_block
                ifm_size = (Cis * Yis * Xis * ifm_bits) // 8
                wgt_size = (Cis * Ky * Kx * wgt_bits) // 8
                bias_size = (Cos * bias_bits) // 8
                ofm_size = (Cos * Yos * Xos * ofm_bits) // 8
                tdm_size = (Cos * Yos * Xos * tdm_bits) // 8
                ifm_sum_size = 0
                tmp_buf_size = 0

                # NOTE: We have a special case for a16w8 with QdQ
                # to account for extra parameters and IFM sum scratchpad space
                is_a16w8_qdq = (
                    (ifm_bits == 16) and
                    (wgt_bits == 8) and
                    (bias_bits == 64) and
                    (ofm_bits == 16) and
                    (tdm_bits == 32)
                )
                if is_a16w8_qdq:
                    qdq_param_bytes = 5 * 4
                    wgt_size = iceil((Cos * Cis * Ky * Kx * wgt_bits) // 8, CORE_VECTOR_ALIGN)
                    bias_size = iceil(((Cos * bias_bits) // 8) + qdq_param_bytes, CORE_VECTOR_ALIGN)
                    Xi_g = ((Xos // X_gran) * Sx) + (Kx > 1)
                    Yi_g = ((Yos - 1) * Sy) + Ky
                    ifm_sum_size = (iceil(Xi_g * X_gran * Yi_g, CORE_VECTOR_ALIGN) * tdm_bits) // 8
                    tmp_buf_size = iceil(max(128, iceil(Xis * Yis, CORE_VECTOR_ALIGN)), Yis * 64) * tdm_bits // 8

                total_size = (
                    (ifm_size * 2) +
                    (wgt_size * 2) +
                    (bias_size * 2) +
                    ofm_size +
                    (tdm_size * 2) +
                    ifm_sum_size +
                    tmp_buf_size
                )
                is_valid = total_size <= CORE_DATA_MEM_SIZE
                if is_valid:
                    subv_splits.append((
                        (Cis, Yis, Xis),
                        (Cos , Yos, Xos),
                        X_split,
                        Co_split,
                        False,
                        False
                    ))

    return subv_splits

def sort_dwc_subv_cost(
    aie_cols: int, aie_rows: int,
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int],
    subv_splits: List[Tuple[Tuple[int, int, int], Tuple[int, int, int], int, int, bool, bool]],
) -> List[Tuple[Tuple[int, int, int], Tuple[int, int, int], int, int]]:
    Ci, Yi, Xi = input
    Co, Yo, Xo = output
    Ky, Kx = kernel
    Sy, Sx = stride
    Py, Px = pad

    assert Yo == conv_output(Yi, Ky, Sy, Py)
    assert Xo == conv_output(Xi, Kx, Sx, Px)

    def split_key(split):
        (Cis, _, _), (Cos, Yos, Xos), X_split, Co_split, _, _ = split
        Y_loop = ceildiv(Yo, (Yos * aie_cols))
        Co_loop = ceildiv(Co, (Cos * Co_split))
        Ci_loop = ceildiv(Ci, Cis)
        shape_ops = (
            Ci * Ky * Kx * Yo * Xo
        )
        compute_ops = (
            (Ci_loop * Cis) * Ky * Kx *
            (Y_loop * aie_cols * Yos) *
            (X_split * Xos)
		)
        loop_count = Y_loop * Ci_loop
        overcompute_ratio = compute_ops / shape_ops
        return (loop_count, overcompute_ratio)

    sorted_subv_splits = sorted(subv_splits, key=split_key)
    return sorted_subv_splits

def dwc_subv_split_mode(
    aie_cols: int, aie_rows: int,
    input: Tuple[int, int, int],
    output: Tuple[int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int],
    ifm_bits: int, wgt_bits: int, bias_bits: int, ofm_bits: int, tdm_bits: int,
    Ci_gran: int, Co_gran: int, X_gran: int,
    inner_loop_range: int, outer_loop_range: int,
):
    subv_splits = enumerate_dwc_subv_cases(
        aie_cols, aie_rows,
        input, output, kernel, stride, pad,
        ifm_bits, wgt_bits, bias_bits, ofm_bits, tdm_bits,
        Ci_gran, Co_gran, X_gran,
        inner_loop_range, outer_loop_range
    )
    sorted_subv_splits = sort_dwc_subv_cost(
        aie_cols, aie_rows,
        input, output, kernel, stride, pad,
        subv_splits,
    )
    return sorted_subv_splits[0]

def main():
    subv_split = dwc_subv_split_mode(
        8, 4,
        (128, 32, 32),
        (256, 32, 32),
        (3, 3),
        (1, 1),
        (1, 1),
        16, 8, 16, 16, 32,
        8, 16, 8,
        8, 1,
    )
    print(subv_split)

if __name__ == '__main__':
    main()
