#include "./gemm/stx_gemm_int16xint16_w4_tdm.cc"
#include "./nonlinear/softmax_bf16x16/softmax_bf16x16_kernel.c"
#include "./qdq/wrapper_qdq.cc"
#include "./conv/direct_conv_int16x8_generic/direct_conv_int16x8_generic_kernel.c"
#include "./nonlinear/glbsum.h"


namespace MHA_3p0_qdq
{
    struct alignas(4) __attribute__ ((__packed__)) LayerParams
    {
        uint8_t ucast_is_s2mm1; //uint8_t mha_mode;
        uint8_t multi_core;
        uint8_t col_id;
        uint8_t row_id;
        uint16_t Msubv;
        uint16_t Ksubv;
        uint16_t Nsubv;
        uint16_t Lsubv;
        uint16_t Sin_kv;
        uint16_t tdm1_addr;
        uint16_t tdm2_addr;
        uint16_t qdq_prm;
        uint16_t act1_sum;
        uint16_t act2_sum;
        uint16_t bufC0_K;
        uint16_t bufC0_V;
        uint16_t scratch_buf;
        KernelParams kernel_params;
    };

    int MUL_DQ_offset   =   0;
    int MUL_Q_offset    =  64;
    int QKt_offset      = 128;
    int SMxV_offset     = 192;
    int DQ_offset       = 256;
    int Q_offset        = 320;
    int None_offset     =  -1;
};
// call this to obtain bufC0 for activation from its activation sum
void perform_activation_preprocess
(
    int8_t* matAct,                // Input Activation, Transposed reult written back here, if needed
    int8_t* qdq_prm,               //
    int32_t* act2_sum,             //
    int8_t* bufC0,                 //
    int8_t* scratch_buf,           //
    bool transpose_act,            // flag to indicate whether to perform block transpose on activation subvolume
    bool spatial_reduce_common,    // flag to indicate whether common dimension is spatially splitted
    int QdqPrm_offset,             // Qdq parameter byte-offset
    int num_elems_common_axis,     // common dimension , dimension reduced, K_SUBV in K, N_SUBV in V
    int num_elems_noncommon_axis,  // Non-common dimension, N_SUBV in K, L_SUBV in V
    int aie_col_id,                //
    int aie_row_id,                //
    KernelParams& sum_params       //
)
{
    const bool acc_init = true;
    const bool last_tdm_iter = true;
    using Ta = int16_t;

    KernelControl ctrl = sum_params.ctrl;
    ctrl.zero_init = 1;
    ctrl.sign_A = 0;
    ctrl.sign_W = 0;
    ctrl.sign_O = 0;

    sum_params.Ci_g = num_elems_common_axis / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8
    sum_params.step_Ci = num_elems_noncommon_axis * 8 * sizeof(Ta);
    sum_params.X_g = num_elems_noncommon_axis / 8;
    direct_conv_int16x8_generic((int16_t*)matAct, scratch_buf, scratch_buf, act2_sum, act2_sum, sum_params, acc_init, last_tdm_iter, OP_SUM);

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Perform Spatial Reduction if needed, using bufC0 as tmpad. (In-place)
    ////////////////////////////////////////////////////////////////////////////////////////////////////
    if(spatial_reduce_common)
    {
        // global reduced result write back to act2sum. bufC0 as scratch-pad.
        // num_elem(act2_sum) == num_elems_noncommon_axis
        global_reduce_sum_int32xint32((int8_t*)act2_sum, (int8_t*)act2_sum, scratch_buf, (num_elems_noncommon_axis / 16), aie_col_id%4, aie_row_id+2);
    }

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Calculate C0 from actsum
    ////////////////////////////////////////////////////////////////////////////////////////////////////
    {
        v64uint8* pI = (v64uint8*)byte_incr(qdq_prm, QdqPrm_offset);
        v64uint8* restrict pO = (v64uint8*)(&g_qdq_kernel_params);
        pO[0] = pI[0];
        chess_memory_fence();
    }
    run_sum_c0<int32, int64>(act2_sum, nullptr, bufC0, num_elems_noncommon_axis);

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Perform Activation block transpose, if needed. (In-place)
    ////////////////////////////////////////////////////////////////////////////////////////////////////
    v32int16 chess_storage(DM_bankA)* ptrO = (v32int16 chess_storage(DM_bankA)*) matAct;
    v32int16 chess_storage(DM_bankA)* ptrI = (v32int16 chess_storage(DM_bankA)*) matAct;
    if(transpose_act)
    {
        for (int ind = 0; ind < num_elems_common_axis * num_elems_noncommon_axis / 64; ind++)
        chess_no_hw_loop
        chess_prepare_for_pipelining
        chess_loop_range(4, )
        {
            v32int16 w0 = *ptrI++;
            v32int16 w1 = *ptrI++;

            int w_mode = T16_8x8_lo;
            *ptrO++ = shuffle(w0, w1, w_mode);
            *ptrO++ = shuffle(w0, w1, w_mode + 1);
        }
    }

}

