import os
CURRDIR = os.path.dirname(os.path.abspath(__file__))
import sys
sys.path.append(os.path.join(CURRDIR, '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..'))
sys.path.append(os.path.join(CURRDIR, '..', '..', 'dmacompiler'))
sys.path.append(CURRDIR)
import json
import logging

from typing import Tuple, List, Optional
import shutil

from dmacompiler import BackEnd

from dataflow_common import clean_overlay, build_sim_overlay
from pooling_common import PoolingDims, iceil, pooling_preproc_directives
from pooling_tiling import pooling_subv_split_mode
import pooling_y8xc_dataflow
from OGOAT.src.L1_fusion.kernel_func_list import kernel_func_list

def pooling_build_qdq_shape(
    back_end: BackEnd,
    kernel_names: List[str],
    kernel_includes: List[str],
    aie_cols: int,
    aie_rows: int,
    input: Tuple[int, int, int, int],
    output: Tuple[int, int, int, int],
    kernel: Tuple[int, int],
    stride: Tuple[int, int],
    pad: Tuple[int, int, int, int],
    ifm_bits: int,
    qdq_mode: int, is_signed: bool,
    frontend_only: bool = False,
    out_folder: Optional[str] = None,
    padding_enable: bool = True,
    max_or_avg: int = 0, # 0: maxpool, 1:avgpool
):
    assert (back_end != BackEnd.Adf) or (out_folder is None)
    assert (aie_cols, aie_rows) == (8, 4) or (aie_cols, aie_rows) == (4, 4)
    """for WAIC flow, we will enable the padding
        1. for the input, will treat the Ci = max(64, W8) despite the real Ci is
        2. for the output, will do padding the Co = max(64, W8) despite the real Co is
        3. will annotate this in the tiling.json
    """
    # check to replace the Ci if needed
    if padding_enable:
        Ci_real = input[3]
        Co_real = output[3]
        input = (input[0], input[1], input[2], iceil(Ci_real, 8))
        output = (input[0], output[1], output[2], iceil(Co_real, 8))

    Ci_gran = 8
    Co_gran = 8

    X_align = 64

    # ifm_bits = 16
    qdq_bits = 32
    # ofm_bits = 16
    CoreqdqPrmSize = 64
    wgt_subv_size = CoreqdqPrmSize * qdq_bits // 8

    is_int16 = 1 if ifm_bits == 16 else 0

    X_gran = (4 if not max_or_avg else 8) if is_int16 else 8

    has_scratch_buf = False
    scratch_buf_bits = 8
    ifm_bits = 16
    ofm_bits = 16
    qdq_mode = int(qdq_mode)
    if is_int16:
        ifm_bits = 16
        ofm_bits = ifm_bits
        has_scratch_buf = False
        scratch_buf_bits = 16
    else: # int8
        if qdq_mode == 0:  #dq only
            #NOTE: sequence:
            # 1. first do transpose (8bits in) -> 8bits output buff (2nd half);
            # 2. then do dq, from 8bits output buf 2nd half to 16bits out buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 16
            has_scratch_buf = True
            scratch_buf_bits = 8
        elif qdq_mode == 1: #q only
            #NOTE: sequence:
            # 1. first do q (16bits input buf) -> 8bits to same buf;
            # 2. then do transpose, from 8bits input buf to 8bits out buf
            # sctrach buf elem:  0
            ifm_bits = 16
            ofm_bits = 8
            has_scratch_buf = False # q output use ifm buffer
            scratch_buf_bits = 8
        elif qdq_mode == 2:
            #NOTE: sequence:
            # 1. first do dq (8bits input buf) -> 16bits to scrath buf;
            # 2. second do q (16bits scratch buf) -> 8bits to scratch buf;
            # 3. then do transpose, from 8bits scratch buf to 8bits out buf
            # sctrach buf elem:  same as ifm
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = False
            scratch_buf_bits = 8
        elif qdq_mode == 3:
            #NOTE: sequence:
            # 1. do transpose from 8bits input buf to 8bits output buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = False
            scratch_buf_bits = 8
        else:
            assert False, f"qdq_mode:{qdq_mode} is not in range(0..3) !"


    subv_split  = pooling_subv_split_mode(
        aie_cols, aie_rows,
        input,
        output,
        kernel,
        stride,
        pad,
        ifm_bits, wgt_subv_size, ofm_bits,
        has_scratch_buf, scratch_buf_bits,
        Ci_gran, Co_gran, X_gran,
        )
    (Ni, Yi, Xi, Ci) = input
    (No, Yo, Xo, Co) = output
    (Ky, Kx) = kernel
    (Sy, Sx) = stride
    (Py_b, Px_b, Py_a, Px_a) = pad

    ((Nis, Yis, Xis, Cis),
    (Nos, Yos, Xos, Cos),
    (Nim, Nom),
    X_split, N_row_split,
    Co_split,
    is_X8_split,
    ifm_streaming_mode,
    spatial_split_mode,
    row_split_mode
    ) = subv_split


    tiling_json = {
        'input': (Ni, Yi, Xi, Ci),
        'output': (No, Yo, Xo, Co),
        'input_subv': (Nis, Yis, Xis, Cis),
        'output_subv': (Nos, Yos, Xos, Cos),
        'Y_split': aie_cols,
        'X_split': X_split,
        'Co_split': Co_split,
        'is_X8_split': is_X8_split,
        'is_ifm_streaming_mode': ifm_streaming_mode,
    }
    tiling_json_filename = os.path.join(CURRDIR, 'tiling.json')
    with open(tiling_json_filename, 'w') as f:
        f.write(json.dumps(tiling_json, sort_keys=True, indent=4))

    dims = PoolingDims(
        aie_cols, aie_rows,
        Ni, Nim, Nis, No, Nom, Nos, N_row_split,
        Ci, Cis, Ci_gran, Co, Cos, Co_gran, Co_split,
        Yi, Yis, Yo, Yos,
        Xi, Xis, Xo, Xos, X_gran, X_align, X_split,
        Ky, Kx,
        Sy, Sx,
        Py_b, Px_b, Py_a, Px_a,
        ifm_bits, ofm_bits, wgt_subv_size,
        has_scratch_buf, scratch_buf_bits,
        spatial_split_mode, row_split_mode,
        is_X8_split = is_X8_split,
        max_or_avg = max_or_avg,
        qdq_mode = qdq_mode,
        is_signed = is_signed,
    )

    clean_overlay()
    host_cpp = os.path.join(os.getcwd(), 'pooling_main.cpp')

    pooling_y8xc_dataflow.compile_dataflow(dims, back_end, kernel_names, kernel_includes)

    if not frontend_only:
        build_sim_overlay(back_end, host_cpp, pooling_preproc_directives(dims, back_end))

    if out_folder is not None:
        if not os.path.exists(out_folder):
            os.makedirs(out_folder)
        in_folder = CURRDIR
        if back_end == BackEnd.TxnHostPatch:
            files = ('ifm.bin', 'wgt.bin', 'ofm.bin',
                    'tiling.json', 'txn.bin', 'param.bin', 'ctrl.bin', 'patch.json')
        else:
            assert False
        for file in files:
            src = os.path.join(in_folder, file)
            dst = os.path.join(out_folder, file)
            shutil.move(src, dst)

