#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <assert.h>
#include <stdlib.h>
#include <iostream>
#include <algorithm>
#include <cmath>
#include <limits>

// 5 QDQ - 16 (32-bit elem)
// Init a 64Bytes wide QDQ params
int const num_qdq_nodes = 6;
int const num_qdq_prm_per_node = 16;
uint32_t qdq_params[num_qdq_nodes*num_qdq_prm_per_node]= {0};




struct bfloat16_t
{
    uint16_t value;
};

inline uint32_t float_to_uint(float f)
{
    uint32_t i = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_i[0] = ptr_f[0];
    ptr_i[1] = ptr_f[1];
    ptr_i[2] = ptr_f[2];
    ptr_i[3] = ptr_f[3];
    return i;
}

inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}

inline bfloat16_t float_to_bfloat16(float fp)
{
    uint32_t bits = float_to_uint(fp);
    uint32_t lsb = (bits >> 16) & 0x1;
    uint32_t bias = 0x7FFF + lsb;
    uint32_t rnd = bits + bias;
    return bfloat16_t{uint16_t(rnd >> 16)};
}

inline float bfloat16_to_float(bfloat16_t bf)
{
    return uint_to_float(uint32_t(bf.value) << 16);
}

inline int row_major_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    return (row * num_cols) + col;
}

inline int col_major_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    return (col * num_rows) + row;
}

inline int w8_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    int constexpr zz = 8;
    return (row * zz) + (col % zz) + ((col / zz) * (num_rows * zz));
}