void run_act_K_preprocess(KernelArgs& args)
{
    MHA_3p0_qdq::LayerParams* layer_params = static_cast<MHA_3p0_qdq::LayerParams*>(args.params_data);
    int8_t* matK = (layer_params->ucast_is_s2mm1)? (int8_t*)args.s2mm_ch1_data : (int8_t*)args.s2mm_ch0_data;  // assuming uni-cast data from s2mm_1

    int8_t * qdq_prm      = static_cast<int8_t*> (conv_to_local_ptr(layer_params->qdq_prm));
    int8_t * bufC0_K      = static_cast<int8_t*> (conv_to_local_ptr(layer_params->bufC0_K));
    int8_t * scratch_buf  = static_cast<int8_t*> (conv_to_local_ptr(layer_params->scratch_buf));
    int32_t* act2_sum     = static_cast<int32_t*>(conv_to_local_ptr(layer_params->act2_sum));

    perform_activation_preprocess
    (
        matK,           // Activation K
        qdq_prm,        //
        act2_sum,       //
        bufC0_K,        //
        scratch_buf,    //
        true,          // here perform transpose on K subv, assuming dataflow bring-in non-transposed K
        false,           // here assuming we do NOT spatially split on inner dimension of K (the common axis between K and Q)
        MHA_3p0_qdq::QKt_offset,      // Qdq parameter offset, 3rd Qdq node, for bufC0_K
        layer_params->Ksubv,          // common dimension , dimension reduced, subv size in number of elements
        layer_params->Nsubv,          // Non-common dimension, subv size in number of elements,
        layer_params->col_id,
        layer_params->row_id,
        layer_params->kernel_params
    );
}

void run_act_V_preprocess(KernelArgs& args)
{
    MHA_3p0_qdq::LayerParams* layer_params = static_cast<MHA_3p0_qdq::LayerParams*>(args.params_data);

    int8_t* matK = (layer_params->ucast_is_s2mm1)? (int8_t*)(args.s2mm_ch1_data) : (int8_t*)args.s2mm_ch0_data;  // assuming uni-cast data from s2mm_1
    int8_t* matV = byte_incr(matK, layer_params->Ksubv * layer_params->Nsubv * sizeof(int16_t));

    int8_t* qdq_prm      = static_cast<int8_t*>( conv_to_local_ptr(layer_params->qdq_prm));
    int8_t* bufC0_V      = static_cast<int8_t*>( conv_to_local_ptr(layer_params->bufC0_V));
    int8_t* scratch_buf  = static_cast<int8_t*>( conv_to_local_ptr(layer_params->scratch_buf));
    int32_t* act2_sum    = static_cast<int32_t*>(conv_to_local_ptr(layer_params->act2_sum));

    perform_activation_preprocess
    (
        matV,           // Activation V
        qdq_prm,        //
        act2_sum,       //
        bufC0_V,        //
        scratch_buf,    //
        false,          // here "NOT" performing transpose on V subv, assuming dataflow bring-in non-transposed V
        true,           // here assuming we do spatially split on N
        MHA_3p0_qdq::SMxV_offset,    // Qdq parameter offset, 4th Qdq node, for bufC0_K
        layer_params->Nsubv,  // common dimension , dimension reduced, subv size in number of elements
        layer_params->Lsubv,  // Non-common dimension, subv size in number of elements
        layer_params->col_id,
        layer_params->row_id,
        layer_params->kernel_params      //
    );
}

