#include "./gemm/stx_gemm_int16xint16_w4_tdm.cc"
#include "./norm/softmax.cc"
#include "./qdq/wrapper_qdq.cc"
#include "./conv/direct_conv_int16x8_generic/direct_conv_int16x8_generic_kernel.c"
#include "./matadd/matadd_kernel_wrapper.c"

namespace MHA_mini_3p0_qdq
{
    struct alignas(4) __attribute__((__packed__)) LayerParams {
            uint8_t mha_mode;
            uint8_t multi_core;
            uint8_t col_id;
            uint8_t row_id;
            uint16_t Msubv;
            uint16_t Ksubv;
            uint16_t Nsubv;
            uint16_t Sin_kv;
            uint16_t tdm1_addr;
            uint16_t tdm2_addr;
            uint16_t qdq_prm;
            uint16_t act1_sum;
            uint16_t act2_sum;
            uint16_t bufC0;
            uint16_t scratch_buf;
            uint16_t query_addr;
            uint16_t key_addr;
            uint16_t val_addr;
            uint16_t msk_addr;
            KernelParams kernel_params;
    };

    struct alignas(4) __attribute__((__packed__)) LayerParamsBcastAdd {
        uint16_t Msubv;         
        uint16_t Nsubv;          
        uint16_t Nlayer;          
        uint16_t dq_node_addr;    
        uint16_t act_addr;   
        uint16_t mask_vector_addr;
        uint16_t bias_addr;
        uint16_t out_addr;       
        uint16_t attn_vector_exist;  
        uint16_t bias_exist; 
        uint16_t perform_dq;
        uint16_t dummy;
    };

    //NOTE try to put this in kernel or layer params?
    const int MUL_DQ_offset = 0;
    const int MUL_Q_offset = 64;
    const int SMxV_offset = 192;
    const int DQ_offset = 256;
    const int Q_offset = 320;

    void mask_w8_cols(bfloat16* ptr, int rows, int step_cols, bfloat16 mask_value, uint32_t mask) {
    bfloat16* pI = ptr + step_cols * rows;
    bfloat16* pO = ptr + step_cols * rows;
    for (unsigned i = 0; i < rows / 4; i++)
        chess_no_hw_loop
        chess_prepare_for_pipelining
        chess_loop_range(2,)
    {
            aie::store_v(pO, aie::select(mask_value, aie::load_v<32>(pI), aie::mask<32>::from_uint32(mask)));
            pI += 32;
            pO += 32;
    }
}
};

void __attribute__((noinline)) broadcast_bfloat16xbfloat16
(
    int8_t* pIn,  // input  (    1     x num_cols )
    int8_t* pOut, // output ( num_rows x num_cols )
    int num_rows,
    int num_cols
)
{
    //v64acc32 __aie_dm_resource_a * restrict v_mat_out = ( v64acc32 __aie_dm_resource_a * )pOut;
    bfloat16* mat_out = reinterpret_cast<bfloat16*>(pOut);
    bfloat16* mat_in  = reinterpret_cast<bfloat16*>(pIn);

    int const outer_loop = num_cols / 8;
    int const inner_loop = num_rows / 8;

    // each inner iteration is working on a 8x8 block
    for (int col = 0; col < outer_loop; ++col) 
    {
        aie::vector<bfloat16,8> v8_bf16 = aie::load_v<8>( mat_in );         mat_in += 8;
        aie::vector<bfloat16,64> v64_bf16 = v8_bf16.grow_replicate<64>( );

        for (int row = 0; row < inner_loop; ++row) 
        {
            aie::store_v( mat_out, v64_bf16 );      mat_out += 64;
        }
    }
}