void populate_qdq_params(const uint32_t Stq, const uint32_t Sq, const uint32_t Skv){
    // supported shapes currently
    // MuL DQ - dummy, not used in kernel
    qdq_params[(16*0) + 0] = 0;
    qdq_params[(16*0) + 1] = float_to_bfloat16(1.0).value;
    // MuL Q - dummy, not used in kernel
    qdq_params[(16*1) + 0] = 0;
    qdq_params[(16*1) + 1] = float_to_bfloat16(1.0).value;
    
    if(Stq == 151) {
        // QKT
        *(int64_t*)(&qdq_params[(16*2) + 0]) = 460673320082688; // c0
        qdq_params[(16*2) + 2] = -94243407; // c1
        qdq_params[(16*2) + 3] = 943872;  //c2
        qdq_params[(16*2) + 4] = -139833162;  //c3
        qdq_params[(16*2) + 5] = Sq;       // M
        qdq_params[(16*2) + 6] = Skv;   // N
        qdq_params[(16*2) + 7] = 1;  //SQb
        qdq_params[(16*2) + 8] = 27; //Sout
        qdq_params[(16*2) + 9] = 8;  //Stdm


        // SM *V
        *(int64_t*)(&qdq_params[(16*3) + 0]) = 8824815616000;
        qdq_params[(16*3) + 2] = -134639853;
        qdq_params[(16*3) + 3] = 2098688;
        qdq_params[(16*3) + 4] = 0;
        qdq_params[(16*3) + 5] = Sq;
        qdq_params[(16*3) + 6] = Skv;
        qdq_params[(16*3) + 7] = 0;
        qdq_params[(16*3) + 8] = 28;
        qdq_params[(16*3) + 9] = 9;

        // DQ before SM
        qdq_params[(16*4) + 0] = 23589; //23589
        qdq_params[(16*4) + 1] = float_to_bfloat16(0.00034349862835370004 * 1.442695041).value; //0.00034349862835370004
        qdq_params[(16*4) + 2] = 0;   // Disable DQ node in softmax wrapper
        // Q after SM
        qdq_params[(16*5) + 0] = 0;
        qdq_params[(16*5) + 1] = float_to_bfloat16(1.0 / 0.000015259021893143654).value;  //0.000015259021893143654
        qdq_params[(16*4) + 2] = 1;   // Eable    Q node in softmax wrapper
    }
    else if(Stq == 256) {
        // QKT
        *(int64_t*)(&qdq_params[(16*2) + 0]) = 681330421901376; // c0
        qdq_params[(16*2) + 2] = -322147399; // c1
        qdq_params[(16*2) + 3] = 1250944;  //c2
        qdq_params[(16*2) + 4] = -304829643;  //c3
        qdq_params[(16*2) + 5] = Sq;       // M
        qdq_params[(16*2) + 6] = Skv;   // N
        qdq_params[(16*2) + 7] = 0;  //SQb
        qdq_params[(16*2) + 8] = 30; //Sout
        qdq_params[(16*2) + 9] = 7;  //Stdm


        // SM *V
        *(int64_t*)(&qdq_params[(16*3) + 0]) = 8824815616000;
        qdq_params[(16*3) + 2] = -134639853;
        qdq_params[(16*3) + 3] = 2098688;
        qdq_params[(16*3) + 4] = 0;
        qdq_params[(16*3) + 5] = Sq;
        qdq_params[(16*3) + 6] = Skv;
        qdq_params[(16*3) + 7] = 0;
        qdq_params[(16*3) + 8] = 28;
        qdq_params[(16*3) + 9] = 9;

        // DQ before SM
        qdq_params[(16*4) + 0] = 35625;
        qdq_params[(16*4) + 1] = float_to_bfloat16(0.0003217620251234621 * 1.442695041).value;

        // Q after SM
        qdq_params[(16*5) + 0] = 0;
        qdq_params[(16*5) + 1] = float_to_bfloat16(1.0 / 0.000009475293).value;
    }
    else if( Stq == 1024){
        // QKT
        *(int64_t*)(&qdq_params[(16*2) + 0]) = 1821801205459072; // c0
        qdq_params[(16*2) + 2] = -830276535; // c1
        qdq_params[(16*2) + 3] = 3198080;  //c2
        qdq_params[(16*2) + 4] = -827852990;  //c3
        qdq_params[(16*2) + 5] = Sq;       // M
        qdq_params[(16*2) + 6] = Skv;   // N
        qdq_params[(16*2) + 7] = 0;  //SQb
        qdq_params[(16*2) + 8] = 31; //Sout
        qdq_params[(16*2) + 9] = 7;  //Stdm


        // SM *V
        *(int64_t*)(&qdq_params[(16*3) + 0]) = 139439636480;
        qdq_params[(16*3) + 2] = -2616954;
        qdq_params[(16*3) + 3] = 161792;
        qdq_params[(16*3) + 4] = 0;
        qdq_params[(16*3) + 5] = Sq;
        qdq_params[(16*3) + 6] = Skv;
        qdq_params[(16*3) + 7] = 0;
        qdq_params[(16*3) + 8] = 22;
        qdq_params[(16*3) + 9] = 11;

        // DQ before SM
        qdq_params[(16*4) + 0] = 28469;
        qdq_params[(16*4) + 1] = float_to_bfloat16(0.0003867986670229584 * 1.442695041).value;

        // Q after SM
        qdq_params[(16*5) + 0] = 0;
        qdq_params[(16*5) + 1] = float_to_bfloat16(1.0 / 0.000012578514).value;
    }
    else if(Stq == 4096){
        // QKT
        *(int64_t*)(&qdq_params[(16*2) + 0]) = 207973141910528; // c0
        qdq_params[(16*2) + 2] = -94930576; // c1
        qdq_params[(16*2) + 3] = 381056;  //c2
        qdq_params[(16*2) + 4] = -98303517;  //c3
        qdq_params[(16*2) + 5] = Sq;       // M
        qdq_params[(16*2) + 6] = Skv;   // N
        qdq_params[(16*2) + 7] = 0;  //SQb
        qdq_params[(16*2) + 8] = 28; //Sout
        qdq_params[(16*2) + 9] = 7;  //Stdm


        // SM *V
        *(int64_t*)(&qdq_params[(16*3) + 0]) = 38334730600448;
        qdq_params[(16*3) + 2] = -456481683;
        qdq_params[(16*3) + 3] = 110534656;
        qdq_params[(16*3) + 4] = 0;
        qdq_params[(16*3) + 5] = Sq;
        qdq_params[(16*3) + 6] = Skv;
        qdq_params[(16*3) + 7] = 0;
        qdq_params[(16*3) + 8] = 30;
        qdq_params[(16*3) + 9] = 13;

        // DQ before SM
        qdq_params[(16*4) + 0] = 27389;
        qdq_params[(16*4) + 1] = float_to_bfloat16(0.00043027219362556934 * 1.442695041).value;

        // Q after SM
        qdq_params[(16*5) + 0] = 0;
        qdq_params[(16*5) + 1] = float_to_bfloat16(1.0 / 0.0000075458247).value;
    }
    else{
        std::cerr<<"Unsupported Seq len";
    }
    qdq_params[(16*4) + 2] = 0; // DQ disabled  
    qdq_params[(16*5) + 2] = 1; //  Q  enabled  
}