void run_sfmx_i16_to_i16(KernelArgs& args)
{
    using Ta = uint16_t;

    MHA_3p0_qdq::LayerParams* layer_params = static_cast<MHA_3p0_qdq::LayerParams*>(args.params_data);
    int Msubv       = layer_params->Msubv;
    int Nsubv       = layer_params->Nsubv;

    int8_t* tdm1    = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t* tdm2    = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm  ));

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Performa Softmax :　SoftMax Input (bf16) ----> SoftMax Output (bf16) , In-place: I/O both at tdm1
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    softmax_bf16x16(tdm1, tdm2, Msubv, Nsubv, (layer_params->col_id) % 4, (layer_params->row_id)+2, (layer_params->multi_core > 0));

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Quantization:  Input (bf16) ----> Output (i16)  , Output @ tdm2
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    Ta*       q_zp  = reinterpret_cast<      Ta*>(byte_incr(qdq_prm,  MHA_3p0_qdq::Q_offset  ));
    bfloat16* q_sc  = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm,  MHA_3p0_qdq::Q_offset+4));
    quant_bf16_to_int16(tdm1, tdm2, Msubv * Nsubv,  *q_zp,  *q_sc, false);
}

void act2act_gemm_qdq
(
    int8_t* matA,
    int8_t* matB,
    int8_t* actA_sum,
    int8_t* bufC0_B,
    int8_t* qdq_prm,
    int8_t* act2act_gemm_postqdq_out,
    int8_t* tdm1,
    int8_t* tdm2,
    int8_t* scratch_buf,
    int Qdq_node_offset_1,
    int Qdq_node_offset_2,
    bool spatial_reduct_flag,  // 1: common axis between matA and matB is spatial splitted, 0 otherwise
    bool bf16_output,
    int Asubv_rows,
    int ABsubv_common,
    int Bsubv_cols,
    int aie_col_id,
    int aie_row_id,
    KernelParams& sum_params,
    bool transposeA,
    bool transposeB
)
{
    const bool acc_init = true;
    const bool last_tdm_iter = true;

    using Ta = int16_t;
    using Ts = int32_t;

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Perform Activation Sum on matA:
    ////////////////////////////////////////////////////////////////////////////////////////////////////
    //sum_params.Ci_g = ABsubv_common / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8
    sum_params.step_Ci = Asubv_rows * 8 * sizeof(Ta);
    sum_params.X_g = Asubv_rows / 8;
    direct_conv_int16x8_generic((int16_t*)matA, scratch_buf, scratch_buf, (int32_t*)actA_sum, (int32_t*)actA_sum, \
                            sum_params, acc_init, last_tdm_iter, OP_SUM);

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Perform Act to Act Gemm:
    ////////////////////////////////////////////////////////////////////////////////////////////////////
    int shiftamt_acc64_int32 = *(int*)(byte_incr(qdq_prm, Qdq_node_offset_1 + 36));

    gemm_int16xint16(matA, matB, tdm1, tdm1, tdm1, Asubv_rows / 8, ABsubv_common / 8, Bsubv_cols / 8,
            Asubv_rows * ABsubv_common * sizeof(Ta), ABsubv_common * Bsubv_cols * sizeof(Ta),
            shiftamt_acc64_int32, transposeB, acc_init, last_tdm_iter, transposeA);


    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Perform Spatial Reduction if needed, using scratch as tmpad. (In-place)
    ////////////////////////////////////////////////////////////////////////////////////////////////////
    if(spatial_reduct_flag)
    {
        // global reduced write back to actAsum. scratch_buf as scratch-pad.
        // num_elem(actA_sum) == Asubv_rows
        global_reduce_sum_int32xint32((int8_t*)actA_sum, (int8_t*)actA_sum, scratch_buf, (Asubv_rows / 16),              aie_col_id%4, aie_row_id+2);
        global_reduce_sum_int32xint32((int8_t*)tdm1    , (int8_t*)tdm1    , scratch_buf, (Asubv_rows *  Bsubv_cols/ 16), aie_col_id%4, aie_row_id+2);
    }

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Perform PostGemm Qdq, result at: act2act_gemm_postqdq_out
    //         If common axis is spatially splitted, only perform qdq on bottom-right core
    ////////////////////////////////////////////////////////////////////////////////////////////////////
    {
        v64uint8* pI = (v64uint8*)byte_incr(qdq_prm, Qdq_node_offset_1);
        v64uint8* restrict pO = (v64uint8*)(&g_qdq_kernel_params);
        pO[0] = pI[0];
        chess_memory_fence();
    }

    if((spatial_reduct_flag && (aie_col_id%4) == 3 && aie_row_id == 0) || (!spatial_reduct_flag))
    {
        sum_params.Co_g = Bsubv_cols / 8;
        sum_params.shift_res = g_qdq_kernel_params.shift_Qout;

        int16* postQdqOut = (!bf16_output)? (int16*)act2act_gemm_postqdq_out : (int16*)tdm2;
        direct_conv_int16x8_generic(nullptr, nullptr, nullptr, (int32*)tdm1, nullptr, (Ts*)actA_sum, (int64_t*)bufC0_B,
            g_qdq_kernel_params.c1, g_qdq_kernel_params.c2, postQdqOut, sum_params, OP_QDQ, 1, 1, 1);

        ////////////////////////////////////////////////////////////////////////////////////////////////////
        // DeQuantization:  QdqOutput (i16) ----> QdqOutput (bf16) (Optional), In-place
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        if(bf16_output)
        {
            Ta*       dq_zp = reinterpret_cast<      Ta*>(byte_incr(qdq_prm, Qdq_node_offset_2  ));
            bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, Qdq_node_offset_2+4));
            dequant_int16_to_bf16((int8_t*)postQdqOut, act2act_gemm_postqdq_out, Asubv_rows * Bsubv_cols, *dq_zp, *dq_sc, false);
        }
    }
}
int callcnt = 0;
void run_qkt_gemm_qdq(KernelArgs& args)
{
    MHA_3p0_qdq::LayerParams* layer_params = static_cast<MHA_3p0_qdq::LayerParams*>(args.params_data);
    int8_t* matQ    = (layer_params->ucast_is_s2mm1)? static_cast<int8_t*>(args.s2mm_ch0_data) : static_cast<int8_t*>(args.s2mm_ch1_data);  // assuming Q from b-cast data from s2mm_0
    int8_t* matKt   = (layer_params->ucast_is_s2mm1)? static_cast<int8_t*>(args.s2mm_ch1_data) : static_cast<int8_t*>(args.s2mm_ch0_data) ;  // assuming K from u-cast data from s2mm_1,
                                                    // assuming K is already transposed earlier by calling preprocess wrapper
    int8_t* matQKt  = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t* matV = byte_incr(matKt, layer_params->Ksubv * layer_params->Nsubv * sizeof(int16_t));
    int8_t* actQsum = static_cast<int8_t*>(conv_to_local_ptr(layer_params->act1_sum));
    int8_t* bufC0_K = static_cast<int8_t*>(conv_to_local_ptr(layer_params->bufC0_K));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));

    int8_t * tdm1   = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t * tdm2   = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    int8_t * scratch_buf  = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));
    KernelParams sum_params = layer_params->kernel_params;
    bool K_is_spatial_reduced = false;  // K as common dimension between Q tensor and K tensor
    bool QKt_bf16_output = true;

    if(layer_params->col_id >= 4)
        matQ = byte_incr(matQ, layer_params->Msubv * layer_params->Ksubv * sizeof(int16_t));

