#define RUN_ON_AIE_ARRAY 1

#ifndef __TXNRT__
#include "matrix_di.hpp"

#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif

#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif

#include <math.h>
#include "mha_validation_di.cpp"

#define STRINGIFY(x) #x
#define TO_STRING(x) STRINGIFY(x)

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif
void read_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::in | std::ios::binary);
    file.read(data, size);
}

void write_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

void* allocate(int num_bytes)
{
#ifdef __TXNRT__
    return malloc(num_bytes);
#else
    return adf::GMIO::malloc(num_bytes);
#endif
}

void deallocate(void* ptr)
{
#ifdef __TXNRT__
    return free(ptr);
#else
    return adf::GMIO::free(ptr);
#endif
}

int main(void)
{
    bool read_io_data = true;

    int const B  = 1;
    int const AieRows = 4;
    int const AieCols = 4;

    int const H   = H_IN;
    int const Stq = SQ_IN;      // M
    int const St  = SK_IN;      // N
    int const Sq  = SQ_IN_SUBV; //
    int const Dh  = DH_IN;      // K
    int const G  = G_IN;      // number of groups


    int const Dt  = Dh * 1;  
    int const St_pad = (((St - 1) / 8) + 1) * 8;
    std::string folder = TO_STRING(TEST_BENCH_DIR);
    bool const hasAttnMask = (ATTN_MASK_EXIST==1)? true : false;
    int const Bias = BIAS_EXIST;
    bool const hasBias = (Bias != 0)? true : false;
    
    printf("Stq = %d, St = %d, Sq = %d, Dh = %d\n", Stq, St, Sq, Dh);

    
    int const qry_rows = Stq;
    int const key_rows = St;
    int const val_rows = St;
    int const val_cols = Dt;
    int const out_rows = Stq;

    int constexpr val_subv_cols = Dh;

    int const out_cols = Dt;
    int const qry_cols = Dt;
    int const key_cols = Dt;

    using Tqkv = uint16_t;
    using Tsm = uint16_t;

    int const qry_size = qry_rows * qry_cols * sizeof(Tqkv) * B;
    int const key_size = key_rows * key_cols * sizeof(Tqkv) * B;
    int const val_size = val_rows * val_cols * sizeof(Tqkv) * B;
    int const mask_size = hasAttnMask?  1     * St_pad * sizeof(Tqkv) * B : 0;
    int const bias_size = hasBias ? qry_rows * key_cols * sizeof(Tqkv) * B : 0;
    int const out_size = out_rows * out_cols * sizeof(Tqkv) * B;
    
    int const out_size1 = qry_rows  * (key_rows + 1) * sizeof(Tqkv) * B;
    int const out_size2 = qry_rows  * (key_rows) * sizeof(Tqkv) * B;

    int const qkv_size = qry_size + key_size + val_size; 
    int const qdq_param_size = sizeof(qdq_params) * H;
    printf("size of qdq_param_size= %d Bytes\n", qdq_param_size);


    void* aie_qkv     = allocate(qry_size*H + (key_size + val_size)*G + mask_size + bias_size);
    void* aie_out     = allocate(out_size*H);
    void* aie_qdq_prm = allocate(qdq_param_size);
    uint32_t* batch_qdq_params = (uint32_t*)(aie_qdq_prm);

    void* aie_qry = static_cast<void*>(aie_qkv);
    void* aie_key = static_cast<void*>(static_cast<uint8_t*>(aie_qkv) + qry_size*H);
    void* aie_val = static_cast<void*>(static_cast<uint8_t*>(aie_qkv) + qry_size*H + key_size*G);
    void* aie_bias = static_cast<void*>(static_cast<uint8_t*>(aie_qkv) + qry_size*H + key_size*G + val_size*G);;
    void* aie_msk = static_cast<void*>(static_cast<uint8_t*>(aie_qkv) + qry_size*H + key_size*G + val_size*G + bias_size*Bias);
    
    uint16_t* ptr_aie_msk = (uint16_t*)aie_msk;
    uint16_t* ptr_aie_bias = (uint16_t*)aie_bias;
    uint16_t* ptr_aie_qry  = (uint16_t*)aie_qry;
    uint16_t* ptr_aie_key  = (uint16_t*)aie_key;
    uint16_t* ptr_aie_val  = (uint16_t*)aie_val;

    void* cpu_out = malloc(out_size*H);
    void* mdl_out = malloc(std::max(out_size*H, AieCols*AieRows*16*Dh));   // max to ensure all cores process at least 16 rows required for SM

    void* mdl_out1 = malloc(std::max(out_size1*H, AieCols*AieRows*16*Dh));   // max to ensure all cores process at least 16 rows required for SM
    void* mdl_out2 = malloc(std::max(out_size2*H, AieCols*AieRows*16*Dh)); 

    RowMajorMatrix<Tqkv> aie_Y  (B*out_rows*H, out_cols, aie_out);
    RowMajorMatrix<Tqkv> cpu_Y  (B*out_rows*H, out_cols, cpu_out);
    RowMajorMatrix<Tqkv> model_Y(B*out_rows*H, out_cols, mdl_out);
    RowMajorMatrix<Tqkv> model_Y1(B*qry_rows*H, (key_rows+1), mdl_out1);
    RowMajorMatrix<Tqkv> model_Y2(B*qry_rows*H, (key_rows), mdl_out2);
    
    int64_t c0; 
    int32_t C1, C2, C3;
    Tqkv SQb, Sout;

    std::string Qry_filename = (!hasAttnMask)? "q_mat_uint16.bin"    : ((Stq==1)? "cpu_Q_uint16.bin" : "ST0_Q_uint16.bin");
    std::string Key_filename = (!hasAttnMask)? "k_mat_uint16.bin"    : ((Stq==1)? "cpu_Kt_uint16.bin" : "ST0_K_uint16.bin");
    std::string Val_filename = (!hasAttnMask)? "v_mat_uint16.bin"    : ((Stq==1)? "cpu_V_uint16.bin" : "ST0_V_uint16.bin");
    std::string Out_filename = (!hasAttnMask)? "SMxV_out_uint16.bin" : ((Stq==1)? "cpu_Out_uint16.bin" : "ST0_Out_uint16.bin");
    std::string M_filename   = (Stq==1)? "cpu_mask_bfloat16.bin" : "ST0_mask_bfloat16.bin";

    printf("Reading IO dataset \n");
    printf("folder: %s \n", folder.c_str());
#if ((!ATTN_MASK_EXIST) or (H_IN != 6) or (SQ_IN!=64))
    for (int ih = 0; ih < H; ih++) 
        read_bin_file(folder+Qry_filename, (char*)aie_qkv + ih*qry_size, qry_size); 
    
    for (int ih = 0; ih < G; ih++) 
        read_bin_file(folder+Key_filename, (char*)aie_qkv + H*qry_size + ih*key_size , key_size); 
    for (int ih = 0; ih < G; ih++) 
        read_bin_file(folder+Val_filename, (char*)aie_qkv + H*qry_size + G*key_size + ih*val_size  , val_size);
    
    for (int ih = 0; ih < H; ih++) 
        read_bin_file(folder+Out_filename, (char*)mdl_out + ih*out_size, out_size);  
#else //ATTN_MASK_EXIST and (H_IN == 6) and (SQ_IN==64)
    Qry_filename = "Q_batch6.bin";
    Key_filename = (G==H) ? "K_batch6.bin" : "K_batch3.bin";
    Val_filename = (G==H) ? "V_batch6.bin" : "V_batch3.bin";
    Out_filename = "O_batch6.bin";
    M_filename = "ST0_mask_bfloat16_batch6.bin";
    read_bin_file(folder+Qry_filename, (char*)aie_qry, qry_size*H); 
    read_bin_file(folder+Key_filename, (char*)aie_key, key_size*G); 
    read_bin_file(folder+Val_filename, (char*)aie_val, val_size*G); 
    read_bin_file(folder+Out_filename, (char*)mdl_out, out_size*H); 
#endif    
    if(hasAttnMask)
    {
        read_bin_file(folder+M_filename, (char*)aie_qkv + H*qry_size + G*key_size + G*val_size, mask_size);
    }
    
    if (hasBias)
    {
    Qry_filename = "Q.bin";
    Key_filename = "K.bin";
    Val_filename = "V.bin";
    Out_filename = "Out.bin";
    std::string Bias_filename = "B.bin";

    read_bin_file(folder+Qry_filename, (char*)aie_qry, qry_size*H); 
    read_bin_file(folder+Key_filename, (char*)aie_key, key_size*G); 
    read_bin_file(folder+Val_filename, (char*)aie_val, val_size*G); 
    read_bin_file(folder+Bias_filename, (char*)aie_bias, bias_size*Bias);
    read_bin_file(folder+Out_filename, (char*)mdl_out, out_size*H);
    }
    
if (Stq==151) {
    printf("Reading IO data from tensors\n");
    
    for (int ih = 0; ih < H; ih++) {
        read_bin_file(folder + "q_mat_uint16.bin", (char*)aie_qkv + ih * qry_size, qry_size);
    }
    
    for (int ih = 0; ih < H; ih++) {
        read_bin_file(folder + "k_mat_uint16.bin", (char*)aie_qkv + H * qry_size + ih * key_size, key_size);
    }
    
    read_bin_file(folder + "b_mat_uint16.bin", (char*)aie_qkv + H * qry_size + H * key_size, mask_size);

    for (int ih = 0; ih < H; ih++) {
        read_bin_file(folder + "SM_out_uint16.bin", (char*)mdl_out2 + ih * out_size2, out_size2);
    }
}

    for(int r = 0; r < model_Y2.num_rows; r++)
        for(int c = 0; c < model_Y2.num_cols; c++)
            model_Y1.at(r, c) = model_Y2.at(r,c); 

    // Common QKT parameter indices
    const int BASE0 = 16 * 0;
    const int BASE1 = 16 * 1;
    const int QKT_BASE = 16 * 2;
    const int SMV_BASE = 16 * 3;
    const int DQ_SM_BASE = 16 * 4;
    const int Q_SM_BASE = 16 * 5;
    
    // Helper lambda to set QKT parameters
    auto set_qkt_params = [&](int64_t c0, int32_t c1, int32_t c2, int32_t c3, int sout, int stdm, uint32_t* pbatch_qdq_params = nullptr) {
        uint32_t* ptr_qdq_params = (pbatch_qdq_params==nullptr)? qdq_params : pbatch_qdq_params;
        *(int64_t*)(&ptr_qdq_params[QKT_BASE + 0]) = c0;
        ptr_qdq_params[QKT_BASE + 2] = c1;
        ptr_qdq_params[QKT_BASE + 3] = c2;
        ptr_qdq_params[QKT_BASE + 4] = c3;
        ptr_qdq_params[QKT_BASE + 5] = Sq;
        ptr_qdq_params[QKT_BASE + 6] = St_pad;
        ptr_qdq_params[QKT_BASE + 7] = 0;  // SQb
        ptr_qdq_params[QKT_BASE + 8] = sout;
        ptr_qdq_params[QKT_BASE + 9] = stdm;
    };
    
    // Helper lambda to set SM*V parameters
    auto set_smv_params = [&](int64_t c0, int32_t c1, int32_t c2, int32_t c3, int sout, int stdm, uint32_t* pbatch_qdq_params = nullptr) {
        uint32_t* ptr_qdq_params = (pbatch_qdq_params==nullptr)? qdq_params : pbatch_qdq_params;
        *(int64_t*)(&ptr_qdq_params[SMV_BASE + 0]) = c0;
        ptr_qdq_params[SMV_BASE + 2] = c1;
        ptr_qdq_params[SMV_BASE + 3] = c2;
        ptr_qdq_params[SMV_BASE + 4] = c3;
        ptr_qdq_params[SMV_BASE + 5] = Sq;
        ptr_qdq_params[SMV_BASE + 6] = val_subv_cols;
        ptr_qdq_params[SMV_BASE + 7] = 0;
        ptr_qdq_params[SMV_BASE + 8] = sout;
        ptr_qdq_params[SMV_BASE + 9] = stdm;
    };
    
    // Helper lambda to set DQ before SM parameters
    auto set_dq_sm_params = [&](uint32_t zp, float scale, uint32_t* pbatch_qdq_params = nullptr) {
        uint32_t* ptr_qdq_params = (pbatch_qdq_params==nullptr)? qdq_params : pbatch_qdq_params;
        ptr_qdq_params[DQ_SM_BASE + 0] = zp;
        ptr_qdq_params[DQ_SM_BASE + 1] = float_to_bfloat16(scale * 1.442695041).value;
        ptr_qdq_params[DQ_SM_BASE + 2] = 0;
    };
    
    // Helper lambda to set Q after SM parameters
    auto set_q_sm_params = [&](uint32_t zp, float scale, uint32_t* pbatch_qdq_params = nullptr) {
        uint32_t* ptr_qdq_params = (pbatch_qdq_params==nullptr)? qdq_params : pbatch_qdq_params;
        ptr_qdq_params[Q_SM_BASE + 0] = zp;
        ptr_qdq_params[Q_SM_BASE + 1] = float_to_bfloat16(1.0 / scale).value;
        ptr_qdq_params[Q_SM_BASE + 2] = 1;
    };

// MuL DQ - dummy, not used in kernel
qdq_params[BASE0 + 0] = 0;
qdq_params[BASE0 + 1] = float_to_bfloat16(1.0).value;
// MuL Q - dummy, not used in kernel                        
qdq_params[BASE1 + 0] = 0;
qdq_params[BASE1 + 1] = float_to_bfloat16(1.0).value;

if (Stq == 256 and St == 77) {
        set_qkt_params(1889495152038784LL, -733530237, 3648128, -1100993630, 31, 7);
        set_smv_params(40433895866368LL, -953386690, 6949120, 0, 30, 8);
        set_dq_sm_params(35378, 0.00027649561525322497);
        set_q_sm_params(0, 0.000015220252); }
else if (Stq == 151) {
        // QKT
        *(int64_t*)(&qdq_params[(16*2) + 0]) = 460673320082688; // c0
        qdq_params[(16*2) + 2] = -94243407; // c1
        qdq_params[(16*2) + 3] = 943872;  //c2
        qdq_params[(16*2) + 4] = -139833162;  //c3
        qdq_params[(16*2) + 5] = Sq;       // M
        qdq_params[(16*2) + 6] = St_pad;   // N
        qdq_params[(16*2) + 7] = 1;  //SQb
        qdq_params[(16*2) + 8] = 27; //Sout
        qdq_params[(16*2) + 9] = 8;  //Stdm
        // SM *V
        *(int64_t*)(&qdq_params[(16*3) + 0]) = 8824815616000;
        qdq_params[(16*3) + 2] = -134639853;
        qdq_params[(16*3) + 3] = 2098688;
        qdq_params[(16*3) + 4] = 0;
        qdq_params[(16*3) + 5] = Sq;
        qdq_params[(16*3) + 6] = St_pad;
        qdq_params[(16*3) + 7] = 0;
        qdq_params[(16*3) + 8] = 28;
        qdq_params[(16*3) + 9] = 9;
        // DQ before SM
        qdq_params[(16*4) + 0] = 23589; //23589
        qdq_params[(16*4) + 1] = float_to_bfloat16(0.00034349862835370004 * 1.442695041).value; //0.00034349862835370004
        qdq_params[(16*4) + 2] = 0;   // Disable DQ node in softmax wrapper
        // Q after SM
        qdq_params[(16*5) + 0] = 0;
        qdq_params[(16*5) + 1] = float_to_bfloat16(1.0 / 0.000015259021893143654).value;  //0.000015259021893143654
        qdq_params[(16*4) + 2] = 1;   // Eable    Q node in softmax wrapper
    }
else if (Stq == 77 and St == 77 and hasBias) 
{
        set_qkt_params(582008778551424LL, -306138875, 1237120, -281193510, 29, 7);
        set_smv_params(13103408349184LL, -227772005, 2268416, 0, 29, 8);
        set_dq_sm_params(22302, 0.0004225457960274070);
        set_q_sm_params(0, 0.000015259021893143654);

}
else if (Stq == 131 and St == 77) {
        set_qkt_params(861536259127936LL, -362849262, 1401216, -355766553, 31, 7);
        set_smv_params(43320113889280LL, -838990851, 6674176, 0, 30, 8);
        set_dq_sm_params(49748, 0.0004402542836032808);
        set_q_sm_params(0, 0.0000149011885);}
 else if (Stq == 1024 && St == 77) {
        set_qkt_params(681946522523776LL, -359633814, 1091712, -245524323, 29, 7);
        set_smv_params(21994527522816LL, -792704643, 6464256, 0, 30, 8);
        set_dq_sm_params(36077, 0.0002566671755630523);
        set_q_sm_params(0, 0.000015172848);
    }
    else if (Stq == 4096 && St == 77) {
        set_qkt_params(1098450594258432LL, -500894562, 1568896, -406295036, 30, 7);
        set_smv_params(12106170630144LL, -212884998, 1452288, 0, 28, 8);
        set_dq_sm_params(33357, 0.00036885470035485923);
        set_q_sm_params(0, 0.000015185596);
    }
    else if (Stq == 64 && St == 77) {
        set_qkt_params(861536259127936LL, -362849262, 1401216, -355766553, 31, 7);
        set_smv_params(43320113889280LL, -838990851, 6674176, 0, 30, 8);
        set_dq_sm_params(49748, 0.0004402542836032808);
        set_q_sm_params(0, 0.0000149011885);
    }
else if (Stq == 64 and St == 64) 
{
#if !ATTN_MASK_EXIST 
    // QKT
        set_qkt_params(1754278565017600LL, -882267152, 3458176, -771983758, 32, 7);
        set_smv_params(76008036237312LL, -425752722, 1444224, 0, 31, 7);
        set_dq_sm_params(32793, 0.00005332884029485285);
        set_q_sm_params(0, 0.0000010263796);
#else

#if (H_IN != 6)

    // MuL DQ
    qdq_params[BASE0 + 0] = 22046;
    qdq_params[BASE0 + 1] = float_to_bfloat16(0.0012369307223707438 *  65535 * 0.00012207217514514923).value;
    // MuL Q
    qdq_params[BASE1 + 0] = 22046;
    qdq_params[BASE1 + 1] = float_to_bfloat16(1.0 / 0.00015461634029634297).value;

    set_qkt_params(604720101780224LL, -287107678, 1146496, -289006562, 29, 7);
    set_smv_params(36408437768192LL, -791173005, 2821760, 0, 30, 7);
    set_dq_sm_params(22046, 0.0012369307223707438 * 0.125);
    set_q_sm_params(0, 0.000015259021893143654);
#else
    
    // Use helper functions instead of manual parameter setting
    set_qkt_params(1420048154852096LL, -661706802, 2598272, -672424674, 30, 7, batch_qdq_params);
    set_smv_params(18688476446720LL, -302999004, 1271936, 0, 29, 7, batch_qdq_params);
    set_dq_sm_params(16007, 0.0015857701655477285 / (65535 * 0.00012207217514514923), batch_qdq_params);
    set_q_sm_params(0,  0.000015259021893143654, batch_qdq_params);
    //================ Q1 ===========================================
    batch_qdq_params += 96;

    set_qkt_params(1542775123213312LL, -738893750, 2800000, -702275000, 30, 7, batch_qdq_params);
    set_smv_params(38849052934144LL, -790364925, 3158400, 0, 30, 7, batch_qdq_params);
    set_dq_sm_params(22913, 0.0013778283027932048 / (65535 * 0.00012207217514514923), batch_qdq_params);
    set_q_sm_params(0, 0.000015259021893143654, batch_qdq_params);
    //================ Q2 ===========================================
    batch_qdq_params += 96;
   
    set_qkt_params(1862579270971776LL, -919981854, 3673728, -884823129, 31, 7, batch_qdq_params);
    set_smv_params(39701603942400LL, -984327603, 3510656, 0, 30, 7, batch_qdq_params);
    set_dq_sm_params(22074, 0.0015337113291025162 / (65535 * 0.00012207217514514923), batch_qdq_params);
    set_q_sm_params(0, 0.000015259021893143654, batch_qdq_params);
    //================ Q3 ===========================================
    batch_qdq_params += 96;

    set_qkt_params(1487099759052032LL, -689675886, 2708096, -697461662, 30, 7, batch_qdq_params);
    set_smv_params(4297920086016LL, -90835668, 381312, 0, 27, 7, batch_qdq_params);
    set_dq_sm_params(29807, 0.001417914405465126 / (65535 * 0.00012207217514514923), batch_qdq_params);
    set_q_sm_params(0, 0.000015259021893143654, batch_qdq_params);
    //================ Q4 ===========================================
    batch_qdq_params += 96;

    set_qkt_params(1200338539211392LL, -554330758, 2100608, -546272957, 30, 7, batch_qdq_params);
    set_smv_params(35268123951104LL, -730274769, 2918272, 0, 30, 7, batch_qdq_params);
    set_dq_sm_params(18077, 0.001493261894211173 / (65535 * 0.00012207217514514923), batch_qdq_params);
    set_q_sm_params(0, 0.000015259021893143654, batch_qdq_params);

    //================ Q5 ===========================================
    batch_qdq_params += 96;
    // DQ before SM
    set_qkt_params(604720101780224LL, -287107678, 1146496, -289006562, 29, 7, batch_qdq_params);
    set_smv_params(36408437768192LL, -791173005, 2821760, 0, 30, 7, batch_qdq_params);
    set_dq_sm_params(22046, 0.0012369307223707438 / (65535 * 0.00012207217514514923), batch_qdq_params);
    set_q_sm_params(0, 0.000015259021893143654, batch_qdq_params);

#endif
#endif
}
else if(Stq == 1 and St == 64) // PSMU_ST1
{
qdq_params[(16 * 0) + 0] = 37010;
qdq_params[(16 * 0) + 1] = float_to_bfloat16(0.0020485231652855873 * (65535 * 0.00012207217514514923)).value;
// MuL Q
qdq_params[(16 * 1) + 0] = 37010;
qdq_params[(16 * 1) + 1] = float_to_bfloat16(1.0 / 0.0002560653956606984).value;

set_qkt_params(1399989535580160LL, -628665160, 2462336, -650364496, 30, 7);
set_smv_params(2007957504LL, -30639, 128, 0, 16, 7);
set_dq_sm_params(37010, 0.0020485231652855873 * 0.125);
set_q_sm_params(0, 0.000015259021893143654);
}
    qdq_params[(16 * 4) + 2] = 0; // DQ disabled  
    qdq_params[(16 * 5) + 2] = 1; //  Q  enabled

    assert(qdq_params[(16*3) + 4] == 0); // kernel wrapper does not implment V sum as C3 is 0
    //memcpy((void*)(static_cast<int16_t*>(aie_qdq_prm)), (void*)qdq_params, qdq_param_size);
#if ((!ATTN_MASK_EXIST) or (H_IN != 6) or (SQ_IN!=64))
    for(int h = 0; h < H; h++)
        memcpy((void*)(static_cast<int8_t*>(aie_qdq_prm) + (h*sizeof(qdq_params))), (void*)qdq_params, sizeof(qdq_params));
#endif



    printf("graph init started \n");
#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", static_cast<char*>(aie_qkv), qry_size*H + (key_size + val_size)*G + mask_size);
            write_bin_file("wgt.bin", static_cast<char*>(aie_qdq_prm), H*sizeof(qdq_params));
            if (Stq == 151)
            write_bin_file("ofm.bin", static_cast<char*>(mdl_out1), out_size1*H);
            else
            write_bin_file("ofm.bin", static_cast<char*>(mdl_out), out_size*H);
#else
        g_compute_graph.init();
        run_dma_layer_config(g_compute_graph, aie_out, aie_qkv, aie_qdq_prm);
        g_compute_graph.end();
#endif // TXN_MODE
#endif // __AIESIM__
    printf("graph run triggered \n");


#ifdef __TXNRT__
    return 0;
#else
    float const relative_error = 1.0;
    if (Stq == 151)
    {float err_cnt = check_result_rmse<RowMajorMatrix<Tqkv>, RowMajorMatrix<Tqkv>>(model_Y1, aie_Y, relative_error);}
    else
    {float err_cnt = check_result_rmse<RowMajorMatrix<Tqkv>, RowMajorMatrix<Tqkv>>(model_Y, aie_Y, relative_error);}
    float err_cnt_TH = 7.0;
    
    bool fail = err_cnt >= err_cnt_TH;
    printf(strcat("%d x %d x %d x %d DI: ",(fail?"FAIL\n":"PASS\n")), B*out_rows*H, key_cols, B*val_rows*H, val_cols);
    
    deallocate(aie_qkv);
    deallocate(aie_out);
    deallocate(aie_qdq_prm);
    free(cpu_out);
    
    assert(false);
    return fail;
#endif
}