template<typename T>
struct RowMajorMatrix
{
    int const num_rows;
    int const num_cols;
    T* const data;

    RowMajorMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , data(static_cast<T*>(data))
    {}

    T& at(int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        int idx = row_major_index(row, col, num_rows, num_cols);
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }
};

template<typename T, int subv_rows, int subv_cols>
struct ScaleTensor
{
    int const num_heads;
    int const num_rows;
    int const num_cols;
    T* const data;

    ScaleTensor(int num_heads, int num_rows, int num_cols, void* data)
        : num_heads(num_heads)
        , num_rows(num_rows)
        , num_cols(num_cols)
        , data(static_cast<T*>(data))
    {}

    T& at(int head, int row, int col)
    {
        assert(head < num_heads);
        assert(row < num_rows);
        assert(col < num_cols);
        int constexpr subv_size = subv_rows * subv_cols;
        int const head_size = num_rows * num_cols;
        int const r = row % subv_rows;
        int const c = col % subv_cols;
        int const i = w8_index(r, c, subv_rows, subv_cols);
        int const rr = row / subv_rows;
        int const cc = col / subv_cols;
        int const ii = col_major_index(rr, cc, (num_rows / subv_rows), (num_cols / subv_cols));
        int const idx = i + (ii * subv_size) + (head * head_size);
        assert(idx < num_heads * head_size);
        return data[idx];
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return num_heads * num_rows * num_cols * sizeof(T);
    }
};

struct MhaParams
{
    static int const attn_dim = 512;
    static int const num_cores = 8;
    static int const mask_cols = attn_dim / num_cores;
    static int const gprb_rows = 96;
    static int const gprb_cols = 8;
    // NOTE: We pad num_scalars to 24 for memory alignment
    static int const num_scalars = 56;
    static int const core_elems = mask_cols + (gprb_rows * gprb_cols) + gprb_cols + num_scalars;

    uint8_t* const data;

    MhaParams(void* data) : data(static_cast<uint8_t*>(data))
    {}

    uint8_t& mask(int col)
    {
        assert(col < attn_dim);
        int const i   = col % mask_cols;
        int const ii  = col / mask_cols;
        int const idx = i + (ii * core_elems);
        assert(idx < num_cores * core_elems);
        return data[idx];
    }

    uint8_t& gprb_mat(int core, int row, int col)
    {
        assert(core < num_cores);
        assert(row < gprb_rows);
        assert(col < gprb_cols);
        int const i   = mask_cols + w8_index(row, col, gprb_rows, gprb_cols);
        int const ii  = core;
        int const idx = i + (ii * core_elems);
        assert(idx < num_cores * core_elems);
        return data[idx];
    }

    uint8_t& gprb_vec(int core, int col)
    {
        assert(core < num_cores);
        assert(col < gprb_cols);
        int const i   = mask_cols + (gprb_rows * gprb_cols) + col;
        int const ii  = core;
        int const idx = i + (ii * core_elems);
        assert(idx < num_cores * core_elems);
        return data[idx];
    }

    uint8_t& gprb_a(int core)
    {
        assert(core < num_cores);
        int const i   = mask_cols + (gprb_rows * gprb_cols) + gprb_cols + 0;
        int const ii  = core;
        int const idx = i + (ii * core_elems);
        assert(idx < num_cores * core_elems);
        return data[idx];
    }

    uint8_t& gprb_b(int core)
    {
        assert(core < num_cores);
        int const i   = mask_cols + (gprb_rows * gprb_cols) + gprb_cols + 1;
        int const ii  = core;
        int const idx = i + (ii * core_elems);
        assert(idx < num_cores * core_elems);
        return data[idx];
    }