#if 0
    int block8_cnt = 0;
    if(layer_params->col_id == 0 && layer_params->row_id == 0 && (callcnt == 0 or callcnt == 8))
    {
        /*for(int r = 0; r < layer_params->Msubv; r++)
        {
            for(int c = 0; c < layer_params->Ksubv; c+=8)
            {
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 0]);
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 1]);
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 2]);
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 3]);
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 4]);
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 5]);
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 6]);
                printf("%3d ", matQ[r*layer_params->Ksubv + c + 7]);
                printf("\n");
                block8_cnt++;
                if(block8_cnt % layer_params->Msubv == 0)
                    printf("----------------------------------------\n");
            }
        }*/

        int16_t* matQ_16 = (int16_t*)matQ;
        for(int b = 0; b < layer_params->Msubv * layer_params->Ksubv; b+=8)
        {
            printf("%3d ", matQ_16[b + 0]);
            printf("%3d ", matQ_16[b + 1]);
            printf("%3d ", matQ_16[b + 2]);
            printf("%3d ", matQ_16[b + 3]);
            printf("%3d ", matQ_16[b + 4]);
            printf("%3d ", matQ_16[b + 5]);
            printf("%3d ", matQ_16[b + 6]);
            printf("%3d ", matQ_16[b + 7]);
            printf("\n");
        }
        printf("----------------------------------------\n");
    }