def extract_fields(file_name):
    with open(file_name, 'r') as f:
        data = json.load(f)
    return data

def run_pooling_op(json_file, path, txn_mode, kernel_d, frontend_only):
    os.chdir(path)   #because build system only work on current dir (subprocess)
    _data = extract_fields(json_file)
    data = {}
    data['back_end'] = BackEnd.Adf if txn_mode == 0 else BackEnd.TxnHostPatch
    if not kernel_d:
        data['kernel_names'] = {}
        kernel_names = ['run_pooling_a16o16_qdq']
        for k in kernel_names:
            try:
                data['kernel_names'][k] = kernel_func_list.index(k)
            except ValueError:
                print(f"Error: '{k}' not found in the kernel func list!")
        data['kernel_includes'] = ['super.hh', 'pooling/pooling_int16x16_wrapper.cc', ]
    else:
        data['kernel_names'] = kernel_d['kernel_list']
        data['kernel_includes'] = kernel_d['kernel_include']
    data['aie_rows'] = _data['overlay_info']['shape']['row']
    data['aie_cols'] = _data['overlay_info']['shape']['col']
    data['input'] =  (_data['layer_info']['in_act_shape'][0], _data['layer_info']['in_act_shape'][1], _data['layer_info']['in_act_shape'][2], _data['layer_info']['in_act_shape'][2]) #HWC
    data['output'] = (_data['layer_info']['out_act_shape'][0], _data['layer_info']['out_act_shape'][1], _data['layer_info']['out_act_shape'][2], _data['layer_info']['out_act_shape'][2]) #HWC
    data['kernel'] = (_data['layer_info']['attributes']['kernel_shape'][0], _data['layer_info']['attributes']['kernel_shape'][1])
    data['strides'] =  (_data['layer_info']['attributes']['strides'][0], _data['layer_info']['attributes']['strides'][1])
    data['pad'] =  (_data['layer_info']['attributes']['pads'][0], _data['layer_info']['attributes']['pads'][1],
                    _data['layer_info']['attributes']['pads'][2], _data['layer_info']['attributes']['pads'][3])


    output_dir = os.path.dirname(os.path.realpath(json_file)) if data['back_end'] != BackEnd.Adf else None
    logging.info(f" Pooling input args: {data}")
    pooling_build_qdq_shape(data['back_end'], data['kernel_names'], data['kernel_includes'],
                data['aie_cols'], data['aie_rows'],
                data['input'], data['output'],
                data['kernel'], data['strides'],
                data['pad'],
                frontend_only,
                output_dir,
                max_or_avg = 0)  #0: max; 1: avg


