#ifndef __WRAPPER_QDQ_CC__
#define __WRAPPER_QDQ_CC__

#include <adf.h>
#include <aie_api/aie.hpp>
#include "constants.h"
#include "wrapper_qdq.hpp"
#include "qdq_sum.hpp"
#include "qdq_int16_bfloat16.hpp"
#include "qdq.cc"
#include "aie_api/utils.hpp"
#define GEMM_A8W8_IFM_SUM_ADDR 54336
#define GEMM_A8W8_TDM1_ADDR 8192
#define GEMM_A8W8_TDM2_ADDR 24576

// NOTE: Refer to overlay_gemm.py to align the QDQ params address in L1
alignas(64) static QDQKernelParams g_qdq_kernel_params;

template<typename Ti, typename Tq, typename Tq0, unsigned Ngran, unsigned Ntile, unsigned vector_coeffs, unsigned lr_min, unsigned coeff_step, unsigned coeff_skip, unsigned fp_accuracy_mode>
inline __attribute__(( always_inline )) void wrapper_sum_to_c0( adf::input_buffer<Ti> &sum_out, adf::input_buffer<Tq> &coeffs_in, adf::output_buffer<Tq> &coeffs_out ){ //, QDQKernelParams& kparam ) {
    static unsigned iteration=0;
    auto cf_out = local_buffer(coeffs_out);

#if __DEBUG__
    printf("wrapper_sum_to_c0: sum_to_co: C0 %lld, C3 %d, M: %d, N: %d, shift_qb: %d, shift_qo: %d \n", g_qdq_kernel_params.c0,
        g_qdq_kernel_params.c3, g_qdq_kernel_params.M, g_qdq_kernel_params.N,
        g_qdq_kernel_params.shift_Qb, g_qdq_kernel_params.shift_Qout);
#endif

    //Kernel can operate on same parameter set as qdq kernel.
    QDQParams param;
    param.N_g = g_qdq_kernel_params.N / Ngran;
    param.shift_Qb = g_qdq_kernel_params.shift_Qb;
    param.shift_Qout = g_qdq_kernel_params.shift_Qout;
    //address_setup_qdq( param, Mgran, Ngran, Mtile, Ntile, address_config, ifm.size( ) / 2 );


    qdq_sum_to_c0<Ti, Tq, Tq0, Ngran, Ntile, lr_min, fp_accuracy_mode>( sum_out, (Tq0)g_qdq_kernel_params.c0, (Tq)g_qdq_kernel_params.c3, cf_out, param );
    iteration++;
}

template<typename Ti, typename Tq, typename Tq0, typename Tr, unsigned Mgran, unsigned Ngran, unsigned Mtile, unsigned Ntile, unsigned vector_coeffs, unsigned coeff_step, unsigned fp_accuracy_mode, unsigned fp_split_threshold, unsigned terms_all>
inline __attribute__((always_inline)) void wrapper_asym( adf::input_buffer<Ti> &ifm, adf::input_buffer<Ti> &ifm_sum, adf::input_buffer<Tq> &coeffs_c, adf::output_buffer<Tr> &ofm,
                   int tdm_buffer_offset, AddressConfig addr_config)//, QDQKernelParams& kparams)
{
    constexpr unsigned lr_min = 6;
    constexpr unsigned terms = std::min( 3u, terms_all );
    auto ifm_Ti = (Ti*) ifm.data( );

    auto coeffs = local_buffer(coeffs_c);

    Tq0 c0 = 0;
    auto address_config = addr_config;

    QDQParams param;
    param.Y_g = 1; //( address_config >= ADDRCFG_YCXC8 ? conv_Y : 1 );
    param.M_g = g_qdq_kernel_params.M / param.Y_g / Mgran; // = Mgemm (256) / 4 rows / 8 rows_per_iter = 8
    param.N_g = g_qdq_kernel_params.N / Ngran;
    param.shift_Qb = g_qdq_kernel_params.shift_Qb;
    param.shift_Qout = g_qdq_kernel_params.shift_Qout;
    address_setup_qdq( param, Mgran, Ngran, Mtile, Ntile, address_config, tdm_buffer_offset);
#if 0 //__DEBUG__
   int rowIdx = (get_coreid() & 0xF);
    int colIdx = (get_coreid() >> 16);
   if(rowIdx == 2 && colIdx == 0)
    {
    printf("wrapper_asym: QDQKparams: C1 %d, C2 %d, M: %d, N: %d, shift_qb: %d, shift_qo: %d \n", g_qdq_kernel_params.c1, g_qdq_kernel_params.c2, g_qdq_kernel_params.M, g_qdq_kernel_params.N, g_qdq_kernel_params.shift_Qb, g_qdq_kernel_params.shift_Qout);
    }
#endif
    set_sat( );
    if constexpr( std::is_integral_v<Tq> || HAS_FLOAT ) {
        if constexpr( vector_coeffs > 1 ) {//|| __AIE_ARCH__ >= 21 ) {
            if constexpr( !std::is_integral_v<Tq> && fp_split_threshold > 1 && sizeof( Tq ) <= sizeof( Ti )) {
                assert( address_config == ADDRCFG_linear || address_config == ADDRCFG_YCXC8 );
                auto ifm_Ti = local_buffer_cast<Ti, adf::direction::in>( ifm );
                qdq_split<Ti, Tq, Tr, Mgran, Ngran, Mtile, Ntile, 32, terms, lr_min, fp_split_threshold, fp_accuracy_mode>( ifm_Ti, ifm_sum, coeffs, ofm, param );
            } else
                qdq<Ti, Tq, Tq0, Tr, Mgran, Ngran, Mtile, Ntile, terms, lr_min, coeff_step, 0, fp_accuracy_mode>( ifm_Ti, ifm_sum, coeffs, ofm, param );
        } else if constexpr( vector_coeffs == 1 || terms > 2 ) {
            auto coeffs_Tq0 = local_buffer_cast<Tq0,adf::direction::in>( coeffs );
            qdq<Mgran, Ngran, Mtile, Ntile, terms, lr_min, coeff_step * sizeof( Tq ) / sizeof( Tq0 ), fp_accuracy_mode>( ifm_Ti, ifm_sum, coeffs_Tq0, g_qdq_kernel_params.c1, g_qdq_kernel_params.c2, ofm, param );
        } else {
            static_assert( Ngran == Ntile );
            //qdq<Mgran, Mtile, Ntile, lr_min, fp_accuracy_mode>( ifm_Ti, c0, kparams.c2, ofm, param );
        }
    }
}