#endif
#if 0
    int block8_cnt = 0;
    if(layer_params->col_id == 3 && layer_params->row_id == 3 && (callcnt == 0 or callcnt == 8))
    {
        int16_t* matK_16 = (int16_t*)matKt;
        for(int b = 0; b < layer_params->Ksubv * layer_params->Nsubv; b+=8)
        {
            printf("%3d ", matK_16[b + 0]);
            printf("%3d ", matK_16[b + 1]);
            printf("%3d ", matK_16[b + 2]);
            printf("%3d ", matK_16[b + 3]);
            printf("%3d ", matK_16[b + 4]);
            printf("%3d ", matK_16[b + 5]);
            printf("%3d ", matK_16[b + 6]);
            printf("%3d ", matK_16[b + 7]);
            printf("\n");
        }
        printf("----------------------------------------\n");
    }
#endif
#if 0
    int block8_cnt = 0;
    if(layer_params->col_id == 2 && layer_params->row_id == 0 && (callcnt == 0 or callcnt == 8))
    {
        int16_t* matV_16 = (int16_t*)matV;
        for(int b = 0; b < layer_params->Nsubv * layer_params->Lsubv; b+=8)
        {
            printf("%3d ", matV_16[b + 0]);
            printf("%3d ", matV_16[b + 1]);
            printf("%3d ", matV_16[b + 2]);
            printf("%3d ", matV_16[b + 3]);
            printf("%3d ", matV_16[b + 4]);
            printf("%3d ", matV_16[b + 5]);
            printf("%3d ", matV_16[b + 6]);
            printf("%3d ", matV_16[b + 7]);
            printf("\n");
        }
        printf("----------------------------------------\n");
    }
#endif
    act2act_gemm_qdq
    (
        matQ,
        matKt,
        actQsum,
        bufC0_K,
        qdq_prm,
        matQKt,                   //int8_t* run_act2act_gemm_qdq,
        tdm1,                     //int8_t* tdm1,
        tdm2,                     //int8_t* tdm2,
        scratch_buf,              //int8_t* scratch_buf,
        MHA_3p0_qdq::QKt_offset,               //int Qdq_node_offset, 128 being 3rd qdq node
        MHA_3p0_qdq::DQ_offset,
        K_is_spatial_reduced,     // 1: common axis between matA and matB is spatial splitted, 0 otherwise
        QKt_bf16_output,          // QKt_qdq_int16 -> QKt_qdq_bf16 conversion enable / disable
        layer_params->Msubv,      //int Asubv_rows,
        layer_params->Ksubv,      //int ABsubv_common,
        layer_params->Nsubv,      //int Bsubv_cols,
        layer_params->col_id,     //int aie_col_id,
        layer_params->row_id,     //int aie_row_id,
        layer_params->kernel_params,
        false,
        true
    );
    callcnt++;
    if(layer_params->col_id == 0 && layer_params->row_id == 0)
        printf("subvol_count : %d\n", callcnt);
}

