#ifndef __POOLING_WRAPPER_CC__
#define __POOLING_WRAPPER_CC__
#include "maxpool_int16x16_impl.hpp"
#include "maxpool_int8x8_impl.hpp"
#include "qdq/wrapper_qdq.hpp"
#include "qdq/qdq_sum.hpp"
#include "qdq/qdq_int16_bfloat16.hpp"
#include "qdq/qdq.cc"



void run_int16_pooling(int8_t* input, int8_t* output, MaxpoolInt16x16Params& kernel_params)
{
    maxpool_int16x16(
        (int*)input,
        (int*)output,
        kernel_params
    );

}

void run_int8_pooling(int8_t* input, int8_t* output, MaxpoolInt8x8Params& kernel_params)
{
    maxpool_int8x8(
        (int*)input,
        (int*)output,
        kernel_params
    );

}

void run_pooling_a16o16_qdq(KernelArgs& args)
{
    struct LayerParams
    {
        uint16_t max_or_avg;
        uint16_t dtype;
        uint16_t subv_elems;
        uint16_t is_signed;
        uint16_t scratch_buf;
        uint16_t dummy;
        MaxpoolInt16x16Params kernel_params;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    MaxpoolInt16x16Params& kernel_params = layer_params->kernel_params;
    int8_t* input = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);

    int8_t* qdq_prm = static_cast<int8_t*>(args.s2mm_ch1_data);
    int input_index   = 0;
    int quant_offset  = 2;
    int byte_offset = quant_offset*4;
    uint16_t* dq_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));
    uint16_t* q_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset ));
    bfloat16* q_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, byte_offset  + 4));
    uint16_t* dq_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*2));
    uint16_t* q_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*3));

    bool is_int16 = (layer_params->dtype == 16);
    bool is_signed = (layer_params->is_signed == 1);
    //sequence:
    // 1. maxpool:  maxpool ->dq->q
    // 2. avgpool:  dq -> Avgpool -> q
    int8_t* matA = (int8_t*)input;
    int8_t* matB = (int8_t*)output;
    run_int16_pooling(matA, matB, kernel_params);
    dequant_int16_to_bf16(matB, matB, layer_params->subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);
    quant_bf16_to_int16(matB, matB, layer_params->subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16);

}


void run_pooling_a8_qdq(KernelArgs& args)
{
    struct LayerParams
    {
        uint16_t max_or_avg;
        uint16_t dtype;
        uint16_t subv_elems;
        uint16_t is_signed;
        uint16_t scratch_buf;
        uint16_t dummy;

        MaxpoolInt8x8Params kernel_params;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    MaxpoolInt8x8Params& kernel_params = layer_params->kernel_params;
    int8_t* input = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);

    int8_t* qdq_prm = static_cast<int8_t*>(args.s2mm_ch1_data);
    int input_index   = 0;
    int quant_offset  = 2;
    int byte_offset = quant_offset*4;
    uint16_t* dq_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));
    uint16_t* q_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset ));
    bfloat16* q_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, byte_offset  + 4));
    uint16_t* dq_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*2));
    uint16_t* q_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*3));

    //sequence:
    // 1. maxpool:  maxpool ->dq->q
    // 2. avgpool:  dq -> Avgpool -> q

    /*scratch
    1. when qdq_mode == 2(dq+q), scratch_size = min(2x output, maxpool output)
    2. when qdq_mode == 0(dq only), scratch_size = min(2x output, maxpool output)
    3. when qdq_mode == 1(q only), scratch_size = min(2x output, maxpool output)

    */
    bool is_int16 = (layer_params->dtype == 16);
    bool is_signed = (layer_params->is_signed == 1);
    int8_t* scratch = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));

    int8_t* matA = (int8_t*)input;
    int8_t* matB = (int8_t*)output;


    // int col_id = get_coreid() >> 16;
    // int row_id = get_coreid() & 0xf;
    // if(col_id == 0 && row_id == 2){
    //     printf("------------COL0::ROW0: start------------\n");
    //     printf("is_int16: %d\n", is_int16);
    //     printf("is_signed: %d\n", is_signed);
    //     printf("dq_enable: %d\n", *dq_enable);
    //     printf("q_enable: %d\n", *q_enable);

    //     printf("kernel_params.outer_loop: %d\n", kernel_params.outer_loop);
    //     printf("kernel_params.inner_loop: %d\n", kernel_params.inner_loop);
    //     printf("kernel_params.step_Ky: %d\n", kernel_params.step_Ky);
    //     printf("kernel_params.shfl_0: %d\n", kernel_params.shfl_0);
    //     printf("kernel_params.shfl_1: %d\n", kernel_params.shfl_1);
    //     printf("kernel_params.shft_0: %d\n", kernel_params.shft_0);
    //     printf("kernel_params.shft_1: %d\n", kernel_params.shft_1);
    //     printf("kernel_params.shft_2: %d\n", kernel_params.shft_2);
    //     printf("kernel_params.min_value: %d\n", kernel_params.min_value);
    //     printf("kernel_params.ctrl.sign: %d\n", kernel_params.ctrl.sign);
    //     printf("kernel_params.dimsA.num0: %d\n", kernel_params.dimsA.num0);
    //     printf("kernel_params.dimsA.num1: %d\n", kernel_params.dimsA.num1);
    //     printf("kernel_params.dimsA.inc0: %d\n", kernel_params.dimsA.inc0);
    //     printf("kernel_params.dimsA.inc1: %d\n", kernel_params.dimsA.inc1);
    //     printf("kernel_params.dimsA.inc2: %d\n", kernel_params.dimsA.inc2);
    //     printf("------------COL0::ROW0: start------------\n");
    // }

    /*sequence:
    1. when qdq_mode ==3 : (maxpool -> output)
    2. when qdq_mode ==2 :(maxpool -> output_buf -> dq to input_buf -> q to output_buf)
    3. when qdq_mode ==0 :(maxpool -> scratch_buf -> dq to output_buf)
    4. when qdq_mode ==1 :(q -> inut_buf -> maxpool to output)
    */
    // run_int16_pooling(matA, scratch, kernel_params);
    if (*dq_enable){ // qdq_mode = 0(dq only) and qdq_mode = 2(dq+q)
        if(*q_enable){//2:
            //suppose the input buf will be at least 2x of output buf, otherwise it might fail
            run_int8_pooling(matA, matB, kernel_params);
            dequant_int16_to_bf16(matB, matA, layer_params->subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);
            quant_bf16_to_int16(matA, matB, layer_params->subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16);
        }else{//0: --
            run_int8_pooling(matA, scratch, kernel_params);
            dequant_int16_to_bf16(scratch, matB, layer_params->subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);//NOTE: redundancy -for code clean up
            quant_bf16_to_int16(matB, matB, layer_params->subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16);
        }
    }else{ //qdq_mode = 1( q only) and qdq_mode = 3(q and dq both bypassed)
        if(*q_enable){ //1:
            dequant_int16_to_bf16(matA, matA, layer_params->subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);
            quant_bf16_to_int16(matA, matA, layer_params->subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16); //NOTE: redundancy -for code clean up
            run_int8_pooling(matA, matB, kernel_params);
        }else{
            run_int8_pooling(matA, matB, kernel_params);
            dequant_int16_to_bf16(matB, matB, layer_params->subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);
            quant_bf16_to_int16(matB, matB, layer_params->subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16);
        }
    }
}

#endif
