#ifndef __WRAPPER_QDQ_HPP__
#define __WRAPPER_QDQ_HPP__
#include <adf.h>
#include "qdq_helpers.hpp"

template<typename T>
#if __AIE_ARCH__ <= 20
//typedef qdq_C_t qdq_C_tt;
//typedef qdq_I_t qdq_I_tt;
//typedef qdq_O_t qdq_O_tt;
using qdq_WA_type = typename std::conditional_t<std::is_same_v<T, bfloat16>, int16, T>;
#else
using qdq_WA_type = typename std::conditional_t<std::is_same_v<T, bfloat16>, int16, std::conditional_t<std::is_same_v<T, float>, int32, T>>;
//using qdq_C_tt = typename std::conditional_t<std::is_same_v<qdq_C_t,float>,int32,qdq_C_t>;
//using qdq_I_tt = typename std::conditional_t<std::is_same_v<qdq_I_t,float>,int32,qdq_I_t>;
//using qdq_O_tt = typename std::conditional_t<std::is_same_v<qdq_O_t,float>,int32,qdq_O_t>;
#endif
using qdq_C_tt = qdq_WA_type<qdq_C_t>;
using qdq_I_tt = qdq_WA_type<qdq_I_t>;
using qdq_O_tt = qdq_WA_type<qdq_O_t>;


template<typename Ta, typename Ts, unsigned Mgran, unsigned Kgran, unsigned Mtile, unsigned Ktile, bool has_transpose=0, unsigned inner_lr_min=6, unsigned outer_lr_min=3>
__attribute__((noinline)) void sum_top( adf::input_buffer<Ta> &ifm, adf::output_buffer<Ts> &ofm, unsigned address_config=0 );

template<typename Ta, typename Ts, unsigned Mgran, unsigned Kgran, unsigned Mtile, unsigned Ktile, unsigned inner_lr_min=6, unsigned outer_lr_min=3>
__attribute__((noinline)) void sum_top_scale( adf::input_buffer<Ta> &ifm, adf::output_buffer<Ts> &ofm, unsigned address_config=0 );

template<typename Ta, typename Ts, unsigned Mgran, unsigned Kgran, unsigned Mtile, unsigned Ktile, unsigned inner_lr_min=6, unsigned outer_lr_min=3>
__attribute__((noinline)) void sum_top_int16( adf::input_buffer<Ta> &ifm, adf::output_buffer<Ts> &ofm, unsigned address_config=0 );

template<typename Ti, typename Tq, typename Tq0, unsigned Ngran, unsigned Ntile, unsigned vector_coeffs=2, unsigned lr_min=4, unsigned coeff_step=qdq_coeffs, unsigned coeff_skip=2, unsigned fp_accuracy_mode=1>
void wrapper_sum_to_c0( adf::input_buffer<Ti> &sum_out, adf::input_buffer<Tq> &coeffs_in, adf::output_buffer<Tq> &coeffs_out);//, QDQKernelParams& kparam );

template<typename Ti, typename Tq, typename Tq0, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned vector_coeffs=2, unsigned coeff_step=qdq_coeffs, unsigned fp_accuracy_mode=1, unsigned fp_split_threshold=1>
void wrapper_sym( adf::input_buffer<Ti> &ifm, adf::input_buffer<Tq> &coeffs_c, adf::output_buffer<Tr> &ofm );

template<typename Ti, typename Tq, typename Tq0, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned vector_coeffs=2, unsigned coeff_step=qdq_coeffs, unsigned fp_accuracy_mode=1, unsigned fp_split_threshold=1, unsigned terms=3>
void wrapper_asym( adf::input_buffer<Ti> &ifm, adf::input_buffer<Ti> &ifm_sum, adf::input_buffer<Tq> &coeffs_c, adf::output_buffer<Tr> &ofm);//, QDQKernelParams& kparams);

template<typename Ti, typename Tq, typename Tq0, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned vector_coeffs=2, unsigned coeff_step=qdq_coeffs, unsigned fp_accuracy_mode=1, unsigned fp_split_threshold=1, unsigned terms=3>
void wrapper_asym( adf::input_buffer<Ti> &ifm, adf::input_buffer<Ti> &ifm_sum, adf::input_buffer<Tq> &coeffs_c, adf::output_buffer<Tr> &ofm, KernelArgs& args);


template<typename Ts, typename Tq, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned fp_split_threshold=1, unsigned fp_accuracy_mode=1, typename T1=int32>
void wrapper_sym_conv( adf::input_buffer<T1> &ifm, adf::input_buffer<qdq_C_tt> &coeffs, adf::output_buffer<Tr> &ofm );

template<typename Ts, typename Tq, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned fp_split_threshold=1, unsigned fp_accuracy_mode=1>
void wrapper_asym_conv( adf::input_buffer<int32> &ifm, adf::input_buffer<Ts> &ifm_sum, adf::input_buffer<qdq_C_tt> &coeffs, adf::output_buffer<Tr> &ofm );

template<typename Ti, typename Tq, typename Tq0, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned vector_coeffs=2, unsigned coeff_step=qdq_coeffs, unsigned fp_accuracy_mode=1, unsigned fp_split_threshold=1>
void wrapper_sym_wa( adf::input_buffer<qdq_I_tt> &ifm, adf::input_buffer<qdq_C_tt> &coeffs_c, adf::output_buffer<qdq_O_tt> &ofm );

template<typename Ti, typename Tq, typename Tq0, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned vector_coeffs=2, unsigned coeff_step=qdq_coeffs, unsigned fp_accuracy_mode=1, unsigned fp_split_threshold=1, unsigned terms=3>
void wrapper_asym_wa( adf::input_buffer<qdq_I_tt> &ifm, adf::input_buffer<Ti> &ifm_sum, adf::input_buffer<qdq_C_tt> &coeffs_c, adf::output_buffer<qdq_O_tt> &ofm );

template<typename Ti, typename Tq, typename Tq0, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned vector_coeffs=2, unsigned coeff_step=qdq_coeffs, unsigned fp_accuracy_mode=1, unsigned fp_split_threshold=1, unsigned terms=3>
void wrapper_asym_wa( adf::input_buffer<qdq_I_tt> &ifm, adf::input_buffer<Ti> &ifm_sum, adf::input_buffer<qdq_C_tt> &coeffs_c, adf::output_buffer<qdq_O_tt> &ofm, KernelArgs& args);

#endif //__WRAPPER_QDQ_HPP__
