import os
import sys
from typing import List, Dict, Tuple

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, "..", "..", "dmacompiler"))
sys.path.append(os.path.join(CURRDIR, "..", "..", "OGOAT", "src", "L1_fusion"))

from dmacompiler import set_dev_gen, DevGen, config

from dataflow_common import ceildiv, iceil, tiling_json_gen, overlay_stack_addr, ifloor
from slice_common import SliceDims, make_slice_dict
from dataflow_utils import CommonDims
from kernel_func_list import kernel_func_list

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True


def sort_subv_cost(subv_splits: List) -> List:
    def split_key(split):
        _, _, _, _, over_compute, total_loop = split
        return (total_loop, over_compute)

    sorted_subv_splits = sorted(subv_splits, key=split_key)
    return sorted_subv_splits


def generate_subv(
    aie_cols,
    aie_rows,
    ifm_bits,
    ofm_bits,
    slice,
    Ni,
    Yi,
    Xi,
    Ci,
    No,
    Yo,
    Xo,
    Co,
    out_step_dims,
    wgt_subv,
    Cop,
    has_scratch_buf,
) -> Tuple[int, int, int]:
    usable_mt_size = (
        config.MAX_MEMTILE_ADDR - aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE - wgt_subv
    )
    usable_core_size = overlay_stack_addr() - wgt_subv
    # new_Co = slice['C'][1] - slice['C'][0]
    new_Co = len(range(slice["C"][0], slice["C"][1], out_step_dims[3]))
    Yis_grid = [1 + n for n in range(Yo)]
    """
        1. We don't split C_in as of now
        2. We split X across rows
        3. In future we might need to add C split here

        NOTE: Memtile MM2S padding capability: D0: 64, D1: 32, D2: 16
    """
    # has_scratch_buf = False if ifm_bits == 16 and ofm_bits == 16 else True
    YXC_gran = 8
    wgt_subv_size = 256  # we will use this value no matter qdq or no-qdq
    total_elements = No * Yo * Xo * new_Co
    subv_splits = []
    for Yis in Yis_grid:
        Y_loop = ceildiv(Yo, Yis * aie_cols)
        if Yis - find_split_remainder(Yo, Yis) >= 2**4:  # Y paddig < 16
            continue
        Cis = new_Co if out_step_dims[3] == 1 else Ci
        Cos = new_Co
        C_loop = 1
        Xis_grid = [1 + n for n in range(Xo)]
        for Xis_0 in Xis_grid:
            C_odd = 1 if Co % 2 == 1 else 0
            Xis_0 = iceil(Xis_0, 32 // ifm_bits) if C_odd else Xis_0
            Xis_0 = YXC_gran // Cos if Yis * Xis_0 * Cos < YXC_gran else Xis_0
            Xis = Xis_0
            X_loop = ceildiv(Xo, Xis_0 * aie_rows)
            ifm_subv_size = iceil(Yis * Xis * Cis, 64) * ifm_bits // 8
            ofm_subv_size = Yis * Xis * do_padding(Cos) * ofm_bits // 8
            scratch_buffer_size = (
                iceil(Yis * Xis * Cis, 64) * 2 if has_scratch_buf else 0
            )

            mt_mem_size = (
                Yis * (Xis * aie_rows) * (Cis) * ifm_bits // 8
                + Yis * (Xis * aie_rows) * (Cop) * ofm_bits // 8
            )
            total_loop = No * Y_loop * X_loop * C_loop

            is_valid = (
                ifm_subv_size + ofm_subv_size + scratch_buffer_size + wgt_subv_size
                <= usable_core_size
                and mt_mem_size <= usable_mt_size
            )

            total_compute = total_loop * Yis * Xis_0 * Cos * aie_cols * aie_rows
            over_compute_ratio = total_compute / total_elements

            Xis_padding_valid = (Xis - (Xo % Xis)) <= 32
            gran_valid = Cis * Xis * Yis >= YXC_gran
            if is_valid and Xis_padding_valid and gran_valid:
                subv_splits.append((Yis, Xis, Cis, Cos, over_compute_ratio, total_loop))
    sorted_subv = sort_subv_cost(subv_splits)
    Yis, Xis, Cis, Cos, _, _ = sorted_subv[0]

    return Yis, Xis, Cis, Cos


def do_padding(inDim: int, padding_enable: bool = True):
    padding_alignment = 8
    outDim = iceil(inDim, padding_alignment) if padding_enable else inDim
    return outDim


def find_split_remainder(inDim: int, subv: int):
    return inDim % subv


def run_tiler(
    aie_cols: int,
    aie_rows: int,
    input: List[int],
    slice: Dict,
    axis: int,
    ifm_bits: int,
    ofm_bits: int,
    fixed_point_bits: int,
    out_start: int,
    out_stop: int,
    out_step: int,
    is_qdq: bool,
    qdq_mode: int = 3,  # 0: DEQUANT; 1: QUANT; 2: BOTH; 3: NONE
    enable_padding_arg: bool = True,
    kernel_padding_arg: bool = False,
    is_input_max64w8_format: bool = True,
) -> SliceDims:

    Ni, Yi, Xi, Ci_orig = input
    out_step_dims = [out_step if i == axis else 1 for i in range(4)]

    if is_input_max64w8_format:
        Ci = do_padding(Ci_orig)
        if axis == 3:
            C_dim = range(0, input[3] + 1)
            slice["C"][0] = C_dim[slice["C"][0]]
            slice["C"][1] = C_dim[slice["C"][1]]
        else:
            slice["C"][0] = slice["C"][0]
            slice["C"][1] = Ci
    else:
        Ci = input[3]

    # Re-construct C dimension if it is too large to fit in L1
    do_dma_padding = True
    do_kernel_padding = True
    ifm_ofm_factor = 2
    ping_pong_factor = 2
    is_C_reconstruct = (
        True
        if axis != 3
        and (
            ifm_ofm_factor * ping_pong_factor * Ci * ifm_bits // 8
            > overlay_stack_addr()
        )
        else False
    )
    if is_C_reconstruct:
        do_dma_padding = False
        do_kernel_padding = False
        C_split_ratio = 1
        while (
            ifm_ofm_factor * ping_pong_factor * Ci // C_split_ratio * ifm_bits // 8
            > overlay_stack_addr()
        ):
            C_split_ratio *= 2

        assert (
            Ci % C_split_ratio == 0
        ), "Can't find a valid reconstruct ratio, need C-dim split"
        Ci = Ci // C_split_ratio
        Ci_orig = Ci
        slice["C"][0] = slice["C"][0]
        slice["C"][1] = Ci
        Xi = Xi * C_split_ratio
        slice["W"][0] = slice["W"][0] * C_split_ratio
        slice["W"][1] = slice["W"][1] * C_split_ratio

    # No = slice['N'][1] - slice['N'][0]
    # Yo = slice['H'][1] - slice['H'][0]
    # Xo = slice['W'][1] - slice['W'][0]
    # Co = slice['C'][1] - slice['C'][0]
    No = len(range(slice["N"][0], slice["N"][1], out_step_dims[0]))
    Yo = len(range(slice["H"][0], slice["H"][1], out_step_dims[1]))
    Xo = len(range(slice["W"][0], slice["W"][1], out_step_dims[2]))
    Co = len(range(slice["C"][0], slice["C"][1], out_step_dims[3]))

    tiling = {}
    tiling["op_type"] = "slice"
    tiling["orig_input"] = input
    tiling["orig_output"] = [No, Yo, Xo, Co]
    tiling_json_gen(tiling, os.path.join(os.getcwd(), "tiling.json"))

    Ni_slice_start = slice["N"][0]
    Ni_slice_stop = slice["N"][1]
    innerC = None
    startC = None
    slice_inner_start = None

    C_start_even = slice["C"][0] % (32 // ifm_bits) == 0
    C_stop_even = slice["C"][1] % (32 // ifm_bits) == 0
    shard = None
    if axis != 3:
        is_slice_kernel = False
    else:
        if out_step_dims[3] == 1:
            if not C_start_even and C_stop_even:
                is_slice_kernel = True
                C_start_adj = ifloor(slice["C"][0], 32 // ifm_bits)
                C_stop_adj = slice["C"][1]
                innerC = slice["C"][1] - slice["C"][0]
                slice_inner_start = slice["C"][0] - C_start_adj
                slice_inner_end = slice["C"][1] - C_start_adj
            elif C_start_even and not C_stop_even:
                is_slice_kernel = True
                C_start_adj = slice["C"][0]
                C_stop_adj = iceil(slice["C"][1], 32 // ifm_bits)
                innerC = slice["C"][1] - slice["C"][0]
                slice_inner_start = 0
                slice_inner_end = slice["C"][1] - slice["C"][0]
            elif C_start_even and C_stop_even:
                is_slice_kernel = False
                C_start_adj = slice["C"][0]
                C_stop_adj = slice["C"][1]
                shard = None
            else:
                is_slice_kernel = True
                C_start_adj = ifloor(slice["C"][0], 32 // ifm_bits)
                C_stop_adj = iceil(slice["C"][1], 32 // ifm_bits)
                innerC = slice["C"][1] - C_start_adj - (slice["C"][0] - C_start_adj)
                slice_inner_start = slice["C"][0] - C_start_adj
                slice_inner_end = slice["C"][1] - C_start_adj
        else:
            is_slice_kernel = True
            innerC = Ci
            slice_inner_start = slice["C"][0]
            C_start_adj = ifloor(slice["C"][0], 32 // ifm_bits)
            C_stop_adj = iceil(slice["C"][1], 32 // ifm_bits)
        slice = make_slice_dict(input, 3, C_start_adj, C_stop_adj)

    is_kernel = is_slice_kernel or is_qdq  # Is either of the kernel being used

    enable_dma_padding = do_dma_padding
    enable_kernel_padding = do_kernel_padding

    # NOTE: Below block is temporary until kernel padding is supported because if Co is odd then MM2S can't pad
    if Co % (32 // ifm_bits) != 0:
        enable_dma_padding = False
        enable_kernel_padding = True if is_kernel else False

    if is_kernel and not is_slice_kernel:
        enable_kernel_padding = False

    Cop = do_padding(Co, do_dma_padding or do_kernel_padding)
    Com = do_padding(Co, do_dma_padding or do_kernel_padding) if axis == 3 else Co

    param_subv_size = config.MAX_CORE_LAYER_PARAM_SIZE
    memtile_param_size = param_subv_size * aie_rows
    CoreqdqPrmSize = 64
    wgt_subv_size = CoreqdqPrmSize * aie_rows
    wgt_subv = wgt_subv_size if is_qdq else 0
    enable_padding = enable_dma_padding or enable_kernel_padding
    has_scratch_buf = False if ifm_bits == 16 and ofm_bits == 16 else True
    # when doing Y,X slice, using dma, so when ifm from shim will be doing the slice
    # when doing C slice, it might be kernel.
    Yis, Xis, Cis, Cos = generate_subv(
        aie_cols,
        aie_rows,
        ifm_bits,
        ofm_bits,
        slice,
        Ni,
        Yi,
        Xi,
        Ci,
        No,
        Yo,
        Xo,
        Co,
        out_step_dims,
        wgt_subv,
        Cop,
        has_scratch_buf,
    )

    Ni_gran = 1
    subv_elem = Ni_gran * Yis * Xis * Cis

    kernel_names = {}
    kernel_includes = ["super.hh"]
    if not is_slice_kernel:
        if is_qdq:
            kernel_names["run_combined_qdq"] = kernel_func_list.index(
                "run_combined_qdq"
            )
            kernel_includes.append("qdq/wrapper_qdq.cc")
    else:
        if Ci_orig == 64 and innerC == 63:
            kernel_names["run_slice_c6463"] = kernel_func_list.index("run_slice_c6463")
            kernel_includes.append("slice_c6463/wrapper_slice_c6463.cc")
        else:
            kernel_names["run_slice"] = kernel_func_list.index("run_slice")
            kernel_includes.append("slice/wrapper_slice.cc")

    # NOTE: 8 for W8 padding
    row_alignment = 8

    return SliceDims(
        aie_cols=aie_cols,
        aie_rows=aie_rows,
        Ni=Ni,
        Yi=Yi,
        Xi=Xi,
        Ci=Ci,
        No=No,
        Yo=Yo,
        Xo=Xo,
        Co=Co,
        Ni_gran=Ni_gran,
        Yis=Yis,
        Xis=Xis,
        Cis=Cis,
        Cos=Cos,
        Cop=Cop,
        ifm_bits=ifm_bits,
        ofm_bits=ofm_bits,
        fixed_point_bits=fixed_point_bits,
        slice=slice,
        Ci_orig=Ci_orig,
        Ni_slice_start=Ni_slice_start,
        Ni_slice_stop=Ni_slice_stop,
        axis=axis,
        out_start=out_start,
        out_stop=out_stop,
        wgt_subv_size=wgt_subv_size,
        is_qdq=is_qdq,
        qdq_mode=qdq_mode,
        is_kernel=is_kernel,
        enable_padding=enable_padding,
        kernel_padding=enable_kernel_padding,
        shard=shard,
        innerC=innerC,
        startC=slice_inner_start,
        Com=Com,
        row_alignment=row_alignment,
        param_subv_size=param_subv_size,
        subv_elem=subv_elem,
        has_scratch_buf=has_scratch_buf,
        out_step_dims=out_step_dims,
    )
