#include "./gemm/stx_gemm_int16xint16_w4_tdm.cc"

#include "./nonlinear/softmax_bf16x16/softmax_bf16x16_kernel.c"
#include "./qdq/wrapper_qdq.cc"
#include "./conv/direct_conv_int16x8_generic/direct_conv_int16x8_generic_kernel.c"

namespace MHA_2p1_qdq
{
    struct alignas(4) __attribute__ ((__packed__)) LayerParams{
        uint8_t mha_mode;
        uint8_t multi_core;
        uint8_t col_id;
        uint8_t row_id;
        uint16_t Msubv;
        uint16_t Ksubv;
        uint16_t Nsubv;
        uint16_t Sin_kv;
        uint16_t tdm1_addr;
        uint16_t tdm2_addr;
        uint16_t qdq_prm;
        uint16_t act1_sum;
        uint16_t act2_sum;
        uint16_t bufC0;
        uint16_t scratch_buf;
        KernelParams kernel_params;
    };

    int MUL_DQ_offset   =  0;
    int MUL_Q_offset    = 64;
    int SMxV_offset     = 192;
    int DQ_offset       = 256;
    int Q_offset        = 320;

    void mask_w8_cols( bfloat16 * ptr, int rows, int step_cols, bfloat16 mask_value, uint32_t mask) {
        bfloat16 * pI = ptr + step_cols * rows;
        bfloat16 * pO = ptr + step_cols * rows;
        for ( unsigned i = 0; i < rows / 4; i++ )
            chess_no_hw_loop
            chess_prepare_for_pipelining
            chess_loop_range(2,)
        {
            aie::store_v( pO, aie::select( mask_value, aie::load_v<32>( pI ), aie::mask<32>::from_uint32(mask)));
            pI += 32;
            pO += 32;
        }
    }
}