/*void __attribute__((noinline)) transpose(v32int16* ptrI, v32int16* ptrO, int num_elements) 
{
    for (int ind = 0; ind < num_elements / 64; ind++) 
        chess_no_hw_loop
        chess_prepare_for_pipelining
        chess_loop_range(8,)
    {
            v32int16 w0 = *ptrI++;
            v32int16 w1 = *ptrI++;
            int w_mode = T16_8x8_lo;
            *ptrO++ = shuffle(w0, w1, w_mode);
            *ptrO++ = shuffle(w0, w1, w_mode + 1);
    }
}*/
#define MINI_MHA_DEBUG 0
void run_mini_mha_preprocess(KernelArgs& args) 
{
    using Ta = int16_t;
    using Ts = int32_t;
    set_rnd_wrapper();
    const int QKt_offset = 128;

    MHA_mini_3p0_qdq::LayerParams* layer_params = static_cast<MHA_mini_3p0_qdq::LayerParams*>(args.params_data);
    KernelParams sum_params = layer_params->kernel_params;

    int Msubv = layer_params->Msubv;
    int Ksubv = layer_params->Ksubv;
    int Nsubv = layer_params->Nsubv;
     
     //NOTE: Remove 0 and take input from layer params
    int8_t* matK = static_cast<int8_t*>(conv_to_local_ptr(layer_params->key_addr));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));
    int32_t* act2_sum = static_cast<int32_t*>(conv_to_local_ptr(layer_params->act2_sum));
    int8_t* bufC0 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->bufC0));
    int8_t* scratch_buf = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));
    bool perform_transpose = (layer_params->multi_core == 1);
    
    int sum_mode = (layer_params->mha_mode == 1)? OP_SUM_T : OP_SUM;
    const bool transposeQ = false;
    const bool transposeK = false;
    const bool acc_init = true;
    const bool last_tdm_iter = true;

    KernelControl ctrl = sum_params.ctrl;
    ctrl.zero_init = 1;
    ctrl.sign_A = 0;
    ctrl.sign_W = 0;
    ctrl.sign_O = 0;
    

    int rowId1 = (get_coreid() & 0xF);
    int colId1 = (get_coreid() >> 16);
    
    // Just does K transpose and K sum
    sum_params.step_Ci = Nsubv * 8 * sizeof(Ta);//Ksubv * 8 * sizeof(Ta);
    sum_params.X_g = Nsubv / 8;//Ksubv / 8;
    sum_params.Ci_g = Ksubv / 8;//Nsubv / 8;
    direct_conv_int16x8_generic(reinterpret_cast<int16_t*>(matK), scratch_buf, scratch_buf, act2_sum, act2_sum, 
                                    nullptr, sum_params, acc_init, last_tdm_iter, sum_mode);

    {
        v64uint8* pI = reinterpret_cast<v64uint8*>(byte_incr(qdq_prm, QKt_offset));
        v64uint8* pO = reinterpret_cast<v64uint8*>(&g_qdq_kernel_params);
        pO[0] = pI[0];
        chess_memory_fence();
    }
    run_sum_c0<int32, int64>(act2_sum, nullptr, bufC0, Nsubv); // NOTE what does this do?
}


struct alignas(4) __attribute__((__packed__)) LayerParam{
        uint8_t mha_mode;
        uint8_t multi_core;
        uint8_t col_id;
        uint8_t row_id;
        uint8_t first_tdm_iter;
        uint8_t last_tdm_iter;
        uint16_t Msubv;
        uint16_t Ksubv;
        uint16_t Nsubv;
        uint16_t Sin_kv;
        uint16_t tdm1_addr;
        uint16_t tdm2_addr;
        uint16_t qdq_prm;
        uint16_t act1_sum;
        uint16_t act2_sum;
        uint16_t bufC0;
        uint16_t scratch_buf;
        uint16_t Asubv_addr;
        uint16_t Bsubv_addr;
        uint16_t msk_addr;
        uint16_t out_addr;
        uint16_t qdq_node_offset;
        uint16_t transpose_B;
        KernelParams kernel_params;
    };

