#ifndef DIRECT_CONV_INT16x8_GENERIC_WRAPPER
#define DIRECT_CONV_INT16x8_GENERIC_WRAPPER

#include "direct_conv_int16x8_generic_kernel.c"

struct alignas(4) LayerParams
{
    uint8_t zero_init;
    uint8_t run_qdq;
    uint16_t tdm1_addr;
    uint16_t tdm2_addr;
    uint16_t ifmsum_addr;
    uint16_t scratch_buf;
    uint16_t tmp_buf;
    uint16_t dummy1;
    uint16_t dummy2;
    KernelParams kernel_params;
};

int iceil(int x, int m)
{
    return ((x + m - 1) / m) * m;
}

void run_conv_a16w8_qdq(KernelArgs& args)
{
    set_sat();
    set_rnd(rnd_conv_even);

    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    KernelParams& kernel_params = layer_params->kernel_params;
    int zero_init = layer_params->zero_init;
    int run_qdq = layer_params->run_qdq;
    int tdm1_addr = layer_params->tdm1_addr;
    int tdm2_addr = layer_params->tdm2_addr;
    int ifmsum_addr = layer_params->ifmsum_addr;
    int scratch_buf = layer_params->scratch_buf;
    int tmp_buf_addr = layer_params->tmp_buf;
    // int scratch_buf = 27520;

    kernel_params.ctrl.zero_init = zero_init;
    kernel_params.ctrl.sign_A = 0;
    kernel_params.ctrl.sign_W = 0;
    kernel_params.ctrl.sign_O = 0;

    int16_t* input = static_cast<int16_t*>(args.s2mm_ch0_data);
    int8_t* weights = static_cast<int8_t*>(args.s2mm_ch1_data);
    int16_t* output = static_cast<int16_t*>(args.mm2s_ch0_data);
    // int8_t* weights_unpack = weights;
    int8_t* weights_unpack = static_cast<int8_t*>(conv_to_local_ptr(scratch_buf));
    int32_t* tdm1 = static_cast<int32_t*>(conv_to_local_ptr(tdm1_addr));
    int32_t* tdm2 = static_cast<int32_t*>(conv_to_local_ptr(tdm2_addr));
    int32_t* ifm_sum = static_cast<int32_t*>(conv_to_local_ptr(ifmsum_addr));
    int32_t* tmp_buf = static_cast<int32_t*>(conv_to_local_ptr(tmp_buf_addr));

    int const weights_size = iceil(
        (kernel_params.Co_g * 8 * kernel_params.Ky_g * kernel_params.Kx_g
        * ((kernel_params.op_mode  == OP_CONV_ASYM) ? (kernel_params.Ci_g * 8) : 1)),
        64
    );
    int const qdq_c0_size = kernel_params.Co_g * 8 * sizeof(int64_t);
    int64_t* qdq_c0 = reinterpret_cast<int64_t*>(byte_incr(weights, weights_size));
    int32_t* qdq_params = reinterpret_cast<int32_t*>(byte_incr(weights, weights_size + qdq_c0_size));
    int32_t qdq_c1 = qdq_params[0];
    int32_t qdq_c2 = qdq_params[1];
    kernel_params.shift_tdm = qdq_params[2];
    kernel_params.shift_res = qdq_params[3];
    kernel_params.zp_wght = qdq_params[4];

    direct_conv_int16x8_generic(
        input,
        weights,
        weights_unpack,
        tdm1,
        tdm2,
        ifm_sum,
        tmp_buf,
        qdq_c0,
        qdq_c1,
        qdq_c2,
        output,
        kernel_params,
        // OP_CONV_ASYM,
        OP_NONE,
        zero_init,
        run_qdq
    );
}

#endif // DIRECT_CONV_INT16x8_GENERIC_WRAPPER