template<bool enable_8x4_overlay>
void run_qtxk(KernelArgs& args)
{
    using Ta = int16_t;
    using Ts = int32_t;
    int QKt_offset = 128;

    MHA_2p1_qdq::LayerParams* layer_params = static_cast<MHA_2p1_qdq::LayerParams*>(args.params_data);
    KernelParams sum_params = layer_params->kernel_params;

    int Msubv = layer_params->Msubv;
    int Ksubv = layer_params->Ksubv;
    int Nsubv = layer_params->Nsubv;
    int mha_mode = layer_params->mha_mode;  // 0 - QKt + SM fusion; 1 - QK + SM + SMV fusion; 2/3 - corresponding K matrix pre-processing
    int colId = layer_params->col_id;

    int8_t* matQ;
    int8_t* matK;
    int8_t* tdm1    = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t* tdm2    = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));
    int32_t* act1_sum  = (int32_t*)conv_to_local_ptr(layer_params->act1_sum);
    int32_t* act2_sum  = (int32_t*)conv_to_local_ptr(layer_params->act2_sum);
    int8_t* bufC0 = (int8_t*)conv_to_local_ptr(layer_params->bufC0);
    int8_t* scratch_buf  = (int8_t*)conv_to_local_ptr(layer_params->scratch_buf);

    if constexpr( enable_8x4_overlay == 0) {
        matQ = static_cast<int8_t*>(args.s2mm_ch0_data);
        matK = static_cast<int8_t*>(args.s2mm_ch1_data);
    }

    else{
        matQ = static_cast<int8_t*>(args.s2mm_ch1_data);
        matK = static_cast<int8_t*>(args.s2mm_ch0_data);
    }

    if(colId >= 4)
        matQ = byte_incr(matQ, Msubv * Ksubv * sizeof(Ta));  // 2 since 2 bytes per elem

    int8_t* qkt_output = tdm1;  // matQ; // QKt output in Q buf if fused MHA else compiler generated buf

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    // Accumulate i1sum & i2sum
    ////////////////////////////////////////////////////////////////////////////////////////////////////

    // QDQ sums
    int sum_mode = OP_SUM;
    const bool transposeQ = false;
    const bool transposeK = true;
    const bool acc_init = true;
    const bool last_tdm_iter = true;

    KernelControl ctrl = sum_params.ctrl;
    ctrl.zero_init = 1;
    ctrl.sign_A = 0;
    ctrl.sign_W = 0;
    ctrl.sign_O = 0;
    sum_params.Ci_g = Ksubv / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8

    if (mha_mode >= 2) {
        sum_params.step_Ci = Nsubv * 8 * sizeof(Ta);
        sum_params.X_g = Nsubv / 8;

        direct_conv_int16x8_generic((int16_t*)matK, scratch_buf, scratch_buf, act2_sum, act2_sum, nullptr, sum_params, acc_init, last_tdm_iter, sum_mode);

        ////////////////////////////////////////////////////////////////////////////////////////////////////
        // Calculate C0 from i2sum and perform QDQ
        ////////////////////////////////////////////////////////////////////////////////////////////////////
        {
            v64uint8* pI = (v64uint8*)byte_incr(qdq_prm, QKt_offset);
            v64uint8* restrict pO = (v64uint8*)(&g_qdq_kernel_params);
            pO[0] = pI[0];
            chess_memory_fence();
        }
        run_sum_c0<int32, int64>(act2_sum, nullptr, bufC0, Nsubv);

        v32int16 chess_storage(DM_bankA)* ptrO = (v32int16 chess_storage(DM_bankA)*) matK;
        v32int16 chess_storage(DM_bankA)* ptrI = (v32int16 chess_storage(DM_bankA)*) matK;

        for (int ind = 0; ind < Ksubv * Nsubv / 64; ind++)
        chess_no_hw_loop
        chess_prepare_for_pipelining
        chess_loop_range(8, )
        {
            v32int16 w0 = *ptrI++;
            v32int16 w1 = *ptrI++;

            int w_mode = T16_8x8_lo;
            *ptrO++ = shuffle(w0, w1, w_mode);
            *ptrO++ = shuffle(w0, w1, w_mode + 1);
        }
    }
    else {
        sum_params.step_Ci = Msubv * 8 * sizeof(Ta);
        sum_params.X_g = Msubv / 8;
        //sum_params.Ci_g = Ksubv / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8

        direct_conv_int16x8_generic((int16_t*)matQ, scratch_buf, scratch_buf, act1_sum, act1_sum, nullptr, sum_params, acc_init, last_tdm_iter, sum_mode);

#if 0
        //print_mat1( (uint16_t*) matK, Nsubv, 8, "Kmat ");
        int aierow = 2 + 0;
        int aiecol = 0;
        if ((get_coreid() & 0xF) == aierow && (get_coreid() >> 16) == aiecol)
        {
            aie::print(aie::vector<int32_t, 32>(*((v32int32*)act1_sum)), true, "ifm1_sum ");
            //aie::print(aie::vector<int32_t, 32>(*((v32int32*) (act1_sum+32))), true, "ifm1_sum ");
            //aie::print(aie::vector<int32_t, 32>(*((v32int32*) act2_sum)), true, "ifm2_sum ");
            //aie::print(aie::vector<int32_t, 32>(*((v32int32*) (act2_sum+32))), true, "ifm2_sum ");
        }
#endif




        ////////////////////////////////////////////////////////////////////////////////////////////////////
        // Qt x K GeMM
        ////////////////////////////////////////////////////////////////////////////////////////////////////

        int qkt_shiftamt_acc64_int32 = *(int*)(byte_incr(qdq_prm, QKt_offset + 36));

        gemm_int16xint16(matQ, matK, tdm1, tdm1, tdm1, Msubv / 8, Ksubv / 8, Nsubv / 8,
            Msubv * Ksubv * sizeof(Ta), Ksubv * Nsubv * sizeof(Ta),
            qkt_shiftamt_acc64_int32, transposeK, acc_init, last_tdm_iter, transposeQ);

        if (mha_mode==1) { // QKt + SM + SMxV fusion case: need to re do sumn c0 as it got overwritten during SMxV.
                    ////////////////////////////////////////////////////////////////////////////////////////////////////
            // Calculate C0 from i2sum and perform QDQ
            ////////////////////////////////////////////////////////////////////////////////////////////////////
            {
                v64uint8* pI = (v64uint8*)byte_incr(qdq_prm, QKt_offset);
                v64uint8* restrict pO = (v64uint8*)(&g_qdq_kernel_params);
                pO[0] = pI[0];
                chess_memory_fence();
            }
            run_sum_c0<int32, int64>(act2_sum, nullptr, bufC0, Nsubv);
        }

        sum_params.Co_g = Nsubv / 8;
        sum_params.shift_res = g_qdq_kernel_params.shift_Qout;

        direct_conv_int16x8_generic(nullptr, nullptr, nullptr, (int32*)tdm1, nullptr, (Ts*)act1_sum, nullptr, (int64_t*)bufC0,
            g_qdq_kernel_params.c1, g_qdq_kernel_params.c2, (int16*)qkt_output, sum_params, OP_QDQ, 1, 1, 1);

        // DeQuant
        Ta* dq_zp = reinterpret_cast<Ta*>(byte_incr(qdq_prm, MHA_2p1_qdq::DQ_offset));
        bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, MHA_2p1_qdq::DQ_offset + 4));
        dequant_int16_to_bf16(qkt_output, tdm2, Msubv * Nsubv, *dq_zp, *dq_sc, false);
    }
}