    uint8_t& gprb_c(int core)
    {
        assert(core < num_cores);
        int const i   = mask_cols + (gprb_rows * gprb_cols) + gprb_cols + 2;
        int const ii  = core;
        int const idx = i + (ii * core_elems);
        assert(idx < num_cores * core_elems);
        return data[idx];
    }

    static int size()
    {
        return num_cores * core_elems * sizeof(uint8_t);
    }
};

template<typename T, int key_subv_rows, int key_subv_cols, int val_subv_rows, int val_subv_cols>
struct ActKVMatrix
{
    int const key_rows;
    int const key_cols;
    int const val_rows;
    int const val_cols;
    T* const data;

    ActKVMatrix(int key_rows, int key_cols, int val_rows, int val_cols, void* data)
        : key_rows(key_rows)
        , key_cols(key_cols)
        , val_rows(val_rows)
        , val_cols(val_cols)
        , data(static_cast<T*>(data))
    {
        assert(key_rows % key_subv_rows == 0);
        assert(key_cols % key_subv_cols == 0);
        assert(val_rows % val_subv_rows == 0);
        assert(val_cols % val_subv_cols == 0);
    }

    T& atK(int row, int col)
    {
        int const idx = row_major_index(row, col, key_rows, key_cols);
        assert(idx < (key_rows * key_cols) + (val_rows * val_cols));
        return data[idx];
    }

    T& atV(int row, int col)
    {
        int const idx = (key_rows * key_cols) + row_major_index(row, col, val_rows, val_cols);
        assert(idx < (key_rows * key_cols) + (val_rows * val_cols));
        return data[idx];
    }

    static int size(int key_rows, int key_cols, int val_rows, int val_cols)
    {
        return (key_rows * key_cols * sizeof(T)) + (val_rows * val_cols * sizeof(T));
    }
};

template<typename T, int subv_rows, int subv_cols>
struct ActQMatrix
{
    int const num_rows;
    int const num_cols;
    T* const data;

    ActQMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }

    T& at(int row, int col)
    {
        int const idx = row_major_index(row, col, num_rows, num_cols);
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    T& at(int idx)
    {
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }
};

template<typename T, int subv_rows, int subv_cols>
struct OutMatrix
{
    int const num_rows;
    int const num_cols;
    T* const data;

    OutMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }

    T& at(int row, int col)
    {
        int const idx = row_major_index(row, col, num_rows, num_cols);
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }
};

template<typename T>
void init_random(T mat, int min = -128, int max = 128)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            mat.at(i, j) = (rand() % (max - min)) + min;
            //mat.at(i, j) = i;    //(rand() % (max - min)) + min;
        }
    }
}

template<typename T>
void init_random_bfloat16(T mat, int min = -128, int max = 128)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            mat.at(i, j) = float_to_bfloat16(float((rand() % (max - min)) + min)).value;
        }
    }
}

template<typename T>
void init_random_KV(T mat, int min = -128, int max = 128)
{
    for(int i = 0; i < mat.key_rows; ++i) {
        for (int j = 0; j < mat.key_cols; ++j) {
            mat.atK(i, j) = (rand() % (max - min)) + min;
        }
    }
    for(int i = 0; i < mat.val_rows; ++i) {
        for (int j = 0; j < mat.val_cols; ++j) {
            mat.atV(i, j) = (rand() % (max - min)) + min;
        }
    }
}

template<typename T>
void init_random_scale_tensor(T mat, int min = -128, int max = 128)
{
    for(int h = 0; h < mat.num_heads; ++h) {
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                mat.at(h, i, j) = (rand() % (max - min)) + min;
            }
        }
    }
}

void init_random_mha_params(MhaParams prm, int min, int max)
{
    for (int col = 0; col < MhaParams::attn_dim; ++col) {
        prm.mask(col) = (rand() % (max - min)) + min;
    }

    for (int row = 0; row < MhaParams::gprb_rows; ++row) {
        for (int col = 0; col < MhaParams::gprb_cols; ++col) {
            int16_t val = (rand() % (max - min)) + min;
            for (int core = 0; core < MhaParams::num_cores; ++core) {
                prm.gprb_mat(core, row, col) = val;
            }
        }
    }

    for (int col = 0; col < MhaParams::gprb_cols; ++col) {
        int16_t val = (rand() % (max - min)) + min;
        for (int core = 0; core < MhaParams::num_cores; ++core) {
            prm.gprb_vec(core, col) = val;
        }
    }

    for (int core = 0; core < MhaParams::num_cores; ++core) {
        prm.gprb_a(core) = 8;
        prm.gprb_b(core) = 1;
        prm.gprb_c(core) = 2;
    }
}

