#ifndef __SILU_EXP2_WRAPPER_CC__
#define __SILU_EXP2_WRAPPER_CC__

#include "q/q.hpp"
#include "q/q_impl.hpp"
#include "dq/dq.hpp"
#include "dq/dq_impl.hpp"
#include "SiLU_exp2.hpp"
#include "SiLU_exp2_impl.hpp"
#include "uniop/uniop_qdq.hpp"

#define DEBUG_SILU 0

void SiLU(int8_t* input, int8_t* output, int8_t* spill_buf_a, int8_t* spill_buf_b, const KernelSiLUExp2Param& kernel_param)
{
    SiLU_exp2<2, 12, bfloat16, bfloat16> //<unsigned poly_order=2, unsigned loop_range=12, typename Ti = bfloat16, typename To = bfloat16>
    (
        (bfloat16*)input,
        (bfloat16*)output,
        (float*)spill_buf_a,
        (float*)spill_buf_b,
        kernel_param
    );
}

void run_silu(KernelArgs& args) 
{
    sigelu_layer_param* layer_params = (sigelu_layer_param*)args.params_data;

    int8_t* matA      = reinterpret_cast<int8_t*>(layer_params->input_addr); 
    int8_t* output    = reinterpret_cast<int8_t*>(layer_params->output_addr); 
    int8_t* matSpillA = reinterpret_cast<int8_t*>(layer_params->spill_a_addr); 
    int8_t* matSpillB = reinterpret_cast<int8_t*>(layer_params->spill_b_addr); 
    int8_t* matQdq    = reinterpret_cast<int8_t*>(layer_params->qdq_param_addr);
    
    KernelSiLUExp2Param* pkernel_param = (KernelSiLUExp2Param*)&(layer_params->num_iters);

#if DEBUG_SILU
    if((get_coreid() & 0xF)==2 && (get_coreid() >> 16)==0)
    {
        chess_report(layer_params->input_addr);
        chess_report(layer_params->output_addr);
        chess_report(layer_params->spill_a_addr);
        chess_report(layer_params->spill_b_addr);
        chess_report(layer_params->qdq_param_addr);
        chess_report(layer_params->qdq_buffer_addr);
        chess_report(outer_g);
    }
#endif
    uniop_qdq_param* qdqprm = reinterpret_cast<uniop_qdq_param*>(layer_params->qdq_param_addr);
    
    bool dequant_en  = (qdqprm->dq_enable == 1);
    bool quant_en    = (qdqprm->q_enable == 1);
    bool nlf_en      = true; 

    KernelDqParam dq_krn_param;
    KernelQParam q_krn_param;

    dq_krn_param.inner_g = pkernel_param->outer_g;
    dq_krn_param.sign_A = layer_params->sign_A;
    q_krn_param.inner_g = pkernel_param->outer_g;
    q_krn_param.sign_O = layer_params->sign_O;

    v32accfloat *dq_buf = (v32accfloat*)(qdqprm->dq_buf);
    v32accfloat *q_buf = (v32accfloat*)(qdqprm->q_buf);

    int8_t *dQ_pIn = matA; 
    int8_t *dQ_pOut = (nlf_en)? matA : output;
    int8_t *Nlf_pIn = matA; 
    int8_t *Nlf_pOut = output;
    int8_t *Q_pIn = ((!dequant_en) and (!nlf_en))? matA : output; 
    int8_t *Q_pOut = output;

#if DEBUG_SILU
    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    if(colIdx_reg == 0 && rowIdx_reg == 2)
    {
        chess_report(dq_buf[0]);
        chess_report(dq_buf[2]);
        chess_report(dequant_en);
        chess_report(q_buf[0]);
        chess_report(q_buf[2]);
        chess_report(quant_en);
    }
#endif

    dq_bfloat16_v32((int8_t*) dQ_pIn, (float*) dq_buf, (bfloat16*) dQ_pOut, dq_krn_param, dequant_en);
    
    SiLU(Nlf_pIn, Nlf_pOut, matSpillA, matSpillB, *pkernel_param);

    q_bfloat16_to_int16_v32((bfloat16*) Q_pIn, (float*) q_buf, (int16*) Q_pOut, q_krn_param, quant_en);
    
}
#endif