#ifndef DIRECT_CONV_INT8x8_GENERIC_WRAPPER
#define DIRECT_CONV_INT8x8_GENERIC_WRAPPER

#include "direct_conv_int8x8_generic_impl.hpp"

namespace A8W8 {
    struct QDQKernelParams {
        int64 c0;
        int32 c1;
        int32 c2;
        int32 c3;
        int32 M;
        int32 N;
        int32 shift_Qb;
        int32 shift_res;
        int32 shift_tdm;
        int32 Vec_coeffs;
    };
}

void run_gemm_a8w8(KernelArgs& args)
{
    struct alignas(4) LayerParams
    {
        uint8_t zero_init;
        uint8_t final_tdm_iter;
        uint16_t wgt_size;
        uint16_t tdm1_addr;
        uint16_t tdm2_addr;
        uint16_t ifmsum_addr;
        uint16_t scratch_buf;
        uint16_t qdq_addr;
        uint16_t bufC0;
        uint16_t op_mode;
        uint16_t mode;
        uint16_t reserved;
        DirectConvInt8x8GenericKernelParams  kernel_params;
    };

    set_sat();
    set_rnd(rnd_conv_even);

    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    DirectConvInt8x8GenericKernelParams& kernel_params = layer_params->kernel_params;

    int zero_init = layer_params->zero_init;
    int final_tdm_iter = layer_params->final_tdm_iter;
    int op_mode = layer_params->op_mode;

    int8_t* input = (layer_params->mode == 0) ? static_cast<int8_t*>(args.s2mm_ch0_data) : static_cast<int8_t*>(args.s2mm_ch1_data);
    int8_t* weights = (layer_params->mode == 0) ? static_cast<int8_t*>(args.s2mm_ch1_data) : static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);
    void* qdq_addr = conv_to_local_ptr(layer_params->qdq_addr);
    
    int8_t* weights_unpack = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));
    int32_t* tdm1 = static_cast<int32_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int32_t* tdm2 = static_cast<int32_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    int32_t* ifm_sum = static_cast<int32_t*>(conv_to_local_ptr(layer_params->ifmsum_addr));
    int64_t* bufC0 = static_cast<int64_t*>(conv_to_local_ptr(layer_params->bufC0));

    alignas(64) static A8W8::QDQKernelParams gemm_a8w8_qdq_params;
    v64uint8 * pI = (v64uint8*)qdq_addr;
    v64uint8 * restrict pO = (v64uint8*)(&gemm_a8w8_qdq_params);
    pO[0] = pI[0];

    int64_t* qdq_c0 = (op_mode == 11 || op_mode == 9) ? reinterpret_cast<int64_t*>(byte_incr(weights, layer_params->wgt_size)) : bufC0;

    direct_conv_int8x8_generic<false, false, true, true>
    (
        input,
        weights, 
        weights_unpack,
        tdm1,
        tdm2,
        ifm_sum,
        qdq_c0,
        output,
        zero_init,
        final_tdm_iter, 
        op_mode,
        gemm_a8w8_qdq_params.c0,
        gemm_a8w8_qdq_params.c1,
        gemm_a8w8_qdq_params.c2,
        gemm_a8w8_qdq_params.c3,
        gemm_a8w8_qdq_params.shift_res,
        kernel_params
    );

}
#endif// DIRECT_CONV_INT8x8_GENERIC_WRAPPER
