#ifndef __LINEAR_APPROX_BF16_WRAPPER_CC__
#define __LINEAR_APPROX_BF16_WRAPPER_CC__

#include "q/q.hpp"
#include "q/q_impl.hpp"
#include "dq/dq.hpp"
#include "dq/dq_impl.hpp"
#include "linear_approx_bf16.hpp"
#include "linear_approx_bf16_impl.hpp"
#include "uniop/uniop_qdq.hpp"

#define DEBUG_LINEAR_APPROX 0

void lut_lookup(int8_t* ifm, int8_t* lut_ab, int8_t* lut_cd, int8_t* ofm, int8_t* spill_buf, 
        const LinearApproxBF16Params<QDQFloatType>& kernel_params)  
{
    linear_approx_bf16<QDQFloatType>
    (
        (QDQFloatType*) ifm,
        (float*) lut_ab,
        (float*) lut_cd,
        (QDQFloatType*) ofm, 
        (int*) spill_buf,
        kernel_params
    );
}

void run_lut_fp16x16(KernelArgs& args) 
{
    linear_approx_layer_param<QDQFloatType>* layer_params = (linear_approx_layer_param<QDQFloatType>*)args.params_data;
    const LinearApproxBF16Params<QDQFloatType>& kernel_params = layer_params->krn_param; 

    int8_t* matA       = reinterpret_cast<int8_t*>(layer_params->input_addr); 
    int8_t* mat_LUT_ab = reinterpret_cast<int8_t*>(layer_params->lut_ab_addr);
    int8_t* mat_LUT_cd = reinterpret_cast<int8_t*>(layer_params->lut_cd_addr); 
    int8_t* mat_Spill  = reinterpret_cast<int8_t*>(layer_params->spill_buf_addr); 
    int8_t* output     = reinterpret_cast<int8_t*>(layer_params->output_addr);
    int8_t* matQdq     = reinterpret_cast<int8_t*>(layer_params->qdq_param_addr);
    
#if DEBUG_LINEAR_APPROX
    if((get_coreid() & 0xF)==2 && (get_coreid() >> 16)==0)
    {
        chess_report(layer_params->input_addr);
        chess_report(layer_params->lut_ab_addr);
        chess_report(layer_params->lut_cd_addr);
        chess_report(layer_params->spill_addr);
        chess_report(layer_params->output_addr);
        chess_report(layer_params->qdq_param_addr);
    }
#endif
    uniop_qdq_param* qdqprm = reinterpret_cast<uniop_qdq_param*>(layer_params->qdq_param_addr);
    
    bool dequant_en  = (qdqprm->dq_enable == 1);
    bool quant_en    = (qdqprm->q_enable == 1);
    bool nlf_en      = true; //(qdqprm->nlf_enable == 1);

    KernelDqParam dq_krn_param;
    KernelQParam q_krn_param;

    dq_krn_param.inner_g = kernel_params.loop; 
    dq_krn_param.sign_A = layer_params->sign_A;
    
    q_krn_param.inner_g  = kernel_params.loop; 
    q_krn_param.sign_O = layer_params->sign_O;
    
    v32accfloat *dq_buf = (v32accfloat*)(qdqprm->dq_buf);
    v32accfloat *q_buf = (v32accfloat*)(qdqprm->q_buf);

    int8_t *dQ_pIn = matA; 
    int8_t *dQ_pOut = (nlf_en)? matA : output;
    int8_t *Nlf_pIn = matA; 
    int8_t *Nlf_pOut = output;
    int8_t *Q_pIn = (!dequant_en and !nlf_en)? matA : output; 
    int8_t *Q_pOut = output;

#if DEBUG_LINEAR_APPROX
    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    if(colIdx_reg == 0 && rowIdx_reg == 2)
    {
        chess_report(dq_buf[0]);
        chess_report(dq_buf[2]);
        chess_report(dequant_en);
        chess_report(q_buf[0]);
        chess_report(q_buf[2]);
        chess_report(quant_en);
    }
#endif

    dq_float16_v32((int8_t*) dQ_pIn, (float*) dq_buf, (QDQFloatType*) dQ_pOut, dq_krn_param, dequant_en);

    lut_lookup(Nlf_pIn, mat_LUT_ab, mat_LUT_cd, Nlf_pOut, mat_Spill, kernel_params);

    q_float16_to_int16_v32((QDQFloatType*) Q_pIn, (float*) q_buf, (int16*) Q_pOut, q_krn_param, quant_en);
    
}
#endif