import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'kernels', 'pooling'))
from typing import List, Tuple, Optional, Type
from pooling_kernel_params import setup_kernel_params, PoolingSubvDims

from dmacompiler import \
    CoreInstr, Loop, AcqBuffer, RelBuffer, ConfigBuffer, CallKernel, \
    DmaChannel, DmaDir, BackEnd

from dataflow_common import ceildiv

class PoolingDims:
    def __init__(
        self,
        aie_cols: int, aie_rows: int,
        Ni: int, Nim: int, Nis: int, No: int, Nom: int, Nos: int, N_row_split: int,
        Ci: int, Cis: int, Ci_gran: int, Co: int, Cos: int, Co_gran: int, Co_split: int,
        Yi: int, Yis: int, Yo: int, Yos: int,
        Xi: int, Xis: int, Xo: int, Xos: int, X_gran: int, X_align: int, X_split: int,
        Ky: int, Kx: int,
        Sy: int, Sx: int,
        Py_b: int, Px_b: int, Py_a: int, Px_a: int,
        ifm_bits: int,  ofm_bits: int,
        wgt_subv_size: int,
        has_scratch_buf: bool,
        scratch_buf_bits: int,
        spatial_split_mode: list,
        row_split_mode: list,
        ifm_use_hwc_format: bool = True,
        ofm_use_hwc_format: bool = True,
		is_X8_split: bool = False,
        max_or_avg: int = 0, # 0: max, 1: avg
        qdq_mode: int = 3, #0: DEQUANT; 1: QUANT; 2: BOTH; 3: NONE
        is_signed: bool = False,
        is_int16: bool = True,
    ):
        self.max_or_avg = max_or_avg
        assert max_or_avg ==0 or max_or_avg == 1
        Y_gran =  1
        Ci_block = (Ci_gran * ifm_bits) // 8
        assert Yo == pooling_output(Yi, Ky, Sy, Py_b, Py_a)
        assert Xo == pooling_output(Xi, Kx, Sx, Px_b, Px_a)
        assert Yos == pooling_output(Yis, Ky, Sy, 0, 0)
        assert Xis == iceil(pooling_input(Xos, Kx, Sx) * Ci_block, X_align) // Ci_block
        assert (Cis % Ci_gran) == 0
        assert (Cos % Co_gran) == 0
        assert (Xos % X_gran) == 0

        assert ((Yos % Y_gran) == 0)
        self.aie_cols = aie_cols
        self.aie_rows = aie_rows
        self.Ni = Ni
        self.Nim = Nim
        self.Nis = Nis
        self.No = No
        self.Nom = Nom
        self.Nos = Nos
        self.Ci = Ci
        self.Cis = Cis
        self.Ci_gran = Ci_gran
        self.Co = Co
        self.Cos = Cos
        self.Co_gran = Co_gran
        self.Co_split = Co_split
        self.Ci_loop = ceildiv(Ci, Cis)
        self.Co_loop = ceildiv(Co, (Co_split * Cos))
        assert ((Ci % Cis) == 0) or (Ci < Cis)
        assert (Co % (Co_split * Cos) == 0) or (Co < (Co_split * Cos))
        self.Yi = Yi
        self.Yis = Yis
        self.Yo = Yo
        self.Yos = Yos
        self.N_loop = ceildiv(Ni, (spatial_split_mode[0] * Nim))
        self.Y_loop = ceildiv(Yo, (spatial_split_mode[1] * Yos)) if not is_X8_split else ceildiv(Yo, Yos)
        self.X_loop = ceildiv(Xo, (self.aie_cols * Xos))
        self.is_X8_split = is_X8_split
        self.Xi = Xi
        self.Xis = Xis
        self.Xo = Xo
        self.Xos = Xos
        self.X_gran = X_gran
        self.X_align = X_align
        self.X_split = X_split
        self.N_row_split = N_row_split
        if not is_X8_split:
            assert Xo <= X_split * Xos
        # if not is_X8_split:
        assert X_split * N_row_split * Co_split == self.aie_rows
        # else:
        #     Co_split *  == self.aie_rows
        self.Ky = Ky
        self.Kx = Kx
        self.Sy = Sy
        self.Sx = Sx
        self.Py_b = Py_b
        self.Px_b = Px_b
        self.Py_a = Py_a
        self.Px_a = Px_a
        self.ifm_bits = ifm_bits
        self.ofm_bits = ofm_bits
        self.ifm_subv_size = (Nis * Cis * Yis * Xis * ifm_bits) // 8
        self.wgt_subv_size = wgt_subv_size
        self.ofm_subv_size = (Nos * Cos * Yos * Xos * ofm_bits) // 8
        self.param_subv_size = 1024
        self.ifm_use_hwc_format = ifm_use_hwc_format
        self.ofm_use_hwc_format = ofm_use_hwc_format
        self.shim_BD_num = {'ifm': 10, 'wgt': 1, 'ofm': 4, 'prm': 1}
        self.qdq_mode = qdq_mode
        self.is_signed = is_signed
        self.is_int16 = ifm_bits == 16
        self.has_scratch_buf = has_scratch_buf
        self.scratch_buf_bits = scratch_buf_bits
        self.scratch_buf_size = (Nos * Cos * Yos * Xos * scratch_buf_bits) // 8 if has_scratch_buf else 0
        self.spatial_split_mode = spatial_split_mode
        self.row_split_mode = row_split_mode
        # for int8, the subv is really for qdq, so it could be ifm (before pooling) or ofm(after pooling)
        # for int16 -- no confuse. it is ofm
        # for int8 : 1) for qdq_mode=0 (dq), it is ifm, otherwise it is ofm
        self.subv_elem = Cis * Yis * Xis if (ifm_bits == 8 and qdq_mode == 0) else Cos * Yos * Xos

