#ifndef MATRIX_HPP
#define MATRIX_HPP

#include <assert.h>
#include <stdlib.h>
#include <iostream>
#include <algorithm>
#include <cmath>

#define SDXL_MUL_SCALE 1

// 5 QDQ - 16 (32-bit elem)
// Init a 64Bytes wide QDQ params
int const num_qdq_nodes = 6;
int const num_qdq_prm_per_node = 16;
uint32_t qdq_params[num_qdq_nodes*num_qdq_prm_per_node]= {0};

struct bfloat16_t
{
    uint16_t value;
};

inline uint32_t float_to_uint(float f)
{
    uint32_t i = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_i[0] = ptr_f[0];
    ptr_i[1] = ptr_f[1];
    ptr_i[2] = ptr_f[2];
    ptr_i[3] = ptr_f[3];
    return i;
}

inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}

inline bfloat16_t float_to_bfloat16(float fp)
{
    uint32_t bits = float_to_uint(fp);
    uint32_t lsb = (bits >> 16) & 0x1;
    uint32_t bias = 0x7FFF + lsb;
    uint32_t rnd = bits + bias;
    return bfloat16_t{uint16_t(rnd >> 16)};
}

inline float bfloat16_to_float(bfloat16_t bf)
{
    return uint_to_float(uint32_t(bf.value) << 16);
}

inline int row_major_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    return (row * num_cols) + col;
}

inline int col_major_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    return (col * num_rows) + row;
}

inline int w8_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    int constexpr zz = 8;
    return (row * zz) + (col % zz) + ((col / zz) * (num_rows * zz));
}

union Float32Bits {
  uint32_t u;
  float f;
};

const uint32_t kF32BfMantiBitDiff = 16;

float bfloat2float(uint16_t bfloatBits) {
  Float32Bits floatBits;
  floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
  return floatBits.f;
}

template<typename T>
struct RowMajorMatrix
{
    int const num_rows;
    int const num_cols;
    T* const data;

    RowMajorMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , data(static_cast<T*>(data))
    {}

    T& at(int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        int idx = row_major_index(row, col, num_rows, num_cols);
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }
};

void cpu_matmul_bf16(
    RowMajorMatrix<uint16_t> Act1,
    RowMajorMatrix<uint16_t> Act2,
    RowMajorMatrix<uint16_t> Y,
    bool perf_act2_transpose)
{
    assert(Y.num_rows    == Act1.num_rows);
    if(!perf_act2_transpose)
    {
        assert(Act1.num_cols == Act2.num_rows);   // k of m,k,n
        assert(Y.num_cols    == Act2.num_cols);
    }
    else
    {
        assert(Act1.num_cols == Act2.num_cols);   // k of m,k,n
        assert(Y.num_cols    == Act2.num_rows);
    }

    for (int r = 0; r < Y.num_rows; ++r)
    {
        for (int c = 0; c < Y.num_cols; ++c)
        {
            float acc = 0;
            if(!perf_act2_transpose)
            {
                for (int k = 0; k < Act1.num_cols; ++k)
                    acc += bfloat2float(Act1.at(r, k)) * \
                           bfloat2float(Act2.at(k, c));
            }
            else
            {
                for (int k = 0; k < Act1.num_cols; ++k)
                    acc += bfloat2float(Act1.at(r, k)) * \
                           bfloat2float(Act2.at(c, k));
            }
            Y.at(r, c) = (float_to_bfloat16(acc)).value;
        }
    }
}

void softmax
(
    RowMajorMatrix<uint16_t> X,
    RowMajorMatrix<uint16_t> Y,
    int true_cols
) //    float* dst, float* src, int height, int width)
{
    assert(Y.num_rows == X.num_rows);
    assert(Y.num_cols == X.num_cols);

    int height = Y.num_rows;
    int width  = true_cols;

	for(int y = 0; y < height; y++)
	{
		float rowsum = 0.0f;
		float rowmax = std::numeric_limits<float>::min();

		for(int x = 0; x < width; x++)
		{
			//int idx = y*width+x;
			//rowmax = std::max(src[idx], rowmax);
            rowmax = std::max(bfloat2float(X.at(y, x)), rowmax);
        }

		for(int x = 0; x < width; x++)
		{
			//int idx = y*width+x;
			//rowsum += approx_exp2(src[idx]-rowmax); //pow(2, (src[idx] - rowmax));
            rowsum += std::exp(bfloat2float(X.at(y, x)) - rowmax);
        }

		for(int x = 0; x < width; x++)
		{
			//int idx = y*width+x;
			//float exp = src[idx]-rowmax;
			//float pow2 = approx_exp2(src[idx]-rowmax); //pow(2, src[idx]-rowmax);
            //dst[idx] = pow2 / rowsum;

            float pow_e = std::exp(bfloat2float(X.at(y, x))-rowmax);
            Y.at(y, x) = (float_to_bfloat16(pow_e / rowsum)).value;
		}
	}
}