template<typename Ti, typename Tq, typename Tq0, typename Tr>
void qdq_tdm32_to_int8(
    void* tdm_data, int tdm_size, int tdm_buffer_offset, AddressConfig addr_config,
    void* ifm_sum_data, int ifm_sum_size,
    void* coeffs_data, int coeffs_size,
    void* ofm_data, int ofm_size
    )
{
    int const Mgran = 4;
    int const Ngran = 8;
    int const Mtile = 4;
    int const Ntile = 8;
    int const vector_coeffs = 1;
//     memcpy((void*)&g_qdq_kernel_params, (void*)(byte_incr(coeffs_data, coeffs_size)), sizeof(QDQKernelParams));
#if __DEBUG__
    printf("qdq_tdm32_to_int8: C0: %lld, C1: %d, C2: %d, C3: %d \n", g_qdq_kernel_params.c0, g_qdq_kernel_params.c1, g_qdq_kernel_params.c2, g_qdq_kernel_params.c3);
#endif

    adf::input_buffer<Ti> tdm(static_cast<Ti*>(tdm_data), tdm_size, 0, tdm_size);
    adf::input_buffer<Ti> ifm_sum(static_cast<Ti*>(ifm_sum_data), ifm_sum_size, 0, ifm_sum_size);
    adf::input_buffer<Tq> coeffs(static_cast<Tq*>(coeffs_data), coeffs_size, 0, coeffs_size);
    adf::output_buffer<Tr> ofm(static_cast<Tr*>(ofm_data), ofm_size, 0, ofm_size);

    wrapper_asym<Ti, Tq, Tq0, Tr,
                 Mgran, Ngran, Mtile, Ntile, vector_coeffs,
                 2, 1, 1, 3>(
        tdm, ifm_sum, coeffs, ofm, tdm_buffer_offset, addr_config
    );
}