static int callcnt = 0;
void run_gemm_qdq_mini(KernelArgs& args) 
{
    using Ta = uint16_t;
    using Ts = int32_t;
    set_rnd_wrapper();
    
    LayerParam* layer_params = static_cast<LayerParam*>(args.params_data);
    KernelParams sum_params = layer_params->kernel_params;

    int Msubv = layer_params->Msubv;
    int Ksubv = layer_params->Ksubv;
    int Nsubv = layer_params->Nsubv;
    int first_tdm_iter = layer_params->first_tdm_iter;
    int last_tdm_iter = layer_params->last_tdm_iter;
    int8_t* tdm1 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t* tdm2 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm2_addr));
    int8_t* matO = static_cast<int8_t*>(conv_to_local_ptr(layer_params->out_addr));//;tdm2;
    int8_t* matA = static_cast<int8_t*>(conv_to_local_ptr(layer_params->Asubv_addr));
    int8_t* matB = static_cast<int8_t*>(conv_to_local_ptr(layer_params->Bsubv_addr));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));
    int32_t* act1_sum = static_cast<int32_t*>(conv_to_local_ptr(layer_params->act1_sum));
    int32_t* act2_sum = static_cast<int32_t*>(conv_to_local_ptr(layer_params->act2_sum));
    int8_t* bufC0 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->bufC0));
    int8_t* scratch_buf = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));
    int8_t* matM = static_cast<int8_t*>(conv_to_local_ptr(layer_params->msk_addr));

    const int sum_mode = OP_SUM;//(layer_params->mha_mode == 1)? OP_SUM_T : OP_SUM;
    const int Qdq_Node_offset = layer_params->qdq_node_offset;

    KernelControl ctrl = sum_params.ctrl;
    ctrl.zero_init = 1;
    ctrl.sign_A = 0;
    ctrl.sign_W = 0;
    ctrl.sign_O = 0;

    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    int colId = layer_params->col_id;

    // If in mha_mode 1 (mha_2p1), consume the second half of Q subvolume streamed in
    if(colId >= 4 && layer_params->mha_mode == 1)
    {	  
        matA = byte_incr(matA, Msubv * Ksubv * sizeof(Ta));  // 2 since 2 bytes per elem
    }
#if MINI_MHA_DEBUG   
    v8uint16* vtmp = (v8uint16*)matM;
    if(rowIdx_reg == 4 && colIdx_reg == 2 && callcnt == 0)
    {
        chess_report(0xA0FFEE);
        for(int i = 0; i < (Nsubv/8); i++)
        {
            chess_report(*vtmp);
            vtmp++;
        }
        chess_report(0xA1FFEE);
    }
#endif
    //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Debugging Purpose ONLY !! :
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#if MINI_MHA_DEBUG
    v8uint16* vtmp = (v8uint16*)matA;
    if(rowIdx_reg == 2 && colIdx_reg == 0 && callcnt == 0)
    {
        chess_report(0xA0FFEE);
        for(int i = 0; i < (Msubv*Ksubv/8); i++)
        {
            chess_report(*vtmp);
            vtmp++;
        }
        chess_report(0xA1FFEE);

        vtmp = (v8uint16*)matB;
        chess_report(0xB0FFEE); // 64x77 == 16x4*77
        for(int i = 0; i < (Nsubv*Ksubv/8); i++)
        {
            chess_report(*vtmp);
            vtmp++;
        }
        chess_report(0xB1FFEE);
    }
#endif    
    int rowId1 = (get_coreid() & 0xF);
    int colId1 = (get_coreid() >> 16);

    // compute act1_sum 
    sum_params.step_Ci = Msubv * 8 * sizeof(Ta);
    sum_params.X_g = Msubv / 8;
    sum_params.Ci_g = Ksubv / 8;
    direct_conv_int16x8_generic(reinterpret_cast<int16_t*>(matA), scratch_buf, scratch_buf, act1_sum, act1_sum,
                                nullptr, sum_params, first_tdm_iter, last_tdm_iter, sum_mode); 

    // AxB or AxBt GeMM
    int shiftamt_acc64_int32 = *reinterpret_cast<int*>(byte_incr(qdq_prm, Qdq_Node_offset + 36));
    const bool transposeA = false;
    bool transposeB = (layer_params->transpose_B == 1);
    gemm_int16xint16(matA, matB, tdm1, tdm1, tdm1, Msubv / 8, Ksubv / 8, Nsubv / 8, Msubv * Ksubv * sizeof(Ta),
                        Ksubv * Nsubv * sizeof(Ta), shiftamt_acc64_int32, transposeB, first_tdm_iter, last_tdm_iter, transposeA);


    //////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Debugging Purpose ONLY !! :
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#if MINI_MHA_DEBUG
    v8int32* vtmp2 = (v8int32*)tdm1;
    if(rowIdx_reg == 2 && colIdx_reg == 0 && callcnt == 0)
    {
        chess_report(0xC0FFEE);
        for(int i = 0; i < (Msubv*Nsubv/8); i++)
        {
            chess_report(*vtmp2);
            vtmp2++;
        }
        chess_report(0xC1FFEE);
        chess_report(Qdq_Node_offset);
        chess_report(shiftamt_acc64_int32);
        chess_report(Msubv);
        chess_report(Ksubv);
        chess_report(Nsubv);
        chess_report(0xC1FFFF);
    }