def main():
    back_end = BackEnd.Adf
    ifm_bits = 8

    # back_end = BackEnd.TxnHostPatch
    kernel_names = ['run_pooling_a16o16_qdq'] if ifm_bits == 16 else ['run_pooling_a8_qdq']
    kernel_includes = [
        'super.hh',
        'pooling/pooling_wrapper.cc',
    ]
    aie_cols = 8
    aie_rows = 4
    #CYX

    #Maxpool-0
    input = (1, 128, 128, 384)
    output = (1, 64, 64, 384)
    #Maxpool-1
    input = (1, 256, 256, 192)
    output = (1, 128, 128, 192)
    #Maxpool-2
    # input = (1, 64, 64, 768)
    # output = (1, 32, 32, 768)
    #Maxpool-3
    input = (1024, 4, 4, 384)
    output = (1024, 2, 2, 384)
    #Maxpool-4
    input = (1024, 8, 8, 192)
    output = (1024, 4, 4, 192)
    #Maxpool-5
    # input  = (25, 14, 14, 768)
    # output = (25, 7, 7, 768)

    kernel = (2, 2)
    stride = (2, 2)
    pad = (0, 0, 0, 0) # (pad_H_start, pad_W_start, pad_H_after, pad_W_after)
    frontend_only = False

    qdq_mode = 2
    is_signed = True

    max_or_avg = 0 # 0: max; 1: avg

    pooling_build_qdq_shape(
        back_end, kernel_names, kernel_includes,
        aie_cols, aie_rows,
        input, output, kernel, stride, pad,
        ifm_bits, qdq_mode, is_signed,
        frontend_only,
        max_or_avg = max_or_avg

    )

if __name__ == '__main__':
	main()