template<int N=64>
void sum_c0(void* ifm_sum_data, const int64_t C0, int32_t C1, void* coeff_out)
{
    const int Ntile = 16;
    alignas(64) int64_t C0_4[4];
    v16int32* pIn = (v16int32*)ifm_sum_data;
    v16int32 c1_b = broadcast_to_v16int32(C1);
    //v4acc64 c0_4 = undef_v4acc64();
    v8acc64 c0_8 = undef_v8acc64();
    v16acc64 c0_16 = undef_v16acc64();
    v16acc64* pOut = (v16acc64*)coeff_out;
    for(int i =0; i < 4; i++){// did not find a suitable broadcast intrinsic
       C0_4[i] = C0;
    }
    //c0_4 = *(v4acc64*)C0_4;
    c0_8 = concat(*(v4acc64*)C0_4, *(v4acc64*)C0_4);
    c0_16 = insert(c0_16, 0, c0_8);//aie::concat(c0_8, c0_8);
    c0_16 = insert(c0_16, 1, c0_8);//aie::concat(c0_8, c0_8);
    for(int n=0; n< (N/Ntile); n++){
        v16int32 pIn_0 = *pIn++;
        auto out = aie::mac(aie::accum<acc64,16>(c0_16), aie::vector<int32, 16>(c1_b),aie::vector<int32, 16>(pIn_0));
        *pOut++ = out;

    }
        #ifdef __DEBUG__
            int rowIdx = (get_coreid() & 0xF);
            int colIdx = (get_coreid() >> 16);
            if(rowIdx == 2 && colIdx ==0){
                //aie::print(out, true, "C0_vec");
                print_buf<int64_t>(C0_4, 4, 4);
                aie::print(aie::accum<acc64, 16>(c0_16), true, "c0_16");
            }
        #endif
}

template<typename Ti, typename Tq0>
void __attribute__((noinline)) run_sum_c0(void* ifm_sum_data, const void* coeff_data, void* coeff_out_data, int N = 64)
{
    int const Ngran = 8;
    int const Ntile = 8;
    int const coeff_size = 64;
    adf::input_buffer<Ti>ifm_sum(static_cast<Ti*>(ifm_sum_data), (N*sizeof(Tq0)), 0, (N*sizeof(Tq0)) );
    adf::input_buffer<Ti>coeff ((Ti*)(coeff_data), (coeff_size*sizeof(Tq0)), 0, (coeff_size*sizeof(Tq0)) );
    adf::output_buffer<Ti>ofm(static_cast<Ti*>(coeff_out_data), (N*sizeof(Tq0)), 0, (N*sizeof(Tq0)) );
#if __DEBUG__
    //printf("run_sum_c0: ifm_sum_data %d ifm_sum_size %d coeffs_data %d coeffs_size %d ofm_data %d , ofm_size %d \n", *(int32_t*)ifm_sum_data , (N*sizeof(Tq0)), coeff_data, (coeff_size*sizeof(Tq0)),  coeff_out_data, (N*sizeof(Tq0)) );
    aie::print(aie::vector<int32, 32>(*(v32int32*)ifm_sum_data), true, "sum to c0: ");
#endif

   //Kernel can operate on same parameter set as qdq kernel.
    QDQParams param;
    param.N_g = g_qdq_kernel_params.N / Ngran;//1;
    param.shift_Qb = g_qdq_kernel_params.shift_Qb;
    param.shift_Qout = g_qdq_kernel_params.shift_Qout;
    //address_setup_qdq( param, Mgran, Ngran, Mtile, Ntile, address_config, ifm.size( ) / 2 );


    qdq_sum_to_c0<Ti, Ti, Tq0, Ngran, Ntile, 4, 1>( ifm_sum, (Tq0)g_qdq_kernel_params.c0, (Ti)g_qdq_kernel_params.c3, ofm, param );
    //int const vector_coeffs = 1;
    //chess_separator_scheduler(4);
    //wrapper_sum_to_c0<int32, int32, int64_t, Ngran, Ntile, vector_coeffs, 4, 1, 0, 1>( ifm_sum, coeff, ofm);//, qdqparams);
    //sum_c0<64>(ifm_sum_data, (int64_t)g_qdq_kernel_params.c0, g_qdq_kernel_params.c3, coeff_out_data);
    //chess_separator_scheduler(4);
}

// function for 3-term actxact qdq
template<typename Ti=int32_t, typename Tq=int32_t, typename Tq0=int64_t, typename Tr=uint8_t, int N=64>
void run_gemm_qdq_asym(void* tdm, const int32_t tdm_size,
                        void* ifm_sum, const int32_t ifm_sum_size,
                        void* wgt_sum, const int32_t wgt_sum_size,
                        void* qdq_param_addr,
                        void* out, const int32_t out_size)
{

    // qdq params are at the end of coeffs
    //memcpy(static_cast<void*>(&g_qdq_kernel_params), qdq_param_addr, sizeof(QDQKernelParams));
    v64uint8 * pI = (v64uint8*)qdq_param_addr;
    v64uint8 * restrict pO = (v64uint8*)(&g_qdq_kernel_params);
    pO[0] = pI[0];

    chess_memory_fence( );
    //sum_c0<64>(wgt_sum, (int64_t)g_qdq_kernel_params.c0, g_qdq_kernel_params.c3, conv_to_local_ptr(C0_ADDR));
    run_sum_c0<Ti,Tq0>(wgt_sum, NULL, conv_to_local_ptr(C0_ADDR), 64);
#if 0 //__DEBUG__
    int rowIdx = (get_coreid() & 0xF);
    int colIdx = (get_coreid() >> 16);
    v32acc64* vtmp3  = (v32acc64*)conv_to_local_ptr(C0_ADDR);
    if(rowIdx == 2 && colIdx == 0){
        aie::print(aie::vector<int32, 32>(*(v32int32*)wgt_sum), true, "sum to c0: ");
        print_buf<int64_t>((int64_t*)vtmp3, 32, 8);
    }
#endif

    // call 3term QDQ
    qdq_tdm32_to_int8<Ti, Tq, Tq0, Tr>(
        tdm, tdm_size, 0, ADDRCFG_linear,
        ifm_sum, ifm_sum_size,
        conv_to_local_ptr(C0_ADDR), wgt_sum_size,
        out,  out_size
    );
}