#endif
    if (last_tdm_iter) 
    {
        {
            v64uint8* pI = reinterpret_cast<v64uint8*>(byte_incr(qdq_prm, Qdq_Node_offset));
            v64uint8* pO = reinterpret_cast<v64uint8*>(&g_qdq_kernel_params);
            pO[0] = pI[0];
            chess_memory_fence();
        }
        run_sum_c0<int32, int64>(act2_sum, nullptr, bufC0, Nsubv);
    
        // Post-Gemm QDQ
        sum_params.shift_res = g_qdq_kernel_params.shift_Qout;
        sum_params.step_Ci = Msubv *   8    * sizeof(Ta);
        sum_params.X_g = Msubv / 8;
        sum_params.Ci_g = Ksubv / 8;    // IFM: H C W C8;      X, Ci, Co have granularity of 8
        sum_params.Co_g = Nsubv / 8;
#if MINI_MHA_DEBUG
        if(rowIdx_reg == 2 && colIdx_reg == 0 && callcnt == 0)
        {
            chess_report(sum_params.shift_res);   // 31
            chess_report(sum_params.step_Ci);     // 256
            chess_report(sum_params.X_g );        // 2  --> 16
            chess_report(sum_params.Ci_g);        // 8  --> 64
            chess_report(sum_params.Co_g);        // 10 --> 80
            chess_report(g_qdq_kernel_params.c1); // -362849262
            chess_report(g_qdq_kernel_params.c2); // 1401216
        }
#endif
        direct_conv_int16x8_generic(nullptr, nullptr, nullptr, reinterpret_cast<int32*>(tdm1), nullptr, act1_sum, nullptr, 
                    reinterpret_cast<int64_t*>(bufC0), g_qdq_kernel_params.c1, g_qdq_kernel_params.c2, 
                    reinterpret_cast<int16*>(tdm1), sum_params, OP_QDQ, 1, 1, 1);
#if MINI_MHA_DEBUG
        vtmp2 = (v8int32*)act1_sum;
        if(rowIdx_reg == 2 && colIdx_reg == 0 && callcnt == 0)
        {
            chess_report(0xC7FFEE);
            for(int i = 0; i < (Msubv/8); i++)
            {
                chess_report(*vtmp2);
                vtmp2++;
            }
            chess_report(0xC8FFEE);
            v8int32* vtmp4 = (v8int32*)act2_sum;
            for(int i = 0; i < (Nsubv/8); i++)
            {
                chess_report(*vtmp4);
                vtmp4++;
            }
            chess_report(0xC9FFEE);
            v8acc64* vtmp3 = (v8acc64*)bufC0;
            for(int i = 0; i < (Nsubv/8); i++)
            {
                chess_report(*vtmp3);
                vtmp3++;
            }
            chess_report(0xCAFFEE);
        }
#endif

    }
#if MINI_MHA_DEBUG
    vtmp = (v8uint16*)matO;
    if(rowIdx_reg == 2 && colIdx_reg == 0 && callcnt == 0)
    {
        chess_report(0xC2FFEE);
        for(int i = 0; i < (Msubv*Nsubv/8); i++)
        {
            chess_report(*vtmp);
            vtmp++;
        }
        chess_report(0xC3FFEE);
    }
#endif
}


void run_presoftmax_dequant(KernelArgs& args) 
{
    using Ta = int16_t;
    using Ts = int32_t;
    set_rnd_wrapper();

    MHA_mini_3p0_qdq::LayerParams* layer_params = static_cast<MHA_mini_3p0_qdq::LayerParams*>(args.params_data);
    KernelParams sum_params = layer_params->kernel_params;

    int Msubv = layer_params->Msubv;
    int Nsubv = layer_params->Nsubv;
    int8_t* tdm1 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->tdm1_addr));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->qdq_prm));
    int8_t* qkt_output = tdm1;

    // DeQuant QKt from int16 to bf16
    Ta* dq_zp = reinterpret_cast<Ta*>(byte_incr(qdq_prm, MHA_mini_3p0_qdq::DQ_offset));
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, MHA_mini_3p0_qdq::DQ_offset + 4));
    dequant_int16_to_bf16(qkt_output, tdm1, Msubv * Nsubv, *dq_zp, *dq_sc, false);
}