class CoreAlloc:
    def __init__(
        self,
        ifm_ping_addr: int,
        ifm_pong_addr: Optional[int],
        wgt_ping_addr: int,
        wgt_pong_addr: Optional[int],
        ofm_ping_addr: int,
        ofm_pong_addr: Optional[int],
        tdm_ping_addr: int,
        tdm_pong_addr: int,
        ifm_sum_addr: int,
        scratch_buf: int,
        tmp_buf : int,
    ):
        self.ifm_ping_addr = ifm_ping_addr
        self.ifm_pong_addr = ifm_pong_addr
        self.wgt_ping_addr = wgt_ping_addr
        self.wgt_pong_addr = wgt_pong_addr
        self.ofm_ping_addr = ofm_ping_addr
        self.ofm_pong_addr = ofm_pong_addr
        self.tdm_ping_addr = tdm_ping_addr
        self.tdm_pong_addr = tdm_pong_addr
        self.ifm_sum_addr = ifm_sum_addr
        self.scratch_buf = scratch_buf
        self.tmp_buf = tmp_buf

def iceil(x: int, d: int) -> int:
    return ceildiv(x, d) * d

def pooling_output(input: int, kernel: int, stride: int, pad_before: int, pad_after: int) -> int:
    output = ((input + pad_before + pad_after - kernel) // stride) + 1
    return output

def pooling_input(output: int, kernel: int, stride: int) -> int:
    input = ((output - 1) * stride) + kernel
    return input

def X_index(dims: PoolingDims, row: int) -> int:
    assert 0 <= row < dims.aie_rows
    return row % dims.X_split

def Co_index(dims: PoolingDims, row: int) -> int:
    assert 0 <= row < dims.aie_rows
    return row // (dims.aie_rows // dims.Co_split)

# def Xi_slice(dims: PoolingDims, row: int) -> Tuple[int, int, int, int]:
#     row_valid = row if dims.row_split_mode[2] == dims.aie_rows else 0
#     Xi_stride = dims.Xos * dims.Sx
#     Xi_start = (X_index(dims, row_valid) * Xi_stride) - dims.Px_b
#     if Xi_start >= dims.Xi:
#         Xi_start = 0
#     Xi_stop = (
#         min(Xi_start + pooling_input(dims.Xos, dims.Kx, dims.Sx),
#             dims.Xi + dims.Px_a) if Xi_start <= dims.Xi else
#         Xi_start
#     )
#     Xi_size = Xi_stop - Xi_start
#     return (Xi_start, Xi_stop, Xi_stride, Xi_size)

def Xi_slice(dims: PoolingDims, row: int) -> Tuple[int, int, int, int]:
    Xi_stride = dims.Xos * dims.Sx
    Xi_start = (X_index(dims, row) * Xi_stride) - dims.Px_b
    if Xi_start >= dims.Xi:
        Xi_start = 0
    # Xi_stop = (
    #     min(Xi_start + pooling_input(dims.Xos, dims.Kx, dims.Sx),
    #         dims.Xi + dims.Px_a) if Xi_start <= dims.Xi else
    #     Xi_start
    # )
    Xi_stop = Xi_start + pooling_input(dims.Xos, dims.Kx, dims.Sx)
    Xi_size = Xi_stop - Xi_start
    return (Xi_start, Xi_stop, Xi_stride, Xi_size)

def Xo_slice(dims: PoolingDims, row: int) -> Tuple[int, int, int, int]:
    Xo_stride = dims.Xos
    Xo_start = X_index(dims, row) * Xo_stride
    Xo_stop = (
        min(Xo_start + Xo_stride, dims.Xo) if Xo_start < dims.Xo else
        Xo_start
    )
    Xo_size = Xo_stop - Xo_start
    return (Xo_start, Xo_stop, Xo_stride, Xo_size)

def Co_slice(dims: PoolingDims, row: int) -> Tuple[int, int, int, int]:
    Co_stride = dims.Cos * dims.Co_split
    Co_start = Co_index(dims, row) * dims.Cos
    Co_stop = min(Co_start + dims.Cos, dims.Co)
    Co_size = Co_stop - Co_start
    return (Co_start, Co_stop, Co_stride, Co_size)

def ifm_core_memory(dims: PoolingDims) -> str:
    return f'Yi:{dims.Yis} Ci:{dims.Cis} Xi:{dims.Xis} Ci:{dims.Ci_gran}'

def ifm_core_s2mm(dims: PoolingDims, row: int) -> str:
    _, _, _, Xi_size = Xi_slice(dims, row)
    return f'Yi:0:{dims.Yis} Ci:0:{dims.Cis}:{dims.Ci_gran} Xi:0:{Xi_size} Ci:0:{dims.Ci_gran}'

def ofm_core_memory(dims: PoolingDims) -> str:
    return f'Yo:{dims.Yos} Co:{dims.Cos} Xo:{dims.Xos} Co:{dims.Co_gran}'

def ofm_core_mm2s(dims: PoolingDims, row: int) -> str:
    _, _, _, Xo_size = Xo_slice(dims, row)
    _, _, _, Co_size = Co_slice(dims, row)
    return f'Yo:0:{dims.Yos} Co:0:{Co_size}:{dims.Co_gran} Xo:0:{Xo_size} Co:0:{dims.Co_gran}'

def pooling_preproc_directives(dims: PoolingDims, back_end: BackEnd) -> List[str]:
    def directive(ident: str, val: int) -> str:
        if back_end == BackEnd.Adf:
            return f'--Xpreproc="-D{ident}={val}"'
        return f"-D{ident}={val}"
    txn_mode = int(back_end != BackEnd.Adf)
    return [
        directive('AIE_ROWS', dims.aie_rows),
        directive('AIE_COLS', dims.aie_cols),
        directive('N_IN', dims.Ni),
        directive('C_IN', dims.Ci),
        directive('C_IN_SUBV', dims.Cis),
        directive('Y_IN', dims.Yi),
        directive('X_IN', dims.Xi),
        directive('N_OUT', dims.No),
        directive('C_OUT', dims.Co),
        directive('C_OUT_SUBV', dims.Cos),
        directive('C_OUT_SPLIT', dims.Co_split),
        directive('Y_OUT', dims.Yo),
        directive('X_OUT', dims.Xo),
        directive('KERNEL_Y', dims.Ky),
        directive('KERNEL_X', dims.Kx),
        directive('STRIDE_Y', dims.Sy),
        directive('STRIDE_X', dims.Sx),
        directive('PAD_Y_BEFORE', dims.Py_b),
        directive('PAD_X_BEFORE', dims.Px_b),
        directive('PAD_Y_AFTER', dims.Py_a),
        directive('PAD_X_AFTER', dims.Px_a),
        directive('MAX_OR_AVG', dims.max_or_avg),
        directive('QDQ_SIZE', dims.wgt_subv_size),
        directive('QDQ_MODE', dims.qdq_mode),
        directive('IS_INT16', int(dims.ifm_bits == 16)),
        directive('IS_SIGNED', int(dims.is_signed)),
        directive('IFM_IS_HWC', int(dims.ifm_use_hwc_format)),
        directive('OFM_IS_HWC', int(dims.ofm_use_hwc_format)),
        directive('TXN_MODE', txn_mode),
    ]

def pooling_core_alloc(dims: PoolingDims, stack_addr: int) -> CoreAlloc:
    def is_valid(last_addr: int) -> bool:
        return (last_addr  <= stack_addr)
    # Place sum buffers at end of call address allocation just in case
    # there is overflow based on the test.
    core_bank_size = 16384
    wgt_ping_addr = 0
    scratch_buf = iceil(wgt_ping_addr + dims.wgt_subv_size, 64)
    ifm_ping_addr = iceil(scratch_buf + dims.scratch_buf_size, 64)
    ofm_ping_addr = max(1 * core_bank_size, iceil(ifm_ping_addr + dims.ifm_subv_size, 64))
    ifm_pong_addr = max(2 * core_bank_size, iceil(ofm_ping_addr + dims.ofm_subv_size, 64))
    ofm_pong_addr = max(3 * core_bank_size, iceil(ifm_pong_addr + dims.ifm_subv_size, 64))

    # Place buffers with no bank padding if first allocation fails
    if not is_valid(ofm_pong_addr + dims.ofm_subv_size):
        ifm_ping_addr = iceil(scratch_buf + dims.scratch_buf_size, 64)
        ofm_ping_addr = iceil(ifm_ping_addr + dims.ifm_subv_size, 64)
        ifm_pong_addr = iceil(ofm_ping_addr + dims.ofm_subv_size, 64)
        ofm_pong_addr = iceil(ifm_pong_addr + dims.ifm_subv_size, 64)

    if not is_valid (ofm_pong_addr + dims.ofm_subv_size):
        ifm_ping_addr = iceil(scratch_buf + dims.scratch_buf_size, 64)
        ofm_ping_addr = iceil(ifm_ping_addr + dims.ifm_subv_size, 64)
        ifm_pong_addr = None
        ofm_pong_addr = None
        assert is_valid(ofm_ping_addr + dims.ofm_subv_size)

    core_alloc = CoreAlloc(
    ifm_ping_addr, ifm_pong_addr,
    wgt_ping_addr, None,
    ofm_ping_addr, ofm_pong_addr,
    None, None,
    None, scratch_buf,
    None
    )
    return core_alloc

def pooling_a16o16_qdq_kernel_name() -> str:
    return 'run_pooling_a16o16_qdq'

def pooling_a8_qdq_kernel_name() -> str:
    return 'run_pooling_a8_qdq'

def pooling_call_kernel(
    dims: PoolingDims,
    core_alloc: CoreAlloc,
) -> CallKernel:
    if dims.is_int16:
        kernel_name = pooling_a16o16_qdq_kernel_name()
    else:
        kernel_name = pooling_a8_qdq_kernel_name()

    params_blob = setup_kernel_params(
        PoolingSubvDims(
        dims.Yos, dims.Xos, dims.Cos,
        dims.Ky, dims.Kx,
        dims.Sy, dims.Sx,
        dims.X_gran, dims.Co_gran,
        dims.ifm_bits,
        dims.subv_elem,
        dims.max_or_avg,
        dims.is_signed,
        core_alloc.scratch_buf,
    ), dims.is_int16, dims.is_signed)

    return CallKernel(kernel_name, kernel_params=params_blob)

def pooling_core_instrs(
    dims: PoolingDims,
    core_alloc: CoreAlloc,
    outer_loop: int,
    inner_loop: int,
    ifm_config: Optional[ConfigBuffer] = None,
    wgt_config: Optional[ConfigBuffer] = None,
    ofm_config: Optional[ConfigBuffer] = None,
) -> List[Type[CoreInstr]]:

    if ifm_config is None:
        ifm_config = ConfigBuffer(
            DmaChannel(DmaDir.S2MM, 0),
            core_alloc.ifm_ping_addr, core_alloc.ifm_pong_addr, dims.ifm_subv_size
        )
    if wgt_config is None:
        wgt_config = ConfigBuffer(
            DmaChannel(DmaDir.S2MM, 1),
            core_alloc.wgt_ping_addr, core_alloc.wgt_pong_addr, dims.wgt_subv_size
        )
    if ofm_config is None:
        ofm_config = ConfigBuffer(
            DmaChannel(DmaDir.MM2S, 0),
            core_alloc.ofm_ping_addr, core_alloc.ofm_pong_addr, dims.ofm_subv_size
        )

    core_instrs = [
        wgt_config,
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        Loop(outer_loop, [
            ifm_config,
            ofm_config,
            Loop(inner_loop, [
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                pooling_call_kernel(dims, core_alloc),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            ]),
        ]),
    ] if outer_loop * inner_loop > 1024 else [
        ifm_config,
        wgt_config,
        ofm_config,
        AcqBuffer(DmaChannel(DmaDir.S2MM, 1)),
        RelBuffer(DmaChannel(DmaDir.S2MM, 1)),
        Loop(outer_loop, [
            Loop(inner_loop, [
                AcqBuffer(DmaChannel(DmaDir.MM2S, 0)),
                AcqBuffer(DmaChannel(DmaDir.S2MM, 0)),
                pooling_call_kernel(dims, core_alloc),
                RelBuffer(DmaChannel(DmaDir.S2MM, 0)),
                RelBuffer(DmaChannel(DmaDir.MM2S, 0)),
            ]),
        ]),
    ]

    return core_instrs