void run_smxv_gemm_qdq(KernelArgs& args)
{
    MHA_3p0_qdq::LayerParams* layer_params = static_cast<MHA_3p0_qdq::LayerParams*>(args.params_data);
    int8_t* matSM   = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr)); // SoftMax Output @ tdm2

    int8_t* matK = (layer_params->ucast_is_s2mm1)? (int8_t*)args.s2mm_ch1_data : (int8_t*)args.s2mm_ch0_data;  // assuming uni-cast data from s2mm_1
    int8_t* matV = byte_incr(matK, layer_params->Ksubv * layer_params->Nsubv * sizeof(int16_t)); //int8_t* matV    = static_cast<int8_t*>(args.s2mm_ch1_data);  // assuming V  from u-cast data from s2mm_1,
    int8_t* matO = static_cast<int8_t*>(args.mm2s_ch0_data);  // assuming O  to   u-cast data to   mm2s_0

    int8_t* actSMsum= static_cast<int8_t*>(conv_to_local_ptr(layer_params->act1_sum));
    int8_t* bufC0_V = static_cast<int8_t*>(conv_to_local_ptr(layer_params->bufC0_V));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));

    int8_t * tdm1   = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t * tdm2   = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    int8_t * scratch_buf  = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));
    //KernelParams sum_params = layer_params->kernel_params;
    bool N_is_spatial_reduced = true;   // N as common dimension between QKt tensor and V tensor
    bool SMxV_bf16_output = false;
#if 0
    int block8_cnt = 0;
    if(layer_params->col_id == 0 && layer_params->row_id == 0 && callcnt == 0)
    {
        int16_t* matV_16 = (int16_t*)matV;
        for(int b = 0; b < layer_params->Nsubv * layer_params->Lsubv; b+=8)
        {
            printf("%3d ", matV_16[b + 0]);
            printf("%3d ", matV_16[b + 1]);
            printf("%3d ", matV_16[b + 2]);
            printf("%3d ", matV_16[b + 3]);
            printf("%3d ", matV_16[b + 4]);
            printf("%3d ", matV_16[b + 5]);
            printf("%3d ", matV_16[b + 6]);
            printf("%3d ", matV_16[b + 7]);
            printf("\n");
        }
    }
#endif
    act2act_gemm_qdq
    (
        matSM,
        matV,
        actSMsum,
        bufC0_V,
        qdq_prm,
        matO,                     //int8_t* run_act2act_gemm_qdq,
        tdm1,                     //int8_t* tdm1,
        tdm2,                     //int8_t* tdm2,
        scratch_buf,              //int8_t* scratch_buf,
        MHA_3p0_qdq::SMxV_offset, //int Qdq_node_offset, 4th qdq node
        MHA_3p0_qdq::None_offset, // Not consumed
        N_is_spatial_reduced,     // 1: common axis between matA and matB is spatial splitted, 0 otherwise
        SMxV_bf16_output,         // SMxV_qdq_int16 -> QKt_qdq_bf16 conversion enable / disable
        layer_params->Msubv,      //int Asubv_rows,
        layer_params->Nsubv,      //int ABsubv_common,
        layer_params->Lsubv,      //int Bsubv_cols,
        layer_params->col_id,     //int aie_col_id,
        layer_params->row_id,     //int aie_row_id,
        layer_params->kernel_params,
        false,
        false
    );
    //callcnt++;
}

#if 0
/*
    callcnt++;
    uint32_t* qkt_out_u32 = (uint32_t*)tdm1;
    if((aie_col_id==0) && (aie_row_id==0) && (callcnt == 1))
    {
        for(int i = 0; i < 8; i++)
            printf("%3d \n", qkt_out_u32[i]);
        printf("shiftamt:%d\n", shiftamt_acc64_int32);
    }
    */

    /*if(transposeB == true)
        {
            printf("c1   = %d", g_qdq_kernel_params.c1);
            printf("c2   = %d", g_qdq_kernel_params.c2);
            printf("Qout = %d", g_qdq_kernel_params.shift_Qout);
        }*/
#endif