template<typename T=uint8_t>
struct gemm_qdq_param
{
    // q params
	uint32_t C0;
	uint32_t C1;
	uint32_t C2;
	uint32_t C3;
	T sqb;
	T sout;

    // dq params
    T zero_point;
    float scale;
};

template<typename T>
T saturate(uint32_t val, T min, T max) {
    return std::min(std::max(val, min), max);
}

template<typename T>
T srs_to_T(uint32_t x, int shift) {
    uint32_t max_value = std::numeric_limits<T>::max();
    uint32_t min_value = std::numeric_limits<T>::min();
    return static_cast<T>(saturate<uint32_t>(std::round(x >> shift), min_value, max_value));
}

template<typename T1, typename T2, typename T3, typename T4>
void qdq_asym_golden(RowMajorMatrix<T1> A, RowMajorMatrix<T2> B, RowMajorMatrix<T3> X, gemm_qdq_param<T4> &param, RowMajorMatrix<T4> Y, bool B_is_transposed)
{
    uint32_t ifmsum[A.num_rows];
    for (int r = 0; r < A.num_rows; ++r) {
        ifmsum[r] = 0;
        for (int c = 0; c < A.num_cols; ++c) {
            ifmsum[r] += A.at(r,c);
            //printf("r:%d c:%d, A.at(r,c) = %d, ifmsum[r] = %d\n", r, c,A.at(r,c), ifmsum[r]);
        }
    }

    if(!B_is_transposed)
    {
        uint32_t wgtsum[B.num_cols];
        for (int c = 0; c < B.num_cols; ++c) {
            wgtsum[c] = 0;
            for (int r = 0; r < B.num_rows; ++r) {
                wgtsum[c] += B.at(r,c);
            }
            printf("c:%d, wgtsum[r] = %d\n", c, wgtsum[c]);
        }
        for (int c = 0; c < X.num_cols; ++c) {
            for (int r = 0; r < X.num_rows; ++r) {
                Y.at(r, c) = srs_to_T<T4>((int64_t)X.at(r, c) * param.C2 + (int64_t)param.C1 * ifmsum[r] +  (((int64_t)wgtsum[c]*param.C3 + param.C0) >> param.sqb), param.sout );
                //printf("r:%d, c:%d, (int64_t)X.at(r, c)=%d, (int64_t)C1 * (int64_t)ifmsum[r] =%d, ((int64_t)C0[c] << sqb)=%d, C2=%d, (int64_t)ifmsum[r]=%d\n", r,c, (int64_t)X.at(r, c), (int64_t)C1 * (int64_t)ifmsum[r]  , ((int64_t)C0[c] << sqb), C2, (int64_t)ifmsum[r]);
                //printf("r:%d, c:%d, X:%d, C1:%d, C0[c]:%d, sqb:%d, sout:%d, X*+: %d, shifted:%d\n", r,c,X.at(r, c),C1,C0[c],sqb,sout,X.at(r, c) * C1 + (C0[r] << sqb), (X.at(r, c) * C1 + (C0[r] << sqb))>>sout);
            }
        }
    }
    else
    {
        uint32_t wgtsum[B.num_rows];
        for (int r = 0; r < B.num_rows; ++r) {
            wgtsum[r] = 0;
            for (int c = 0; c < B.num_cols; ++c) {
                wgtsum[r] += B.at(r,c);
            }
            printf("c:%d, wgtsum[r] = %d\n", r, wgtsum[r]);
        }
        for (int c = 0; c < X.num_cols; ++c) {
            for (int r = 0; r < X.num_rows; ++r) {
                Y.at(r, c) = srs_to_T<T4>((int64_t)X.at(r, c) * param.C2 + (int64_t)param.C1 * ifmsum[r] +  (((int64_t)wgtsum[c]*param.C3 + param.C0) >> param.sqb), param.sout );
                //printf("r:%d, c:%d, (int64_t)X.at(r, c)=%d, (int64_t)C1 * (int64_t)ifmsum[r] =%d, ((int64_t)C0[c] << sqb)=%d, C2=%d, (int64_t)ifmsum[r]=%d\n", r,c, (int64_t)X.at(r, c), (int64_t)C1 * (int64_t)ifmsum[r]  , ((int64_t)C0[c] << sqb), C2, (int64_t)ifmsum[r]);
                //printf("r:%d, c:%d, X:%d, C1:%d, C0[c]:%d, sqb:%d, sout:%d, X*+: %d, shifted:%d\n", r,c,X.at(r, c),C1,C0[c],sqb,sout,X.at(r, c) * C1 + (C0[r] << sqb), (X.at(r, c) * C1 + (C0[r] << sqb))>>sout);
            }
        }
    }
}

