#ifndef DIRECT_CONV_INT16x8_GENERIC_GEMM_WRAPPER
#define DIRECT_CONV_INT16x8_GENERIC_GEMM_WRAPPER

#include "direct_conv_int16x8_generic_kernel.c"

void unpack_weights_int4_int8(int8_t* src, int8_t* restrict dst, int sign, int num_elems){
    auto in_ptr = (v64int4* ) src;
    auto out_ptr = (v64int8* ) dst;
    // NOTE: min no. of elems = 64
    //dims_2d_t dims( num_elems/64 - 1, 1, 1 - num_elems/64 );
    //for(int i = 0; i < std::max(8,(num_elems/64)); i++)
    for(int i = 0; i < (num_elems/64); i++)
    chess_prepare_for_pipelining
    chess_loop_range( 8, )
    {
        *out_ptr++ = v64int8(unpack(*in_ptr++, sign));
        //out_ptr = add_2d_ptr( out_ptr, dims );
    }
}

void __attribute__((noinline)) run_a16w8_gemm_generic(KernelArgs& args, bool enable_qdq)
{
    set_sat();
    set_rnd(rnd_conv_even);
    struct alignas(4) LayerParams
    {
        uint8_t zero_init;
        uint8_t mode;
        // NOTE wgt_size is no. of elements per subvol
        uint16_t wgt_size;
        uint16_t n_elems;
        uint16_t qdq_addr;
        uint16_t ifmsum_addr;
        uint16_t tdm1_addr;
        uint16_t tdm2_addr;
        KernelParams kernel_params;
        uint16_t int4_wgt;
        uint16_t weights_unpack_addr;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    KernelParams& kernel_params = layer_params->kernel_params;
    int zero_init = layer_params->zero_init;
    void* qdq_param_addr = conv_to_local_ptr(layer_params->qdq_addr);
    // NOTE: Convert no. of weight elems to bytes to calc. bias offset
    bool int4_wgt = (layer_params->int4_wgt == 1);

    alignas(64) static QDQKernelParams gemm_a16w8_qdq_params;
    v64uint8 * pI = (v64uint8*)qdq_param_addr;
    v64uint8 * restrict pO = (v64uint8*)(&gemm_a16w8_qdq_params);
    pO[0] = pI[0];

    chess_memory_fence();

    kernel_params.shift_res = gemm_a16w8_qdq_params.shift_Qout;
    kernel_params.shift_tdm = gemm_a16w8_qdq_params.shift_tdm;

    // NOTE: Mode0: layer_params->mode == 0 => MatA broadcast, MatB Unicast
    // NOTE: Mode1: layer_params->mode == 1 => MatA Unicast, MatB broadcast
    int16_t* matA = (layer_params->mode == 0) ? static_cast<int16_t*>(args.s2mm_ch0_data) : static_cast<int16_t*>(args.s2mm_ch1_data);
    int8_t* matB = (layer_params->mode == 0) ? static_cast<int8_t*>(args.s2mm_ch1_data) : static_cast<int8_t*>(args.s2mm_ch0_data);
    int16_t* output = static_cast<int16_t*>(args.mm2s_ch0_data);
    int64_t* c0 = reinterpret_cast<int64_t*>(byte_incr(matB, layer_params->wgt_size));
    int8_t* weight_unpack = (int4_wgt) ? static_cast<int8_t*>(conv_to_local_ptr(layer_params->weights_unpack_addr)) : matB;
    int32_t * ifm_sum = static_cast<int32_t*>(conv_to_local_ptr(layer_params->ifmsum_addr));
    int32_t * tdm1 = static_cast<int32_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int32_t * tdm2 = static_cast<int32_t*>(conv_to_local_ptr(layer_params->tdm2_addr));

    if(int4_wgt) {
        unpack_weights_int4_int8(matB, weight_unpack, kernel_params.ctrl.sign_W, layer_params->n_elems);
        // NOTE: Update matB pointer to the new unpacked weights
        matB = weight_unpack;
    }

    direct_conv_int16x8_generic(
        matA, matB, weight_unpack, tdm1, tdm2, ifm_sum, nullptr,
        c0, gemm_a16w8_qdq_params.c1, gemm_a16w8_qdq_params.c2,
        output, kernel_params, kernel_params.op_mode, zero_init, enable_qdq, 0, gemm_a16w8_qdq_params.Vec_coeffs);

}

void run_a16w8_gemm_tdm(KernelArgs& args)
{
    run_a16w8_gemm_generic(args, false);
}

void run_a16w8_gemm_qdq(KernelArgs& args)
{
    run_a16w8_gemm_generic(args, true);

}
#endif // DIRECT_CONV_INT16x8_GENERIC_GEMM_WRAPPER
