#ifndef __TXNRT__
#include "matrix.hpp"
#define RUN_ON_AIE_ARRAY 1

#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif

#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif

#include <math.h>
#include "mha_validation.cpp"

#define STRINGIFY(x) #x
#define TO_STRING(x) STRINGIFY(x)

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif
void read_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::in | std::ios::binary);
    // Check if the file was opened successfully
    if (!file.is_open()) 
    {
        std::cout << "Error: Unable to open file " << filename << std::endl;
        //exit(EXIT_FAILURE); // Exit the program with an error code
    }
    file.read(data, size);
}

void write_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

int main(void)
{
    bool read_io_data = true;

    int const H  = H_IN;
    int const AieRows = 4;
    int const AieCols = 4;

    std::string folder = TO_STRING(TEST_BENCH_DIR);
    int Stq = SQ_IN;
    int const St = SK_IN;

    int const St_pad = (((St - 1) / ( 8 )) + 1) * ( 8 );
    int const St_128 = (((St - 1) / (128)) + 1) * (128);  // 128 as 128 cols == 16 cores * 8 (cols/core)
    int const Skv = St_128 / (AieRows * AieCols);
    int const Sq  = SQ_IN_SUBV;  //std::min(16, std::max(16, Stq / (AieCols * AieRows)));
    int const Dh  = DH_IN;       //sub-volume K per head
    int const Dt  = Dh * H;     //total volume K across all heads

    int const qry_rows = Stq;
    int const key_rows = St;
    int const out_rows = Stq;

    int const out_cols = St;
    int constexpr out_subv_cols = St;

    int const qry_cols = Dh;
    int const key_cols = Dh;
    using Tqkv = uint16_t;

    int const qry_size  = qry_rows * qry_cols * sizeof(Tqkv);
    int const key_size  = key_rows * key_cols * sizeof(Tqkv);
    int const vec_size =    1     * St_pad   * sizeof(Tqkv);
    int const mat_size  = qry_rows * key_rows * sizeof(Tqkv);

    int const qkv_size  = qry_size + key_size + vec_size + mat_size;
    
    int const qdq_param_size = sizeof(qdq_params) * H;
    
    int const out_size      = out_rows * out_cols * sizeof(Tqkv);
    int const out_pad_size  = out_rows *  St_pad  * sizeof(Tqkv);
    
    printf("size of qdq_param_size= %d Bytes\n", qdq_param_size);

#ifdef __TXNRT__
    void* aie_qkv = malloc(qkv_size*H);
    void* aie_out = malloc(out_size*H);
    void* aie_qdq_prm = malloc(qdq_param_size);
#else
    void* aie_qkv = adf::GMIO::malloc(qkv_size*H);
    void* aie_out = adf::GMIO::malloc(out_pad_size*H);
    void* aie_qdq_prm = adf::GMIO::malloc(qdq_param_size);
#endif
    void* cpu_out       = malloc(out_size*H);
    void* cpu_pad_out   = malloc(out_pad_size*H);
    void* model_out     = malloc(std::max(out_size*H, AieCols*AieRows*16*Dh));   // max to ensure all cores process at least 16 rows required for SM
    void* model_pad_out = malloc(std::max(out_pad_size*H, AieCols*AieRows*16*Dh));

    RowMajorMatrix<Tqkv> aie_Y(      out_rows*H,  St_pad , aie_out);
    RowMajorMatrix<Tqkv> cpu_Y(      out_rows*H, out_cols, cpu_out);
    RowMajorMatrix<Tqkv> model_Y(    out_rows*H, out_cols, model_out);
    RowMajorMatrix<Tqkv> model_pad_Y(out_rows*H,  St_pad , model_pad_out);

    /*  Read IO dataset from model  */
    printf("Reading IO dataset \n");
    for (int ih = 0; ih < H; ih++) {
        read_bin_file(folder+"q_mat_uint16.bin", (char*)aie_qkv + ih*qry_size, qry_size);
    }
    for (int ih = 0; ih < H; ih++) {
        read_bin_file(folder+"k_mat_uint16.bin", (char*)aie_qkv + H*qry_size + ih*key_size , key_size);
    }
    read_bin_file(folder+"b_mat_uint16.bin", ((char*)aie_qkv+H*qry_size+H*key_size+0*vec_size), vec_size);

    for (int ih = 0; ih < H; ih++) {
        uint16_t* ptr_matrix = (uint16_t*)((char*)aie_qkv+H*qry_size+H*key_size+H*vec_size+ih*mat_size);
        for(int i = 0; i < qry_rows; i++)
            for(int j = 0; j < key_rows; j++)
                ptr_matrix[i*key_rows+j] = 0;// i*key_rows+j;
    }

    for (int ih = 0; ih < H; ih++) {
        read_bin_file(folder+"SM_out_uint16.bin", (char*)model_out + ih*out_size, out_size);
    }

    for(int r = 0; r < model_pad_Y.num_rows; r++)
        for(int c = 0; c < model_Y.num_cols; c++)
            model_pad_Y.at(r, c) = model_Y.at(r,c); 
    
    // Reset Stq to the value supported by model data, for populating the qdq data structure
    Stq = SQ_IN;
    if(SQ_IN <= 256)
        Stq = (SQ_IN==151)? 151 : 256;
    else if(SQ_IN <= 1024)
        Stq = 1024;
    else if(SQ_IN <= 4096)
        Stq = 4096;
    // At here we make sure Stq is one of these 4 values that is the shapes of PSR or PSO3 : 151/256/1024/4096
    populate_qdq_params(Stq, Sq, Skv);


    assert(qdq_params[(16*3) + 4] == 0); // kernel wrapper does not implment V sum as C3 is 0

    for(int h = 0; h < H; h++)
        memcpy((void*)(static_cast<int8_t*>(aie_qdq_prm) + (h*sizeof(qdq_params))), (void*)qdq_params, sizeof(qdq_params));

#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", static_cast<char*>(aie_qkv), qkv_size*H);
            write_bin_file("wgt.bin", static_cast<char*>(aie_qdq_prm), qdq_param_size);
            write_bin_file("ofm.bin", static_cast<char*>(model_pad_out), out_pad_size*H);
    #else
        g_compute_graph.init();
        run_dma_layer_config(g_compute_graph, aie_out, aie_qkv, aie_qdq_prm);
        g_compute_graph.end();
    #endif // TXN_MODE
#endif // __AIESIM__

    printf("graph run triggered \n"); //print_matrix(model_Y, "Model Y = ");

    float const max_pct_diff = 1.0;
    int err_cnt;
#ifdef __TXNRT__
#else
    if (!read_io_data)
        err_cnt = check_result_rmse<RowMajorMatrix<Tqkv>, RowMajorMatrix<Tqkv>>(cpu_Y, aie_Y, max_pct_diff);
    else 
        err_cnt = check_result_rmse<RowMajorMatrix<Tqkv>, RowMajorMatrix<Tqkv>>(model_pad_Y, aie_Y, max_pct_diff);
    
    int err_cnt_TH = int(H*out_rows * out_cols * 0.12);
    bool fail = err_cnt >= err_cnt_TH;
    if (fail) 
        printf("%d x %d x %d x %d DI: FAIL\n", H*qry_rows, qry_cols, H*key_rows, key_cols);
    else 
        printf("%d x %d x %d x %d DI: PASS\n", H*qry_rows, qry_cols, H*key_rows, key_cols);
#endif

#ifdef __TXNRT__
    free(aie_qkv);
    free(aie_out);
    free(aie_qdq_prm);
#else
    adf::GMIO::free(aie_qkv);
    adf::GMIO::free(aie_out);
    adf::GMIO::free(aie_qdq_prm);
#endif
    free(cpu_out);

    
#ifdef __TXNRT__
    return 0;
#else
    assert(false);
    return fail;
#endif
}