void cpu_mha
(
    RowMajorMatrix<uint16_t> Q,     // MxK
    RowMajorMatrix<uint16_t> K,     // NxK
    //RowMajorMatrix<uint16_t> V,   // NxL
    RowMajorMatrix<uint16_t> VT,    // LxN

    RowMajorMatrix<uint16_t> QKt,   // intermediates
    RowMajorMatrix<uint16_t> SM,    // intermediates

    RowMajorMatrix<uint16_t> Y,
    int true_dim
)
{
    assert(  Q.num_cols == K.num_cols);   // K
    //assert(  V.num_rows == K.num_rows);   // N
    assert( VT.num_cols == K.num_rows);
    assert(  Y.num_rows == Q.num_rows);   // M
    //assert(  Y.num_cols == V.num_cols);   // L
    assert(  Y.num_cols ==VT.num_rows);

    assert(QKt.num_rows == Q.num_rows);   // M
    assert(QKt.num_cols == K.num_rows);   // N
    assert( SM.num_rows == Q.num_rows);   // M
    assert( SM.num_cols == K.num_rows);   // N

    cpu_matmul_bf16( Q, K, QKt,  true);
    
    #if SDXL_MUL_SCALE
    for(int r = 0; r < QKt.num_rows; r++)
       for(int c = 0; c < QKt.num_cols; c++)
           //QKt.at(r, c) = QKt.at(r, c)*(float_to_bfloat16(0.125f)).value;
           QKt.at(r, c) = (float_to_bfloat16((bfloat2float(QKt.at(r,c))) * 0.125f)).value;
    #endif     

    //compute softmax on true col dim.
    softmax(QKt, SM, true_dim);


    cpu_matmul_bf16(SM, VT, Y, true); //cpu_matmul_bf16(QKt, V,  Y, false);

    /*for (int r = 0; r < Y.num_rows; ++r) {
        for (int c = 0; c < Y.num_cols; ++c) {

            float acc = 0;
            for (int k = 0; k < 64; ++k)
                acc += bfloat2float(QKt.at(r, k+768)) * \
                       bfloat2float(VT.at(c,  k+768));

            Y.at(r, c) = (float_to_bfloat16(acc)).value;
        }
    }*/

    //for(int r = 0; r < Y.num_rows; r++)
    //    for(int c = 0; c < Y.num_cols; c++)
    //        Y.at(r, c) = VT.at(r, c+768);

    //Validate QKt
    //for(int r = 0; r < Y.num_rows; r++)
    //    for(int c = 0; c < Y.num_cols; c++)
    //        Y.at(r, c) = QKt.at(r, c+768);

    //Validate VT
    //for(int r = 0; r < Y.num_rows; r++)
    //    for(int c = 0; c < Y.num_cols; c++)
    //        Y.at(r, c) = VT.at(r, c+768);
    /*for(int r = 0; r < Y.num_rows; r++)
    {
        // for Q
        //for(int c = 0; c < Y.num_cols; c++)
        //    Y.at(r, c) = Q.at(r, c);

        // for K and V
        for(int c = 0; c < Y.num_cols; c++)
            Y.at(r, c) = V.at(r+768, c);
    }*/
}




template<typename T>
void init_random(T mat, int min = -128, int max = 128)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            mat.at(i, j) = (rand() % (max - min)) + min;
        }
    }
}

/*template<typename T>
void init_random_bfloat16(T mat, int min = -128, int max = 128)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            mat.at(i, j) = float_to_bfloat16(float((rand() % (max - min)) + min)).value;
        }
    }
}*/

void init_random_bfloat16(RowMajorMatrix<uint16_t> X, int true_rows, int true_cols, float min = -1.0, float max = 1.0)
{
    for (int r = 0; r < X.num_rows; ++r) {
        for (int c = 0; c < X.num_cols; ++c) {
            float val = ((max - min) * (rand() / (float) RAND_MAX)) + min;
            if(r > true_rows || c > true_cols)
            {
             X.at(r, c) = 0;
            }
            else{
             X.at(r, c) = (float_to_bfloat16(val)).value;
            }
        }
    }
}

template<typename T>
void print_matrix(T mat, const char* msg = nullptr)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            std::cout << bfloat16_to_float(bfloat16_t{mat.at(i, j)}) << " ";
        }
        std::cout << "\n";
    }
}

#if 0
template<typename T>
void init_random_KV(T mat, int min = -128, int max = 128)
{
    for(int i = 0; i < mat.key_rows; ++i) {
        for (int j = 0; j < mat.key_cols; ++j) {
            mat.atK(i, j) = (rand() % (max - min)) + min;
        }
    }
    for(int i = 0; i < mat.val_rows; ++i) {
        for (int j = 0; j < mat.val_cols; ++j) {
            mat.atV(i, j) = (rand() % (max - min)) + min;
        }
    }
}