void run_sfmx(KernelArgs& args)
{
    using Ta = uint16_t;
    using Ts = int32_t;

    MHA_2p1_qdq::LayerParams* layer_params = static_cast<MHA_2p1_qdq::LayerParams*>(args.params_data);
    int Msubv = layer_params->Msubv;
    int Ksubv = layer_params->Ksubv;
    int Nsubv = layer_params->Nsubv;
    int N_IN   = layer_params->Sin_kv;
    int multi_core = layer_params->multi_core;
    int colId = layer_params->col_id;
    int rowId = layer_params->row_id;
    int tdm1_addr = layer_params->tdm1_addr;
    int tdm2_addr = layer_params->tdm2_addr;
    int qdq_addr = layer_params->qdq_prm;

/*
    if(multi_core == 0){
        matQ    = static_cast<int8_t*>(conv_to_local_ptr(query_addr));
        tdm1    = static_cast<int8_t*>(conv_to_local_ptr(tdm1_addr));
        qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(qdq_prm_addr));
    }
*/

    int8_t* tdm1 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t* tdm2 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));

    ////////////////////////////////////////////////////////////////////////////////////////////////////
    //DQ + MASK + SM + Q
    ////////////////////////////////////////////////////////////////////////////////////////////////////
/*    // DeQuant
    Ta*       dq_zp = reinterpret_cast<      Ta*>(byte_incr(qdq_prm, MHA_2p1_qdq::DQ_offset  ));
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, MHA_2p1_qdq::DQ_offset+4));
    dequant_int16_to_bf16(qkt_output, tdm1, Msubv * Nsubv, *dq_zp, *dq_sc, false);
*/
#define PRINT 0
#if PRINT
    print_mat1<bfloat16,v8bfloat16>( (bfloat16*) tdm1, Msubv, 8, "SM_input ");
#endif

    if (N_IN == 77) {
        MHA_2p1_qdq::mask_w8_cols( (bfloat16 *) tdm2, Msubv, 72, -10000.0, 0x1f1f1f1f);
    }

#if PRINT
    print_mat1<bfloat16,v8bfloat16>( (bfloat16*) tdm1, Msubv, 8, "SM_input_masked ");
#endif

    bool global_sm = multi_core > 0;
    softmax_bf16x16(tdm2, tdm1, Msubv, Nsubv, colId, rowId, global_sm); // rowId should be in range [2, 6)
    //softmax_bf16x16<Msubv, Nsubv>(tdm1, matQ );

    // Quant
    Ta*       q_zp  = reinterpret_cast<      Ta*>(byte_incr(qdq_prm,  MHA_2p1_qdq::Q_offset  ));
    bfloat16* q_sc  = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm,  MHA_2p1_qdq::Q_offset+4));
    quant_bf16_to_int16(tdm2, tdm1, Msubv * Nsubv,  *q_zp,  *q_sc, false);

}


void run_qtxk(KernelArgs& args)
{
    run_qtxk<true>(args);
}
