#ifndef __SOFTMAX_FP16x16_WRAPPER_CC__
#define __SOFTMAX_FP16x16_WRAPPER_CC__

#include "q/q.hpp"
#include "q/q_impl.hpp"
#include "dq/dq.hpp"
#include "dq/dq_impl.hpp"
#include "softmax_fp16x16.hpp"
#include "softmax_fp16x16_impl.hpp"
#include "uniop/uniop_qdq.hpp"

void masking //generateBinaryMatrix
(
	uint8_t* buffer, // Buffer is assumed to have all 0, when called
	int rows, 
	int cols, 
	int trim_rows, 
	int trim_cols
) 
{
    // Calculate how many bits remain after full bytes
    int remainingBits = trim_cols & 0x7;
	// Create a mask for the remaining bits
    uint8_t mask = (1 << remainingBits) - 1; // sets 'remainingBits' number of 1s
    for (int i = 0; i < trim_rows; ++i) 
    {
        int rowStart = i * (cols / 8);
        int i_by_8 = i << 3;   // i * 8
        int rows_by_8 = rows << 3; // rows * 8
        for (int j = trim_cols; j < cols; j += 8)
            buffer[rows_by_8*(j>>6)+((j&0x3F)>>3) + i_by_8] = 0x00; // j>>6 is e.q to j/64, (j&0x3F)>>3 e.q (j%64)/8
            //buffer[8*rows*(j/64)+(j%64)/8 + i*8] = 0x00;//0xFF;  j>>6 is e.q to j/64, j&0x3F e.q j%64
            
        if (remainingBits > 0)
            buffer[rows_by_8*(trim_cols>>6)+ ((trim_cols&0x3F)>>3) + i_by_8] = mask; 
    }

}

void softmax(int8_t* matA, int8_t* mask, int8_t* output, const KernelSoftmax_fp16x16Param& kernel_param)
{
    softmax_fp16x16<QDQFloatType, QDQFloatType>
    (
        (QDQFloatType*)matA,
        (int*)mask,
        (QDQFloatType*)output,
        kernel_param
    );
}

#define UNIT_TEST_1x1 0  // When set to 1, need to change test flag in build_uniop.py

void run_softmax_fp16x16(KernelArgs& args) 
{
    softmax_layer_param* layer_params = (softmax_layer_param*)args.params_data;
    const KernelSoftmax_fp16x16Param& kernel_params = layer_params->krn_param; 

    int8_t* matA   = reinterpret_cast<int8_t*>(layer_params->input_addr); 
    int8_t* matM   = reinterpret_cast<int8_t*>(layer_params->mask_addr); 
    int8_t* matQdq = reinterpret_cast<int8_t*>(layer_params->qdq_param_addr);
    int8_t* output = reinterpret_cast<int8_t*>(layer_params->output_addr);

    uniop_qdq_param* qdqprm = reinterpret_cast<uniop_qdq_param*>(layer_params->qdq_param_addr);
    
    bool dequant_en  = (qdqprm->dq_enable == 1);
    bool quant_en    = (qdqprm->q_enable == 1);
    bool nlf_en      = true; //(qdqprm->nlf_enable == 1);

    int true_num_cols = layer_params->true_num_cols;

    KernelDqParam dq_krn_param;
    KernelQParam q_krn_param;
    dq_krn_param.inner_g = layer_params->num_elem_subv;
    dq_krn_param.sign_A = layer_params->sign_A;
    q_krn_param.inner_g = layer_params->num_elem_subv;
    q_krn_param.sign_O = layer_params->sign_O;
    
    v32accfloat *dq_buf = (v32accfloat*)(qdqprm->dq_buf);
    v32accfloat *q_buf = (v32accfloat*)(qdqprm->q_buf);

    int8_t *dQ_pIn = matA; 
    int8_t *dQ_pOut = (nlf_en)? matA : output;
    int8_t *Nlf_pIn = matA; 
    int8_t *Nlf_pOut = output;
    int8_t *Q_pIn = (!dequant_en and !nlf_en)? matA : output; 
    int8_t *Q_pOut = output;

#if DEBUG_SOFTMAX
    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    if(colIdx_reg == 0 && rowIdx_reg == 2)
    {
        chess_report(dq_buf[0]);
        chess_report(dq_buf[2]);
        chess_report(dequant_en);
        chess_report(q_buf[0]);
        chess_report(q_buf[0]);
        chess_report(quant_en);
    }
#endif

#if !UNIT_TEST_1x1
    for(int i = 0; i < layer_params->msk_num_bytes; i++)
        matM[i] = 0xFF;    
    masking((uint8_t*)matM, layer_params->Msubv, layer_params->Nsubv, layer_params->Msubv, true_num_cols);

    dq_float16_v32((int8_t*) dQ_pIn, (float*) dq_buf, (QDQFloatType*) dQ_pOut, dq_krn_param, dequant_en);
#endif

    softmax(Nlf_pIn, matM, Nlf_pOut, kernel_params);

#if !UNIT_TEST_1x1
    q_float16_to_int16_v32((QDQFloatType*) Q_pIn, (float*) q_buf, (int16*) Q_pOut, q_krn_param, quant_en);
#endif

}

#endif