template<typename T>
void print_matrix(T mat, const char* msg = nullptr)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            std::cout << static_cast<int64_t>(mat.at(i, j)) << " ";
        }
        std::cout << "\n";
    }
}

template<typename T>
void print_KV(T mat, const char* msg = nullptr)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < mat.key_rows; ++i) {
        for (int j = 0; j < mat.key_cols; ++j) {
            std::cout << static_cast<int64_t>(mat.atK(i, j)) << " ";
        }
        std::cout << "\n";
    }
    std::cout << "\n";
    for(int i = 0; i < mat.val_rows; ++i) {
        for (int j = 0; j < mat.val_cols; ++j) {
            std::cout << static_cast<int64_t>(mat.atV(i, j)) << " ";
        }
        std::cout << "\n";
    }
}

void print_mha_params(MhaParams prm)
{
    std::cout << "Attn Mask =\n";
    for (int col = 0; col < MhaParams::attn_dim; ++col) {
        std::cout << prm.mask(col) << " ";
    }
    std::cout << "\n";

    std::cout << "GPRB Mat =\n";
    for (int row = 0; row < MhaParams::gprb_rows; ++row) {
        for (int col = 0; col < MhaParams::gprb_cols; ++col) {
            std::cout << prm.gprb_mat(0, row, col) << " ";
        }
        std::cout << "\n";
    }

    std::cout << "GPRB Vec =\n";
    for (int col = 0; col < MhaParams::gprb_cols; ++col) {
        std::cout << prm.gprb_vec(0, col) << " ";
    }
    std::cout << "\n";

    std::cout << "GPRB a = " << prm.gprb_a(0) << "\n";
    std::cout << "GPRB b = " << prm.gprb_b(1) << "\n";
    std::cout << "GPRB c = " << prm.gprb_c(2) << "\n";
}

template<typename Ta, typename Tb>
int check_result(Ta cpu_Y, Tb aie_Y, float max_num_cols, float max_pct_diff = 0.0, bool enable_logging = true)
{
    int err_count = 0;
    int max_err = 0;
    float sum_pct_diff = 0.0;
    float num_valid_pct_diff = 0.0;
    long long sum_squares = 0;
    long long sum_abs_diff = 0;
    for (int r = 0; r < cpu_Y.num_rows; ++r) {
        for (int c = 0; c < cpu_Y.num_cols; ++c) {
            int diff = std::abs(cpu_Y.at(r, c) - aie_Y.at(r, c));
            float abs_ref = (float) std::abs(cpu_Y.at(r,c));
	         float denominator = (abs_ref == 0.0f)? abs_ref + 0.00001 : abs_ref;
	         float pct_diff = 100.0 * (diff / denominator);
	         //float pct_diff = 100.0 * (diff / (float) std::abs(cpu_Y.at(r, c)));
            //bool is_fail = (pct_diff > max_pct_diff) && (std::abs((cpu_Y.at(r,c)>>8) - (aie_Y.at(r,c)>>8)) > 2);
            bool is_fail = (pct_diff > max_pct_diff) && (std::abs((cpu_Y.at(r,c)) - (aie_Y.at(r,c))) > 2);

	        sum_pct_diff += (abs_ref != 0.0f)? pct_diff : 0.0f;
            num_valid_pct_diff += (abs_ref != 0.0f)? 1.0f : 0.0f;


            int abs_diff = std::abs((cpu_Y.at(r,c)) - (aie_Y.at(r,c)));
            if(c < max_num_cols)
            {
                sum_squares += (abs_diff* abs_diff);
                sum_abs_diff += (abs_diff);
            }
             if (is_fail)
	         {
                err_count += 1;
            }
            if (true) { // is_fail) {
                std::cout << "Y[" << r << ", " << c << "]: "
                          << "Expected: " << (int)(cpu_Y.at(r, c)) << ", "
                          << "Received: " << (int)(aie_Y.at(r, c)) << ", "
                          << "Pct Diff: " << pct_diff << "%\n";
            }
            max_err = (diff > max_err) ? diff : max_err;
        }
    }

    float mean_sq_error = (float)(sum_squares) / (cpu_Y.num_rows * cpu_Y.num_cols);
    float mean_abs_error = (float)(sum_abs_diff) / (cpu_Y.num_rows * cpu_Y.num_cols);
    float avg_pct_diff = sum_pct_diff / num_valid_pct_diff;//(cpu_Y.num_rows * cpu_Y.num_cols);
    std::cout << "Average Relative Error = " << avg_pct_diff << "%\n";
    std::cout << "Error Count = " << err_count << "\n";
    std::cout << "Max error = " << max_err << "\n";

    std::cout << "MSE : = " << mean_sq_error << "\n";
    std::cout << "MAE : = " << mean_abs_error << "\n";
    return err_count;
}