//function for 3-term actxwgt qdq

void run_gemm_qdq(KernelArgs& args)
{
    uint16_t* args_params = (uint16_t*)args.params_data;
    // NOTE: args_params: 2 bytes of qdq_param_addr. Refer to overlay.py
    void* qdq_param_addr = static_cast<void*>(conv_to_local_ptr(args_params[0]));
    // NOTE: Refer to overlay_gemm.py and mmultint8x8_oloh_kernel.h
    void* ifm_sum_addr = static_cast<void*>(conv_to_local_ptr(GEMM_A8W8_IFM_SUM_ADDR));
    int rowIdx = (get_coreid() & 0xF);
    int colIdx = (get_coreid() >> 16);
    v32uint8 * pI = (v32uint8*)qdq_param_addr;
    v32uint8 * restrict pO = (v32uint8*)(&g_qdq_kernel_params);
    pO[0] = pI[0];

    qdq_tdm32_to_int8<int32_t, int32_t, int64_t, uint8_t>(
        conv_to_local_ptr(GEMM_A8W8_TDM1_ADDR), 16384, (16384 / 4), ADDRCFG_C2r1Cl1RC8,
        ifm_sum_addr, (64 * sizeof(int32)),
        byte_incr(args.s2mm_ch0_data, (128 * 64)), (64 * sizeof(int64)),
        args.mm2s_ch0_data, 4096
    );

    clr_srs_of ();
    clr_ups_of ();
}

/*
 * Standlone OP for DEQUANT
 */
void run_int16_bf16_dequant(KernelArgs& args)
{
    uint16_t const* args_params = static_cast<uint16_t const*>(args.params_data);
    // NOTE: SUBV_elemens must be a multiple of 128
    int subv_elements = args_params[0];
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(args_params[1]));
    uint16_t* dq_zp = reinterpret_cast<uint16_t*>(qdq_prm);
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 4));
    int8_t* matin = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* matout = static_cast<int8_t*>(args.mm2s_ch0_data);
    dequant_int16_to_bf16(matin, matout, subv_elements, *dq_zp, *dq_sc, false);
}

/*
 * Standlone OP for QUANT
 */
void run_bf16_int16_quant(KernelArgs& args)
{
    uint16_t const* args_params = static_cast<uint16_t const*>(args.params_data);
    // NOTE: SUBV_elemens must be a multiple of 128
    int subv_elements = args_params[0];
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(args_params[1]));
    uint16_t* q_zp = reinterpret_cast<uint16_t*>(qdq_prm);
    bfloat16* q_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 4));
    int8_t* matin = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* matout = static_cast<int8_t*>(args.mm2s_ch0_data);
    quant_bf16_to_int16(matin, matout, subv_elements, *q_zp,  *q_sc, false);
}

void run_int16_negative(KernelArgs& args)
{
    uint16_t const* args_params = static_cast<uint16_t const*>(args.params_data);
    // NOTE: SUBV_elements must be a multiple of 128
    int subv_elements = args_params[0];
    int8_t* matin = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* matout = static_cast<int8_t*>(args.mm2s_ch0_data);
    neg_int16(matin, matout, subv_elements);
}

