#include "./gemm/stx_gemm_int16xint16_w4_tdm.cc"

#include "./qdq/wrapper_qdq.cc"
#include "conv/direct_conv_int16x8_generic/direct_conv_int16x8_generic_kernel.c"


void run_vxsmt(KernelArgs& args)
{
    struct alignas(4) LayerParams{
        uint8_t mha_mode;
        uint8_t transpose_mode;
        uint8_t col_id;
        uint8_t row_id;
        uint8_t first_tdm_iter;
        uint8_t last_tdm_iter;
        uint16_t Msubv;
        uint16_t Ksubv;
        uint16_t Nsubv;
        uint16_t unpack_V;
        uint16_t tdm1_addr;
        uint16_t tdm2_addr;
        uint16_t qdq_prm;
        uint16_t act1_sum;
        uint16_t act2_sum;
        uint16_t bufC0;
        uint16_t scratch_buf;
        KernelParams kernel_params;
    };

    using Ta = uint16_t;
    using Ts = int32_t;
    set_rnd_wrapper();


    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    KernelParams sum_params = layer_params->kernel_params;

    int Msubv = layer_params->Msubv;
    int Ksubv = layer_params->Ksubv;
    int Nsubv = layer_params->Nsubv;
    int first_tdm_iter = layer_params->first_tdm_iter;
    int last_tdm_iter = layer_params->last_tdm_iter;
    int mode = layer_params->mha_mode;
    int unpack_matV = layer_params->unpack_V;

    int8_t *matSM, *matV, *matO, *tdm1, *tdm2, *bufC0, *scratch_buf, *qdq_prm;
    int32_t *act1_sum, *act2_sum;

    qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));
    tdm1   = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    tdm2   = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    act1_sum= (int32_t*)(conv_to_local_ptr(layer_params->act1_sum));
    act2_sum= (int32_t*)(conv_to_local_ptr(layer_params->act2_sum));
    bufC0   = ( int8_t*)(conv_to_local_ptr(layer_params->bufC0));
    scratch_buf = ( int8_t*)(conv_to_local_ptr(layer_params->scratch_buf));

    matSM   = (mode == 0) ? static_cast<int8_t*>(args.s2mm_ch1_data) : static_cast<int8_t*>(args.s2mm_ch0_data);
    matV    = (mode == 0) ? static_cast<int8_t*>(args.s2mm_ch0_data) : static_cast<int8_t*>(args.s2mm_ch1_data);
    matO = static_cast<int8_t*>(args.mm2s_ch0_data);

    // SM sum for SM x V QDQ
    int sum_mode = OP_SUM;

    // KernelParams sum_params = get_ifm_sum_params();
    KernelControl ctrl = sum_params.ctrl;
    ctrl.zero_init = 1;
    sum_params.ctrl = ctrl;

    if ( unpack_matV ) {
        [[ using chess: min_loop_count( 1 ), no_hw_loop ]]
        for ( int i = Ksubv * Nsubv - 1; i >= 0; i -= 32 ) {
            aie::store_v( (int16_t * )matV + i, aie::load_v<32>((( int8_t * ) matV ) + i ).unpack_sign( ctrl.sign_W ));
        }
    }
    
    // compute SM Sum
    if (layer_params->transpose_mode & 1) {
        sum_mode = OP_SUM_T;
        sum_params.step_Ci = Ksubv *   8    * sizeof(Ta);
        sum_params.X_g = Ksubv / 8;
        sum_params.Ci_g = Msubv / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8
    } else {
        sum_mode = OP_SUM;
        sum_params.step_Ci = Msubv *   8    * sizeof(Ta);
        sum_params.X_g = Msubv / 8;
        sum_params.Ci_g = Ksubv / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8
    }
    direct_conv_int16x8_generic( (int16_t*)matSM, scratch_buf, scratch_buf, act1_sum, act1_sum, nullptr, sum_params, first_tdm_iter, last_tdm_iter, sum_mode);

    // compute V sum
    if (layer_params -> transpose_mode & 2){
        sum_mode = OP_SUM;
        sum_params.step_Ci = Nsubv *   8    * sizeof(Ta);
        sum_params.X_g = Nsubv / 8;
        sum_params.Ci_g = Ksubv / 8;
    }
    else{
        sum_mode = OP_SUM_T;
        sum_params.step_Ci = Ksubv *   8    * sizeof(Ta);
        sum_params.X_g = Ksubv / 8;
        sum_params.Ci_g = Nsubv / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8
    }
    KernelControl ctrl2 = sum_params.ctrl;
    ctrl2.sign_A = ctrl.sign_W;
    sum_params.ctrl = ctrl2;
    direct_conv_int16x8_generic( (int16_t*)matV, scratch_buf, scratch_buf, act2_sum, act2_sum, nullptr, sum_params, first_tdm_iter, last_tdm_iter, sum_mode);
    sum_params.ctrl = ctrl;

    // SMxV GeMM
    int smv_shiftamt_acc64_int32 = *(int*)(byte_incr(qdq_prm,36));

    const bool transposeV = (layer_params->transpose_mode & 2);    
    const bool transposeSM = (layer_params->transpose_mode & 1);
    gemm_int16xint16(matSM, matV, tdm1, tdm2, tdm1, Msubv / 8, Ksubv / 8, Nsubv / 8,
                      Msubv * Ksubv * sizeof(Ta), Ksubv * Nsubv * sizeof(Ta),
                     smv_shiftamt_acc64_int32, transposeV, first_tdm_iter, last_tdm_iter, transposeSM, ctrl.sign_A, ctrl.sign_W);

    int const col_idx = (get_coreid() >> 16);
    int const row_idx = (get_coreid() & 0xF);

    if (last_tdm_iter) {    
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        // Calculate C0 from i2sum
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        {
            v64uint8 * pI = (v64uint8*)(qdq_prm);
            v64uint8 * restrict pO = (v64uint8*)(&g_qdq_kernel_params);
            pO[0] = pI[0];
            chess_memory_fence();
        }
        run_sum_c0<int32, int64>( act2_sum, nullptr, bufC0, Nsubv);
        sum_params.step_Ci = Msubv *   8    * sizeof(Ta);
        sum_params.X_g = Msubv / 8;
        sum_params.Ci_g = Ksubv / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8
        sum_params.Co_g = Nsubv / 8;
        sum_params.shift_res = g_qdq_kernel_params.shift_Qout;

        direct_conv_int16x8_generic(nullptr, nullptr, nullptr, (int32*)tdm1, nullptr, act1_sum, nullptr, ( int64_t* )bufC0, g_qdq_kernel_params.c1, g_qdq_kernel_params.c2, ( int16* )matO, sum_params, OP_QDQ, 1, 1, 1, 1);

    }
}