template<typename Ta, typename Tb>
float check_result_rmse(Ta cpu_Y, Tb aie_Y, float max_relative_err_percentage_tolerance, bool verbose = true)
{
    int err_count = 0;
    double SumErSq = 0.0;
    float SumAbsErr = 0.0;
    float relative_error = 0.0;
    float max_error = 0.0; 
    float maxEP = 0.0;     

    for (int r = 0; r < cpu_Y.num_rows; ++r) {
        for (int c = 0; c < cpu_Y.num_cols; ++c) {
            float err = std::abs(static_cast<float>(cpu_Y.at(r, c)) - static_cast<float>(aie_Y.at(r, c)));
            SumAbsErr += err;
            SumErSq += std::pow((double)err, 2);
            double abs_cpu_val = std::abs(cpu_Y.at(r, c));
            double relative_err_percentage = (abs_cpu_val == 0.0) ? 0.0 : (std::abs(cpu_Y.at(r, c) - aie_Y.at(r, c)) / abs_cpu_val) * 100;
            relative_error += relative_err_percentage;

            if (err > max_error) {
                max_error = err;
            }
            if (relative_err_percentage > maxEP) {
                maxEP = relative_err_percentage;
            }
            if (relative_err_percentage > max_relative_err_percentage_tolerance) {
                err_count += 1;
                if (verbose) {
                    std::cout << "Fail Y[" << r << ", " << c << "]: "
                              << "Expected: " << (int)(cpu_Y.at(r, c)) << ", "
                              << "Received: " << (int)(aie_Y.at(r, c)) << ", "
                              << "Pct Diff: " << relative_err_percentage << "%\n";
                }
            } else {
                if (verbose) {
                    std::cout << "Pass Y[" << r << ", " << c << "]: "
                              << "Expected: " << (int)(cpu_Y.at(r, c)) << ", "
                              << "Received: " << (int)(aie_Y.at(r, c)) << ", "
                              << "Pct Diff: " << relative_err_percentage << "%\n";
                }
            }
        }
    }
    
    float RMSE = std::sqrt(SumErSq / (cpu_Y.num_rows * cpu_Y.num_cols));
    float MAE = SumAbsErr / (cpu_Y.num_rows * cpu_Y.num_cols);
    float average_relative_error = relative_error / (cpu_Y.num_rows * cpu_Y.num_cols);
    std::cout << "Root Mean square Error = " << RMSE << "\n";
    std::cout << "Max Error = " << max_error << "\n"; 
    std::cout << "Mean Absolute Error = " << MAE << "\n";
    std::cout << "Average Relative Error Percentage = " << average_relative_error << "\n";
    std::cout << "Error Count = " << err_count << "\n";
    std::cout << "Max Relative Error Percentage = " << maxEP << "\n";
    return average_relative_error;
}

#endif // MATRIX_HPP
