#ifndef __GEMM_QDQ_A16W2_WRAPPER_CC__
#define __GEMM_QDQ_A16W2_WRAPPER_CC__
#include "mmult_qdq_blocked_int16x2_impl.hpp"

void run_gemm_int16x2(KernelArgs& args)
{
    GemmInt16x2Blocked* params = static_cast<GemmInt16x2Blocked*>(args.params_data);
    uint8_t zero_init = params->zero_init;
    uint8_t final_tdm_iter = params->final_tdm_iter;

    int16_t* input = (params->mode == 0) ? static_cast<int16_t*>(args.s2mm_ch0_data) : static_cast<int16_t*>(args.s2mm_ch1_data);
    int8_t* weights = (params->mode == 0) ? static_cast<int8_t*>(args.s2mm_ch1_data) : static_cast<int8_t*>(args.s2mm_ch0_data);
    int16_t* output = static_cast<int16_t*>(args.mm2s_ch0_data);
    int64_t* coeffs = reinterpret_cast<int64_t*>(byte_incr(weights, params->wgt_size));
    int8_t* zp = reinterpret_cast<int8_t*>(byte_incr(weights, params->wgt_size + params->bias_size));
    int32_t* sw = reinterpret_cast<int32_t*>(byte_incr(weights, params->wgt_size + params->bias_size + params->zp_size));
    int8_t* weight_unpack = static_cast<int8_t*>(conv_to_local_ptr(params->weights_unpack_addr));
    int32_t * tdm1 = static_cast<int32_t*>(conv_to_local_ptr(params->tdm1_addr));
    int32_t * tdm2 = static_cast<int32_t*>(conv_to_local_ptr(params->tdm2_addr));
    int32_t * tdm1s = static_cast<int32_t*>(conv_to_local_ptr(params->tdm1s_addr));
    int32_t * tdm2s = static_cast<int32_t*>(conv_to_local_ptr(params->tdm2s_addr));
    GemmInt16x2_QDQ_Params* qdq_params = static_cast<GemmInt16x2_QDQ_Params*>(conv_to_local_ptr(params->qdq_addr));

    mmult_qdq_blocked_int16x2<1, true, 2>(
        input,
        weights,
        weight_unpack,
        tdm1,
        tdm2,
        tdm1s,
        tdm2s,
        coeffs,
        zp,
        sw,
        output,
        zero_init,
        final_tdm_iter,
        *params,
        *qdq_params
    );
}
#endif // __GEMM_QDQ_A16W2_WRAPPER_CC__