#ifndef __PWL_NLF_KERNEL_C__
#define __PWL_NLF_KERNEL_C__
//#include "pwl_nlf_template.h"
#include "pwl_nlf_bf16_template.h"
#include "qdq/qdq_int8_bfloat16.hpp"
#include "qdq/qdq_int16_bfloat16.hpp"
#include "kernel_helpers.h"

#define DEBUG 0

#define IFM_SIGN_MASK 0x1
#define OFM_SIGN_MASK 0x2

#if DEBUG
inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}

//only for debugging purpose
inline float bfloat16_to_float(bfloat16 bf)
{
    return uint_to_float(uint32_t(bf) << 16);
}
#endif

void run_a16w8_silu_gelu_qdq(KernelArgs& args)
{
    struct LayerParam 
    { 
    	uint16_t core_addr;
    	uint16_t lutab_addr;
    	uint16_t lutcd_addr;
    	uint16_t num_elements;
    	uint16_t tdm1_addr;
    	uint16_t tdm2_addr;
    	uint16_t fused_op;
        uint16_t is_in_int16;
        uint16_t is_out_int16;
        uint16_t sign_mask;
    };
    set_sat();
    set_rnd(rnd_conv_even);

    uint16_t* args_params = (uint16_t*)args.params_data;
    LayerParam* LayerParamPtr = static_cast<LayerParam*>(static_cast<void*>(args_params));
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);
    // NOTE: If it is a fused OP, the output of gemm is written to mm2s0 ptr
    int8_t* dq_in = (LayerParamPtr->fused_op == 1) ? output : static_cast<int8_t*>(args.s2mm_ch0_data);
    int is_in_int16 = LayerParamPtr->is_in_int16;
    int is_out_int16 = LayerParamPtr->is_out_int16;
    int sign_mask = LayerParamPtr->sign_mask;
    int * lut_ab = static_cast<int*>(conv_to_local_ptr(LayerParamPtr->lutab_addr));
    int * lut_cd = static_cast<int*>(conv_to_local_ptr(LayerParamPtr->lutcd_addr));
    //int8_t* restrict qdq_lut = static_cast<int8_t*>(args.s2mm_ch1_data);
    void* qdq_lut = static_cast<void*>(conv_to_local_ptr(LayerParamPtr->core_addr));

    uint16_t* dq_zero_point = static_cast<uint16_t *>(qdq_lut);
    bfloat16* dq_scale      = static_cast<bfloat16 *>(byte_incr(qdq_lut,4));
    uint16_t* q_zero_point  = static_cast<uint16_t *>(byte_incr(qdq_lut,8));
    bfloat16* q_scale       = static_cast<bfloat16 *>(byte_incr(qdq_lut,12));
    uint16_t* dq_enable     = static_cast<uint16_t *>(byte_incr(qdq_lut,16));
    uint16_t* q_enable      = static_cast<uint16_t *>(byte_incr(qdq_lut,20));

    int8_t* tdm1 = (*dq_enable == 1) ? static_cast<int8_t*>(conv_to_local_ptr(LayerParamPtr->tdm1_addr)) : dq_in;
    int8_t* tdm2 = (*q_enable == 1) ? static_cast<int8_t*>(conv_to_local_ptr(LayerParamPtr->tdm2_addr)) : output;

    bool is_input_signed = sign_mask & IFM_SIGN_MASK; 
    bool is_output_signed = sign_mask & OFM_SIGN_MASK; 

    dequant_int16_to_bf16(dq_in, tdm1, LayerParamPtr->num_elements, *dq_zero_point, *dq_scale, is_input_signed, (bool) *dq_enable, (bool) is_in_int16);
    pwla_nlf_bf16(tdm1, tdm2, lut_ab, lut_cd, LayerParamPtr->num_elements);
    #if DEBUG
    int col_id = get_coreid() >> 16;
    int row_id = get_coreid() & 0xf;
    if(col_id == 0 && row_id == 2) {
        //for (int i=0;i<8;i=i+=2)
        //    printf("qdq addr: %u, content: 0x%x\n", (static_cast<uint16_t*>(qdq_lut))+i, *(static_cast<int16_t*>(qdq_lut)+i));
        printf("core addr: %u\n", LayerParamPtr->core_addr);
        printf("tdm1 addr: %u\n", LayerParamPtr->tdm1_addr);
        printf("tdm2 addr: %u\n", LayerParamPtr->tdm2_addr);
				printf("LUTAB_Addr: %u\n", LayerParamPtr->lutab_addr);
				printf("LUTCD_Addr: %u\n", LayerParamPtr->lutcd_addr );
				printf("num_elements: %u\n", LayerParamPtr->num_elements );
				printf("q_zero_point: %u\n", *q_zero_point);
				printf("q_scale: %f\n", (float)*q_scale);
				printf("dq_zero_point: %u\n", *dq_zero_point);
				printf("dq_scale: %f\n", (float)*dq_scale);
				printf("fused_gelu: %d \n", LayerParamPtr->fused_op);
				printf("dq_in addr: %u\n", dq_in );
				printf("dq_enable: %u\n", *dq_enable );
				printf("q_enable: %u\n", *q_enable );

				v16bfloat16* qdq_ptr = (v16bfloat16*) qdq_lut;
				chess_report(*(qdq_ptr));

        for(int idx = 0; idx < 64; idx++) {
						v16bfloat16* buf_ptr = (v16bfloat16*) dq_in;
						chess_report(*(buf_ptr+idx));
           //printf("silu_dq_in[%d]=%u\n", idx, (float)*((int8_t*)(dq_in+idx*2));
           //printf("silu_dq_in[%d]=%u\n", idx, *((static_cast<int8_t*>(dq_in))+idx));
				}
        for(int idx = 0; idx < 64; idx++) {
						v16bfloat16* tdm1_buf_ptr = (v16bfloat16*) tdm1;
						chess_report(*(tdm1_buf_ptr+idx));
            //printf("silu_dq_out[%d]=%f\n", idx, (float)*((bfloat16*)(tdm1+idx*2)) );
				}

        for(int idx = 0; idx < 64; idx++) {
						v16bfloat16* tdm2_buf_ptr = (v16bfloat16*) tdm2;
						chess_report(*(tdm2_buf_ptr+idx));
            //printf("silu_q_out[%d]=%f\n", idx, (float)*((bfloat16*)(tdm2+idx*2)) );
				}
    }
    #endif
    quant_bf16_to_int16(tdm2, output, LayerParamPtr->num_elements, *q_zero_point, *q_scale, is_output_signed, (bool) *q_enable, (bool) is_out_int16);

		#if DEBUG
		col_id = get_coreid() >> 16;
		row_id = get_coreid() & 0xf;
		if(col_id == 0 && row_id == 2) {
			for(int idx = 0; idx < 64; idx++) {
						v16bfloat16* output_buf_ptr = (v16bfloat16*) output;
						chess_report(*(output_buf_ptr+idx));
            //printf("silu_out[%d]=%f\n", idx, (float)*((bfloat16*)(output+idx*2)) );
				}
		}
		#endif


}
#endif