void run_all_qdq_modes(
    int8_t* input_addr,
    int8_t* output_addr,
    int8_t* scratch_addr,
    int subv_elements,
    uint16_t dq_enable,
    uint16_t dq_zp,
    bfloat16 dq_sc,
    uint16_t q_enable,
    uint16_t q_zp,
    bfloat16 q_sc,
    bool is_int16,
    bool is_signed
) {
    int8_t* matA = input_addr;
    int8_t* matB = (dq_enable && q_enable && !is_int16) ? scratch_addr : output_addr;
    int8_t* matC = (dq_enable) ? matB : input_addr;
    int8_t* matD = output_addr;
    dequant_int16_to_bf16(matA, matB, subv_elements, dq_zp, dq_sc, is_signed, dq_enable, is_int16);
    quant_bf16_to_int16(matC, matD, subv_elements, q_zp, q_sc, is_signed, q_enable, is_int16);
}

struct QDQLayerParams {
    uint16_t subv_elems;
    uint16_t scratch_buffer_addr;
    uint16_t is_int16;
    uint16_t dq_zp_elem_idx;
    uint16_t dq_sc_elem_idx;
    uint16_t q_zp_elem_idx;
    uint16_t q_sc_elem_idx;
    uint16_t dq_enable_idx;
    uint16_t q_enable_idx;
    uint16_t is_signed;
};

void run_combined_qdq_a8(KernelArgs& args)
{

    int8_t* input = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);
    int32_t* qdq_prm = static_cast<int32_t*>(args.s2mm_ch1_data);

    QDQLayerParams* qdq_layer_params = static_cast<QDQLayerParams*>(args.params_data);
    int8_t* scratch_buffer = static_cast<int8_t*>(conv_to_local_ptr(qdq_layer_params->scratch_buffer_addr));
    bool is_signed = bool(qdq_layer_params->is_signed);


    uint16_t dq_zp        = reinterpret_cast<uint16_t*>(&qdq_prm[qdq_layer_params->dq_zp_elem_idx])[0];
    bfloat16 dq_sc        = reinterpret_cast<bfloat16*>(&qdq_prm[qdq_layer_params->dq_sc_elem_idx])[0];
    uint16_t q_zp         = reinterpret_cast<uint16_t*>(&qdq_prm[qdq_layer_params->q_zp_elem_idx])[0];
    bfloat16 q_sc         = reinterpret_cast<bfloat16*>(&qdq_prm[qdq_layer_params->q_sc_elem_idx])[0];
    uint16_t dq_enable    = reinterpret_cast<uint16_t*>(&qdq_prm[qdq_layer_params->dq_enable_idx])[0];
    uint16_t q_enable     = reinterpret_cast<uint16_t*>(&qdq_prm[qdq_layer_params->q_enable_idx])[0];

    run_all_qdq_modes(input, output, scratch_buffer, qdq_layer_params->subv_elems, dq_enable, dq_zp, dq_sc, q_enable, q_zp, q_sc, qdq_layer_params->is_int16, is_signed);

}

void run_combined_qdq(KernelArgs& args)
{
    uint16_t const* args_params = static_cast<uint16_t const*>(args.params_data);
    // NOTE: SUBV_elemens must be a multiple of 64
    int subv_elements = args_params[0];
    int dq_zp_elem_idx = args_params[1];
    int dq_sc_elem_idx = args_params[2];
    int q_zp_elem_idx = args_params[3];
    int q_sc_elem_idx = args_params[4];
    int dq_enable_idx = args_params[5];
    int q_enable_idx = args_params[6];
    int is_signed = args_params[7];





    int8_t* matIn = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* matOut = static_cast<int8_t*>(args.mm2s_ch0_data);
    int32_t* qdq_prm = static_cast<int32_t*>(args.s2mm_ch1_data);


    uint16_t dq_zp        = reinterpret_cast<uint16_t*>(&qdq_prm[dq_zp_elem_idx])[0];
    bfloat16 dq_sc        = reinterpret_cast<bfloat16*>(&qdq_prm[dq_sc_elem_idx])[0];
    uint16_t q_zp         = reinterpret_cast<uint16_t*>(&qdq_prm[q_zp_elem_idx])[0];
    bfloat16 q_sc         = reinterpret_cast<bfloat16*>(&qdq_prm[q_sc_elem_idx])[0];
    uint16_t dq_enable    = reinterpret_cast<uint16_t*>(&qdq_prm[dq_enable_idx])[0];
    uint16_t q_enable     = reinterpret_cast<uint16_t*>(&qdq_prm[q_enable_idx])[0];

    int8_t* matA = matIn;
    int8_t* matB = matOut;
    int8_t* matC = (dq_enable == 1) ? matOut : matIn;
    int8_t* matD = matOut;

    dequant_int16_to_bf16(matA, matB, subv_elements, dq_zp, dq_sc, is_signed, dq_enable);
    quant_bf16_to_int16(matC, matD, subv_elements, q_zp, q_sc, is_signed, q_enable);
}


#endif