template<typename T>
void print_KV(T mat, const char* msg = nullptr)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < mat.key_rows; ++i) {
        for (int j = 0; j < mat.key_cols; ++j) {
            std::cout << static_cast<int64_t>(mat.atK(i, j)) << " ";
        }
        std::cout << "\n";
    }
    std::cout << "\n";
    for(int i = 0; i < mat.val_rows; ++i) {
        for (int j = 0; j < mat.val_cols; ++j) {
            std::cout << static_cast<int64_t>(mat.atV(i, j)) << " ";
        }
        std::cout << "\n";
    }
}
template<typename Ta, typename Tb>
int check_result(Ta cpu_Y, Tb aie_Y, float max_pct_diff = 0.0, bool enable_logging = true)
{
    int err_count = 0;
    float sum_pct_diff = 0.0;
    for (int r = 0; r < cpu_Y.num_rows; ++r) {
        for (int c = 0; c < cpu_Y.num_cols; ++c) {
            int diff = std::abs(cpu_Y.at(r, c) - aie_Y.at(r, c));
            float abs_ref = (float) std::abs(cpu_Y.at(r,c));
	    float denominator = (abs_ref == 0.0f)? abs_ref + 0.00001 : abs_ref;
	    float pct_diff = 100.0 * (diff / denominator);
	    //float pct_diff = 100.0 * (diff / (float) std::abs(cpu_Y.at(r, c)));
        //bool is_fail = (pct_diff > max_pct_diff) && (std::abs((cpu_Y.at(r,c)>>8) - (aie_Y.at(r,c)>>8)) > 2);
        bool is_fail = (pct_diff > max_pct_diff) && (std::abs((cpu_Y.at(r,c)) - (aie_Y.at(r,c))) > 2);

	    //if(std::isfinite(pct_diff))
	    sum_pct_diff += pct_diff;
	    if (is_fail)
	    {
                err_count += 1;
            }
            if (is_fail || enable_logging) {
                std::cout << "Y[" << r << ", " << c << "]: "
                          << "Expected: " << (int)(cpu_Y.at(r, c)) << ", "
                          << "Received: " << (int)(aie_Y.at(r, c)) << ", "
                          << "Pct Diff: " << pct_diff << "%\n";
            }
        }
    }
    float avg_pct_diff = sum_pct_diff / (cpu_Y.num_rows * cpu_Y.num_cols);
    //std::cout << "Average Relative Error = " << avg_pct_diff << "%\n";
    std::cout << "Error Count = " << err_count << "\n";
    return err_count;
}
template<typename Tin, typename Ts>
void init_fp(Tin dest, Ts* src, int H,  int rows, int cols)
{

    for(int h = 0; h < H; ++h) {
    for(int i = 0; i < rows; ++i) {
        for (int j = 0; j < cols; ++j) {
            dest.at(h, i, j) = src[h*rows*cols+ i * cols + j];
        }
    }
    }
}



template<typename T, int key_subv_rows, int key_subv_cols, int val_subv_rows, int val_subv_cols>
struct ActKVMatrix
{
    int const key_rows;
    int const key_cols;
    int const val_rows;
    int const val_cols;
    T* const data;

    ActKVMatrix(int key_rows, int key_cols, int val_rows, int val_cols, void* data)
        : key_rows(key_rows)
        , key_cols(key_cols)
        , val_rows(val_rows)
        , val_cols(val_cols)
        , data(static_cast<T*>(data))
    {
        assert(key_rows % key_subv_rows == 0);
        assert(key_cols % key_subv_cols == 0);
        assert(val_rows % val_subv_rows == 0);
        assert(val_cols % val_subv_cols == 0);
    }

    T& atK(int row, int col)
    {
        int const idx = row_major_index(row, col, key_rows, key_cols);
        assert(idx < (key_rows * key_cols) + (val_rows * val_cols));
        return data[idx];
    }

    T& atV(int row, int col)
    {
        int const idx = (key_rows * key_cols) + row_major_index(row, col, val_rows, val_cols);
        assert(idx < (key_rows * key_cols) + (val_rows * val_cols));
        return data[idx];
    }

    static int size(int key_rows, int key_cols, int val_rows, int val_cols)
    {
        return (key_rows * key_cols * sizeof(T)) + (val_rows * val_cols * sizeof(T));
    }
};

template<typename T, int subv_rows, int subv_cols>
struct ActQMatrix
{
    int const num_rows;
    int const num_cols;
    T* const data;

    ActQMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }

    T& at(int row, int col)
    {
        int const idx = row_major_index(row, col, num_rows, num_cols);
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    T& at(int idx)
    {
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }
};

template<typename T, int subv_rows, int subv_cols>
struct OutMatrix
{
    int const num_rows;
    int const num_cols;
    T* const data;

    OutMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }

    T& at(int row, int col)
    {
        int const idx = row_major_index(row, col, num_rows, num_cols);
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }
};
#endif
#endif // MATRIX_HPP
