#ifndef __L2NORM_FP16x16_WRAPPER_CC__
#define __L2NORM_FP16x16_WRAPPER_CC__

#include "q/q.hpp"
#include "q/q_impl.hpp"
#include "dq/dq.hpp"
#include "dq/dq_impl.hpp"
#include "l2norm_fp16x16.hpp"
#include "l2norm_fp16x16_impl.hpp"
#include "uniop/uniop_qdq.hpp"

#define DEBUG_L2NORM 0

void l2norm(int8_t* input, int8_t* weight, int8_t* output, const KernelL2Norm_fp16x16_Param& kernel_param)
             //uint16_t Nlayer, uint16_t Msubv, uint16_t Nsubv, bool multi_core_sm, 
             //uint16_t colIdx, uint16_t rowIdx, float softmax_scalefactor)
{
    l2norm_fp16x16<QDQFloatType, QDQFloatType>
    (
        (QDQFloatType*)input,
        (QDQFloatType*)weight,
        (QDQFloatType*)output,
        kernel_param
    );
}

void run_l2norm_fp16x16(KernelArgs& args) 
{
    l2norm_layer_param* layer_params = (l2norm_layer_param*)args.params_data;
    const KernelL2Norm_fp16x16_Param& kernel_params = layer_params->krn_param; 

    int8_t* matA   = reinterpret_cast<int8_t*>(layer_params->input_addr); 
    int8_t* matW   = reinterpret_cast<int8_t*>(layer_params->weight_addr); 
    int8_t* output = reinterpret_cast<int8_t*>(layer_params->output_addr);
    int8_t* matQdq = reinterpret_cast<int8_t*>(layer_params->qdq_param_addr);
    
#if DEBUG_L2NORM
    if((get_coreid() & 0xF)==2 && (get_coreid() >> 16)==0)
    {
        chess_report(layer_params->input_addr);
        chess_report(layer_params->weight_addr);
        chess_report(layer_params->output_addr);
        chess_report(layer_params->qdq_param_addr);
        chess_report(layer_params->qdq_buffer_addr);

        chess_report(kernel_params.order_64);
        chess_report(kernel_params.inner_g);
        chess_report(kernel_params.X_g);
        chess_report(kernel_params.outer_g);

        chess_report(kernel_params.dimsI_ol.num0);
        chess_report(kernel_params.dimsI_ol.inc0);
        chess_report(kernel_params.dimsI_ol.inc1);

        chess_report(kernel_params.dimsO_ol.num0);
        chess_report(kernel_params.dimsO_ol.inc0);
        chess_report(kernel_params.dimsO_ol.inc1);

        chess_report(kernel_params.dimsI_il.num0);
        chess_report(kernel_params.dimsI_il.inc0);
        chess_report(kernel_params.dimsI_il.inc1);

        chess_report(kernel_params.dimsO_il.num0);
        chess_report(kernel_params.dimsO_il.inc0);
        chess_report(kernel_params.dimsO_il.inc1);
    }
#endif
    uniop_qdq_param* qdqprm = reinterpret_cast<uniop_qdq_param*>(layer_params->qdq_param_addr);
    
    bool dequant_en  = (qdqprm->dq_enable == 1);
    bool quant_en    = (qdqprm->q_enable == 1);
    bool nlf_en      = true; //(qdqprm->nlf_enable == 1);

    int true_num_cols = layer_params->true_num_cols;

    KernelDqParam dq_krn_param;
    KernelQParam q_krn_param;

    dq_krn_param.inner_g = layer_params->num_elem_subv; 
    dq_krn_param.sign_A = layer_params->sign_A;
    
    q_krn_param.inner_g  = layer_params->num_elem_subv; 
    q_krn_param.sign_O = layer_params->sign_O;
    
    v32accfloat *dq_buf = (v32accfloat*)(qdqprm->dq_buf);
    v32accfloat *q_buf = (v32accfloat*)(qdqprm->q_buf);

    int8_t *dQ_pIn = matA; 
    int8_t *dQ_pOut = (nlf_en)? matA : output;
    int8_t *Nlf_pIn = matA; 
    int8_t *Nlf_pOut = output;
    int8_t *Q_pIn = (!dequant_en and !nlf_en)? matA : output; 
    int8_t *Q_pOut = output;

#if DEBUG_L2NORM
    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    if(colIdx_reg == 0 && rowIdx_reg == 2)
    {
        chess_report(dq_buf[0]);
        chess_report(dq_buf[2]);
        chess_report(dequant_en);
        chess_report(q_buf[0]);
        chess_report(q_buf[2]);
        chess_report(quant_en);
    }
#endif

    dq_float16_v32((int8_t*) dQ_pIn, (float*) dq_buf, (QDQFloatType*) dQ_pOut, dq_krn_param, dequant_en);
    
    l2norm(Nlf_pIn, matW, Nlf_pOut, kernel_params);

    q_float16_to_int16_v32((QDQFloatType*) Q_pIn, (float*) q_buf, (int16*) Q_pOut, q_krn_param, quant_en);
    
}
#endif