import os
import sys
import math
from typing import List, Dict, Tuple

CURRDIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'OGOAT', 'src', 'L1_fusion'))

from dmacompiler import (
    set_dev_gen, DevGen, config
)

from dataflow_common import ceildiv, iceil, tiling_json_gen, overlay_stack_addr, ifloor
from pad_common import PadDims
from dataflow_utils import CommonDims
from kernel_func_list import kernel_func_list

set_dev_gen(DevGen.Aie2p)
config.ENABLE_BUSY_POLL = True

def sort_subv_cost(subv_splits: List) -> List:
        def split_key(split):
            _, _, loop, ping_pong, overcompute, _, _ = split
            total_loop = math.prod(loop)
            return (not ping_pong, total_loop, overcompute) # not ping_ping meaning ping_Pong has high priority
        sorted_subv_splits = sorted(subv_splits, key=split_key)
        return sorted_subv_splits

def generate_subv(aie_cols, aie_rows,
               input, output, pad_dims,
               in_gran, out_gran,
               pad_limit,
               ifm_bits,
               ofm_bits,
               fix_point_bits,
               wgt_subv_size,
               qdq_mode,
               ) -> Tuple[int, int, int]:


    usable_mt_size = config.MAX_MEMTILE_ADDR - aie_rows * config.MAX_CORE_LAYER_PARAM_SIZE - wgt_subv_size
    usable_core_size = overlay_stack_addr() - wgt_subv_size

    Ni, Yi, Xi, Ci = input
    No, Yo, Xo, Co = output
    Nis_gran, Yis_gran, Xis_gran, Cis_gran = in_gran
    Nos_gran, Yos_gran, Xos_gran, Cos_gran = out_gran
    Np, Yp, Xp, Cp = No-Ni, Yo-Yi, Xo-Xi, Co-Ci



    wgt_subv_size = 256  # we will use this value no matter qdq or no-qdq
    # total_elements = No * Yo * Xo * new_Co
    subv_splits = []


    def choose_best_split(dim_size, gran, split_factor):
        # Priority: N > Y > X > C  (outer dim first)
        priority = [0, 1, 2, 3]  # lower index = higher priority

        candidates = []

        for i, (d, g) in enumerate(zip(dim_size, gran)):
            if d / g >= split_factor:
                score = abs((d // g) - split_factor)  # closer to perfect 8 blocks
                candidates.append((score, priority[i], i))
            else:
                # cannot split this dim
                candidates.append((float("inf"), priority[i], i))

        # choose smallest score; tie-break by priority
        _, _, best_idx = min(candidates)

        # build split mode
        split_mode = [1, 1, 1, 1]
        split_mode[best_idx] = split_factor
        #TODO: temp
        # row_split_mode = [1, 1, 1, 1]

        return best_idx, split_mode

    _, spatial_split_mode = choose_best_split(input, [Nis_gran, Yis_gran, Xis_gran, Cis_gran], aie_cols)

    Nim_grid = [n for n in range(1, Ni + 1) if n % Nis_gran == 0] if spatial_split_mode[0] != aie_cols else [ceildiv(Ni, aie_cols)]
    Yim_grid = [n for n in range(1, Yi + 1) if n % Yis_gran == 0] if spatial_split_mode[1] != aie_cols else [ceildiv(Yi, aie_cols)]
    Xim_grid = [n for n in range(1, Xi + 1) if n % Xis_gran == 0] if spatial_split_mode[2] != aie_cols else [ceildiv(Xi, aie_cols)]
    Cim_grid = [n for n in range(1, Ci + 1) if n % Cis_gran == 0] if spatial_split_mode[3] != aie_cols else [ceildiv(Ci, aie_cols)]

    for Nim in Nim_grid:
        for Yim in Yim_grid:
            for Xim in Xim_grid:
                for Cim in Cim_grid:
                    N_loop = ceildiv(Ni, Nim * spatial_split_mode[0])
                    Y_loop = ceildiv(Yi, Yim * spatial_split_mode[1])
                    X_loop = ceildiv(Xi, Xim * spatial_split_mode[2])
                    C_loop = ceildiv(Ci, Cim * spatial_split_mode[3])


                    ifm_mt_subv_size  = Nim * Yim * Xim * Cim * ifm_bits // 8
                    ofm_mt_subv_size  = Nim * Yim * Xim * Cim * ofm_bits // 8

                    if qdq_mode ==3 :
                        Nis, Yis, Xis, Cis = Nim, Yim, Xim, Cim
                        row_split_mode = [1, 1, 1, 1]
                    else:
                        _, row_split_mode = choose_best_split([Nim, Yim, Xim, Cim],
                                                            [Nis_gran, Yis_gran, Xis_gran, Cis_gran],
                                                            aie_rows)
                        Nis, Yis, Xis, Cis = [ceildiv(m, r) for m, r in zip([Nim, Yim, Xim, Cim], row_split_mode)]
                    ifm_core_subv_size  = Nis * Yis * Xis * Cis * ifm_bits // 8
                    core_scratch_buf_size = 0 if fix_point_bits == 16 else Nis * Yis * Xis * Cis * 2
                    ofm_core_subv_size  = Nis * Yis * Xis * Cis * ofm_bits // 8
                    # for last loop  it will pad the size  of ofm,
                    # currently we can only do the last loop pad . and also depends on which Dim doing spatial/ column 8 split.
                    # 1) for the dim is not for spatial split, we just need to get the last loop;
                    # 2) if the dim is doing the spatial split, we need to know the last loop and which column being the padding.
                    ofm_mt_pad_size = (Nim + pad_dims[0]) * (Yim + pad_dims[1]) * \
                                      (Xim + pad_dims[2]) * (Cim + pad_dims[3]) * ofm_bits // 8

                    mt_valid = ifm_mt_subv_size + ofm_mt_pad_size <= usable_mt_size
                    core_valid = iceil(ifm_core_subv_size, 64) + iceil(ofm_core_subv_size,64) + iceil(core_scratch_buf_size,64) <= usable_core_size if qdq_mode != 3 else True
                    ping_pong = (2 * ifm_mt_subv_size + 2 * ofm_mt_pad_size) <= usable_mt_size and \
                                ((2 * iceil(ifm_core_subv_size, 64) + 2 * iceil(ofm_core_subv_size, 64) + \
                                iceil(core_scratch_buf_size, 64)) <= usable_core_size if qdq_mode != 3 else True)
                    if qdq_mode == 3:
                        overcompute = (N_loop *  Y_loop * X_loop * C_loop) * (Nim * Yim * Xim * Cim) \
                                  * aie_cols  / (math.prod(input))
                    else:
                        overcompute = (N_loop *  Y_loop * X_loop * C_loop) * (Nis * Yis * Xis * Cis) \
                                  * aie_cols * aie_rows / (math.prod(input))

                    if mt_valid and core_valid:
                        subv_splits.append((
                            (Nis, Yis, Xis, Cis),
                            (Nim, Yim, Xim, Cim),
                            (N_loop, Y_loop, X_loop, C_loop),
                            ping_pong, overcompute,
                            spatial_split_mode,
                            row_split_mode,

                        ))

    sorted_subv = sort_subv_cost(subv_splits)
    core_subv, mt_subv, loop, ping_pong, _ , spatial_split_mode, row_split_mode= sorted_subv[0]

    return core_subv, mt_subv, loop, ping_pong, spatial_split_mode, row_split_mode

def do_padding(inDim: int, padding_enable: bool = True):
    padding_alignment = 8
    outDim = iceil(inDim, padding_alignment) if padding_enable else inDim
    return outDim


def run_tiler(
    aie_cols: int, aie_rows: int,
    input: List[int], output: List[int],
    ifm_bits: int, ofm_bits: int,
    qdq_mode: int = 3,   # 0: DEQUANT; 1: QUANT; 2: BOTH; 3: NONE
    fix_point_bits: int = 16, #16 or 8
    is_signed: bool = False,
    is_input_max64w8_format: bool = True,

    ) -> PadDims:

    Ni, Yi, Xi, Ci = input
    No, Yo, Xo, Co = output
    Cip = do_padding(Ci) if is_input_max64w8_format else Ci
    Cop = do_padding(Co) if is_input_max64w8_format else Co
    pad_dims = [No-Ni, Yo-Yi, Xo-Xi, Co-Ci]

    Nis_gran = 1
    Nos_gran = 1
    Yis_gran = 1
    Yos_gran = 1
    Xis_gran = 1
    Xos_gran = 1
    Cos_gran = 1
    Cis_gran = 32 // ifm_bits
    Cos_gran = 32 // ofm_bits
    Nis_pad_limit = 0
    Yis_pad_limit = 16
    Xis_pad_limit = 32
    Cis_pad_limit = 64
    param_subv_size = 1024


    wgt_subv_size = 64  if qdq_mode != 3 else 0 # no qdq for this version

    core_subv, mt_subv, loop, ping_pong, \
    spatial_split_mode, row_split_mode = generate_subv(
                aie_cols, aie_rows,
                [Ni, Yi, Xi, Cip],
                [No, Yo, Xo, Cop],
                pad_dims,
                [Nis_gran, Yis_gran, Xis_gran, Cis_gran],
                [Nos_gran, Yos_gran, Xos_gran, Cos_gran],
                [Nis_pad_limit, Yis_pad_limit, Xis_pad_limit, Cis_pad_limit],
                ifm_bits, ofm_bits, fix_point_bits,
                wgt_subv_size,
                qdq_mode,
                )


    # NOTE: 8 for W8 padding
    # row_alignment = 8

    ifm_memtile_size = math.prod(mt_subv) * ifm_bits // 8
    ofm_memtile_size = math.prod([a + b for a, b in zip(mt_subv, pad_dims)]) * ofm_bits // 8
    ifm_core_size = math.prod(core_subv) * ifm_bits // 8
    ofm_core_size = math.prod(core_subv) * ofm_bits // 8
    scratch_buf_size = 0 if fix_point_bits == 16 else math.prod(core_subv) * 2

    #TODO: for debug, remove me after
    #  for qdq_mode == 3 case0:

    # ifm_memtile_size = 98304
    # ofm_memtile_size = 204288
    # core_subv = [1, 4, 16, 768]
    # mt_subv = [1, 4, 16, 768]
    # loop = [1, 1, 2, 1]
    # spatial_split_mode =[1, 8, 1, 1]
    # row_split_mode =[1, 1, 1, 1]
    # ping_pong = False

    return PadDims(
        aie_cols=aie_cols,
        aie_rows=aie_rows,
        input = input, Cip = Cip,
        output = output, Cop = Cop,
        pad_dims = pad_dims,
        in_gran = [Nis_gran, Yis_gran, Xis_gran, Cis_gran],
        out_gran = [Nos_gran, Yos_gran, Xos_gran, Cos_gran],
        pad_limit = [Nis_pad_limit, Yis_pad_limit, Xis_pad_limit, Cis_pad_limit],
        ifm_bits= ifm_bits, ofm_bits= ofm_bits,
        wgt_subv_size= wgt_subv_size,
        qdq_mode = qdq_mode,
        fix_point_bits = fix_point_bits,
        is_signed = is_signed,
        param_subv_size= param_subv_size,
        ifm_memtile_size= ifm_memtile_size,
        ofm_memtile_size= ofm_memtile_size,
        ifm_core_size= ifm_core_size,
        ofm_core_size= ofm_core_size,
        scratch_buf_size=scratch_buf_size,

        core_subv = core_subv,
        mt_subv = mt_subv,
        loop = loop,
        ping_pong = ping_pong,
        spatial_split_mode= spatial_split_mode,
        row_split_mode = row_split_mode,

        )
