#ifndef RUN_CONV_WRAPPER_CC
#define RUN_CONV_WRAPPER_CC

//Included conv2d function from conv2d_wrapper.cc
#include <adf.h>
#include "aie_api/aie.hpp"
#include "biased_conv_int8x8_template.h"

#include "utils.h"         // Contains some utility functions and definitions related to intrinsics usage
#include "mllib_const.h"   // Constants shared across graphs like activation type

#ifndef ADD_BIAS_IN_FIRST_ITERATION
#define ADD_BIAS_IN_FIRST_ITERATION false
#endif

#ifndef PSUM_DATASIZE_BITS
#define PSUM_DATASIZE_BITS 32
#endif

#ifdef STRIDE2_OPT
constexpr bool stride2_opt_preset = true;
#else
constexpr bool stride2_opt_preset = false;
#endif

static constexpr auto KC_DEFAULT = ( PSUM_DATASIZE_BITS == 16 )? KC_TDM16 : KC_TDM32;

template <KernelConfig KC_PREC = KC_DEFAULT,
          bool has_0_to_R=1, bool has_S50=!stride2_opt_preset, bool has_both_S=stride2_opt_preset,
          unsigned ol_lr=8, bool has_lrelu=1, bool has_relu6=0>
__attribute__(( always_inline ))
void conv2d
(
    int8 * input,
    int8 * weights,
    int8 * bias_ptr,
    int * psum_0,
    int * psum_1,
    int8 * output,
    const conv3d_params_t &arch_params,
    bool psum_start,
    bool psum_end,
    const std::optional<leakyrelu_kernel_params_t> relu_params = std::nullopt
) {

    event0( );

    constexpr bool bias_first_iter = ADD_BIAS_IN_FIRST_ITERATION;

    if ( psum_end ) {
        if ( has_0_to_R && psum_start ) {
            biased_conv_int8x8_template<KC_ZERO, KC_RESULT8, ol_lr, 8, 1, 1, 1, has_lrelu, has_relu6>(
                    input,
                    weights,
                    bias_ptr,
                    psum_0,
                    psum_1,
                    output,
                    arch_params,
                    1,
                    1,
                    relu_params
                );
        } else if ( !has_S50 || ( has_both_S && arch_params.str_w == 2 )) {
            biased_conv_int8x8_template<KC_PREC, KC_RESULT8, ol_lr, 8, !bias_first_iter, 1, 0, has_lrelu, has_relu6>(
                    input,
                    weights,
                    bias_ptr,
                    psum_0,
                    psum_1,
                    output,
                    arch_params,
                    psum_start,
                    1,
                    relu_params
                );
        } else {
            biased_conv_int8x8_template<KC_PREC, KC_RESULT8, ol_lr, 8, !bias_first_iter, !has_both_S, 1, has_lrelu, has_relu6>(
                    input,
                    weights,
                    bias_ptr,
                    psum_0,
                    psum_1,
                    output,
                    arch_params,
                    psum_start,
                    1,
                    relu_params
                );
        }
    } else {
        if ( !has_S50 || ( has_both_S && arch_params.str_w == 2 )) {
            biased_conv_int8x8_template<KC_PREC, KC_PREC, ol_lr, 8, bias_first_iter, 1, 0>(
                    input,
                    weights,
                    bias_ptr,
                    psum_0,
                    psum_1,
                    output,
                    arch_params,
                    psum_start,
                    psum_start
                );
        } else {
            biased_conv_int8x8_template<KC_PREC, KC_PREC, ol_lr, 8, bias_first_iter, !has_both_S, 1>(
                    input,
                    weights,
                    bias_ptr,
                    psum_0,
                    psum_1,
                    output,
                    arch_params,
                    psum_start,
                    psum_start
                );
        }
    }
    event1( );
}

void run_conv_xint8(KernelArgs& args) {
    
    auto runtime_params = static_cast<uint16_t const*>(args.params_data);
    bool zero_init = runtime_params[0];
    bool final_tdm_iter = runtime_params[1];
    int * tdm1 = (int *) conv_to_local_ptr(runtime_params[2]);
    int * tdm2 = (int *) conv_to_local_ptr(runtime_params[3]);
    
    int8_t* wgt = static_cast<int8_t*>(args.s2mm_ch1_data);
    int8_t* ofm = static_cast<int8_t*>(conv_to_local_ptr(runtime_params[6]));
    conv3d_params_t& conv1x1_params = *((conv3d_params_t *) conv_to_local_ptr(runtime_params[4]));
    int wgt_size = runtime_params[5];
    
    int ifmflag = runtime_params[7];
    int8_t* ifm;    
    if (ifmflag){
        ifm = static_cast<int8_t*>(conv_to_local_ptr(runtime_params[8]));  // For Tk = 1 //FLAG To CHECK
    }else {
        ifm = static_cast<int8_t*>(args.s2mm_ch0_data);
    }
    
    int8_t* bias = (int8_t*)(byte_incr(wgt, wgt_size));

    set_sat();
    set_rnd(rnd_conv_even);

    conv2d(
        ifm,
        wgt,
        bias,
        tdm1,
        tdm2,
        ofm,
        conv1x1_params,
        zero_init,
        final_tdm_iter
    );

}

#endif

