#ifndef __QDQADD_KERNEL_WRAPPER_C__
#define __QDQADD_KERNEL_WRAPPER_C__

#include "kernel_helpers.h"
#include "qdqadd.c"

struct qdqadd_layer_params_t{
    int32_t shft_o; //output_shift
    int64 c0;
    int32_t c1;
    int32_t c2;
    int8_t channel_v; //is in1 a channel/column vector
    int8_t ifm_sign; //ifms are signed or unsigned
};

struct tensor_CRc{

    // Inner dimension in all AIE kernels, typically 8
    uint32_t c_inner = 1;

    // Rows
    uint32_t R = 1;

    // Total number of Cols divided by c
    uint32_t C = 1;

    // Number of dimensions in the tensor
    static constexpr unsigned dims()
    {
        return 3;
    }

    // Total number of Cols
    constexpr uint32_t cols() const
    {
        return c_inner * C;
    }

    // Total number of elements
    constexpr uint32_t size() const
    {
        return c_inner * R * C;
    }

};


qdqadd_params convert_to_qdqadd_params (tensor_CRc ifmsv_dim, qdqadd_layer_params_t layer_params) {
    qdqadd_params param = { 0 };

    const unsigned ocg = 8;
    const unsigned oxg = 4;
    const unsigned x_replication = 4;
    param.shft_o = layer_params.shft_o;
    param.c0 = (int64) layer_params.c0;
    param.c1 = layer_params.c1;
    param.c2 = layer_params.c2;
    param.inner_loop = (ifmsv_dim.R/oxg) * ifmsv_dim.cols()/ocg;
    param.num_ox = layer_params.channel_v == 1 ? (ifmsv_dim.R/oxg)  : param.inner_loop;
    param.num_oc = layer_params.channel_v == 1 ? (ifmsv_dim.cols()/ocg)  : 1;
    param.step_ci = layer_params.channel_v == 1 ? ocg * sizeof(int16_t) * x_replication :  0;
    param.step_xi = layer_params.channel_v == 1 ? 0 :  ocg * oxg * sizeof(int16_t);
    param.step_reset = layer_params.channel_v == 1 ? 0 : 0;
    param.channel_v = layer_params.channel_v == 1;
    param.ifm_sign = layer_params.ifm_sign;
    
    return param;
}


void debug_qdqadd_params(qdqadd_params param) {
    printf("param debug helper \n");
    printf("Shift      = %d \n", param.shft_o);
    printf("C0         = %lld \n", param.c0);
    printf("C1         = %d \n", param.c1);
    printf("C2         = %d \n", param.c2);

    printf("inner_loop = %d \n", param.inner_loop);
    printf("step_ci    = %d \n", param.step_ci);
    printf("step_xi    = %d \n", param.step_xi);
    printf("step_reset = %d \n", param.step_reset);
    printf("num_ox     = %d \n", param.num_ox);
    printf("num_oc     = %d \n", param.num_oc);
    printf("channel_v  = %d \n", param.channel_v);
}

struct KernelParam_t 
{ 
   uint16_t op_select;
	uint16_t Msubv;
	uint16_t Nsubv;
   uint16_t qdq_addr;
   uint16_t scratch;
};

void run_qdqadd_a16a16(KernelArgs& args){
#if 0
    printf("starting kernel!!\n");
#endif
    set_rnd(rnd_conv_even);
    set_sat();

    KernelParam_t* kernelParamPtr = static_cast<KernelParam_t*>(args.params_data);

    qdqadd_layer_params_t* layerParam = static_cast<qdqadd_layer_params_t*>(conv_to_local_ptr(kernelParamPtr->qdq_addr));
/*
    layerParam.shft_o = 4;
    layerParam.c0 = 1;
    layerParam.c1 = 2;
    layerParam.c2 = 4;
    layerParam.channel_v = 0;
    layerParam.ifm_sign = 1;
*/
    tensor_CRc crc;
    crc.c_inner = 8;
    crc.C = kernelParamPtr->Nsubv / crc.c_inner;
    crc.R = kernelParamPtr->Msubv;

    qdqadd_params kp = convert_to_qdqadd_params(crc, *layerParam);

    int matadd_num_elems = kernelParamPtr->Msubv * kernelParamPtr->Nsubv;

    int8_t* matA = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* matB = byte_incr(matA, matadd_num_elems  * 2);
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);
#if 0
    int const col_idx = (get_coreid() >> 16);
    int const row_idx = (get_coreid() & 0xF);
    if (col_idx == 0 && row_idx == 2) {
      printf("Running qdqadd!!\n");
      debug_qdqadd_params(kp);
    }
#endif
    qdqadd<9, false>((int*)matA, (int*)matB, kp, (int*)output); 
}

#endif