void run_bcast_add_mini(KernelArgs& args) 
{
    using Ta = int16_t;
    using Ts = int32_t;
    set_rnd_wrapper();
    MHA_mini_3p0_qdq::LayerParamsBcastAdd* layer_params = static_cast<MHA_mini_3p0_qdq::LayerParamsBcastAdd*>(args.params_data);

    uint16_t Msubv             = layer_params->Msubv; // Msubv is the number of rows in tdm2
    uint16_t Nsubv             = layer_params->Nsubv; // Nsubv is the number of columns in tdm2
    //uint16_t Nlayer            = layer_params->Nlayer; // Nlayer is the number of layers in the MHA
    uint16_t attn_vector_exist = layer_params->attn_vector_exist; // attn_vector_exist is 1 if there is an attention vector, 0 otherwise
    uint16_t bias_exist        = layer_params->bias_exist; // bias_exist is 1 if there is a bias vector, 0 otherwise
    
    int8_t* actv = static_cast<int8_t*>(conv_to_local_ptr(layer_params->act_addr));        
    int8_t* tdm2 = static_cast<int8_t*>(conv_to_local_ptr(layer_params->bias_addr));       
    int8_t* matM = static_cast<int8_t*>(conv_to_local_ptr(layer_params->mask_vector_addr));
    int8_t* out  = static_cast<int8_t*>(conv_to_local_ptr(layer_params->out_addr));
    int8_t* qdq_prm = static_cast<int8_t*>(conv_to_local_ptr(layer_params->dq_node_addr)); 
    
    // DeQuant QKt from int16 to bf16
    Ta* dq_zp = reinterpret_cast<Ta*>(byte_incr(qdq_prm, MHA_mini_3p0_qdq::DQ_offset));
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, MHA_mini_3p0_qdq::DQ_offset + 4));

    int perform_dq = layer_params->perform_dq; 
    if(perform_dq==1)
        dequant_int16_to_bf16(actv, actv, Msubv * Nsubv, *dq_zp, *dq_sc, false);


    // if there is attention vector, we need to broadcast it and populated tdm2
    if(attn_vector_exist==1) 
    {
        broadcast_bfloat16xbfloat16
        (
            matM,  // input (    1     x num_cols  ==   1   x Nsubv )
            tdm2, // output ( num_rows x num_cols  == Msubv x Nsubv )
            Msubv,
            Nsubv
        );
    }

    
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Debugging Purpose ONLY !! :
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#if MINI_MHA_DEBUG
    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    v16bfloat16* vtmp = (v16bfloat16*)tdm2;
    if(rowIdx_reg == 2 && colIdx_reg == 0 )
    {
        chess_report(0xC8FFEE);
        for(int i = 0; i < Msubv*Nsubv/16; i++)
        {
            chess_report(*vtmp);
            vtmp++;
        }
        chess_report(0xC9FFEE);
    }

    vtmp = (v16bfloat16*)actv;
    if(rowIdx_reg == 2 && colIdx_reg == 0 )
    {
        chess_report(0xCAFFEE);
        for(int i = 0; i < Msubv*Nsubv/16; i++)
        {
            chess_report(*vtmp);
            vtmp++;
        }
        chess_report(0xCBFFEE);
    }
#endif
    if((attn_vector_exist==1) || (bias_exist==1))
    {
        matadd_bf16_bf16_bf16
        (
            tdm2, actv, 
            out,         // change to tdm1 if enable Bcast-Add Attention Mask 
                         // change to matB for PSR mini mha DI pass
            ELW_ADD,  // matA --> tdm2, matB --> tdm1    matOut --> tdm1  , ElwAdd mode
            Msubv * Nsubv, 1, Msubv*Nsubv
        );
    }

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Debugging Purpose ONLY !! :
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#if MINI_MHA_DEBUG
    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    v16bfloat16* vtmp = (v16bfloat16*)out;
    if(rowIdx_reg == 2 && colIdx_reg == 0 && (bias_exist==1))
    {
        chess_report(0xCCFFEE);
        for(int i = 0; i < Msubv*Nsubv/16; i++)
        {
            chess_report(*vtmp);
            vtmp++;
        }
        chess_report(0xCDFFEE);
    }
    callcnt++;
#endif
    
}