#ifndef GEMM_HOST_RUNTIME_HPP
#define GEMM_HOST_RUNTIME_HPP

#include <assert.h>
#include <stdlib.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <cmath>
#include <string.h>
#include <cstdint>
#include <type_traits> 

struct bfloat16_t
{
    uint16_t value;
};

int round_to_multiple(int x, int m)
{
    return ((x + m - 1) / m) * m;
}

template<typename T>
T int8_to_int2(T in_val, int idx)
{
    T val = (in_val >> (idx * 2)) & 0x03;

    if constexpr (std::is_same_v<T, uint8_t>) {
        return val;
    } else {
        // Sign-extend 2-bit value to 8-bit signed integer
        if (val > 1) {
            return val | 0xFC; // 0xFC = 11111100 to preserve sign
        } else {
            return val;
        }
    }
}

template<typename T>
T int8_to_int4(T in_val, int idx)
{
    if constexpr (std::is_same_v<T, uint8_t>) {
        return ( (idx == 1) ? in_val >> 4 : (in_val & 0x0f) ); 
    } else {
        if (idx == 1){
            return (in_val >> 4);
        } else {
            if ((in_val & 0x0f) > 7) {
                return (in_val & 0x0f) | (0xf0);
            } else {
                return (in_val & 0x0f);
            }
        }
    }
}

uint8_t int4_pack_to_int8(uint8_t in_l, uint8_t in_h, int is_odd) {
  uint8_t out_val;
  if (is_odd == 0) {
    out_val = (in_l & 0x0F) + ((in_h & 0x0F) << 4);
  } else {
    out_val = (in_l >> 4) + ((in_h >> 4) << 4);
  }
  return (out_val);
}

inline uint32_t float_to_uint(float f)
{
    uint32_t i = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_i[0] = ptr_f[0];
    ptr_i[1] = ptr_f[1];
    ptr_i[2] = ptr_f[2];
    ptr_i[3] = ptr_f[3];
    return i;
}

inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}

inline bfloat16_t float_to_bfloat16(float fp)
{
    uint32_t bits = float_to_uint(fp);
    uint32_t lsb = (bits >> 16) & 0x1;
    uint32_t bias = 0x7FFF + lsb;
    uint32_t rnd = bits + bias;
    return bfloat16_t{uint16_t(rnd >> 16)};
}

inline float bfloat16_to_float(bfloat16_t bf)
{
    return uint_to_float(uint32_t(bf.value) << 16);
}


// Init a 64Bytes wide QDQ params
int32_t gemm_qdq_params[16]= {0};
int32_t silu_gelu_qdq_params[16]= {0};
int16_t rope_qdq_params[32]= {0}; //TODO: Check why the other two is in int32 and not int16 (aligned with collateral)
int16_t elew_qdq_params[32]= {0}; //TODO: Check why the other two is in int32 and not int16 (aligned with collateral)

inline int row_major_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    return (row * num_cols) + col;
}

inline int col_major_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    return (col * num_rows) + row;
}

inline int w8_index(int row, int col, int num_rows, int num_cols)
{
    assert(row < num_rows);
    assert(col < num_cols);
    int constexpr zz = 8;
    return (row * zz) + (col % zz) + ((col / zz) * (num_rows * zz));
}

inline int h4_index(int row, int col, int num_rows, int num_cols)
{
    int constexpr zz = 4;
    return (col * zz) + (row % zz) + ((row / zz) * (zz * num_cols));
}

struct QdqCoef
{
    int c2;
    int c1;
    int64_t* c0;
    int shift_Qb;
    int shift_Qout;
};

template<typename T, int subv_rows, int subv_cols>
struct ActMatrix
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    T* const data;

    ActMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(1)
        , data(static_cast<T*>(data))
    {
        //assert(num_rows % subv_rows == 0);
        //assert(num_cols % subv_cols == 0);
    }

    ActMatrix(int num_heads, int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , data(static_cast<T*>(data))
    {
        //assert(num_rows % subv_rows == 0);
        //assert(num_cols % subv_cols == 0);
    }

    T& at(int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        int const idx = (row*num_cols) + col;
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    T& at(int head, int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        assert(head < num_heads);
        int head_idx = head * num_rows * num_cols;
        int const idx = head_idx + (row*num_cols) + col;
        assert(idx < num_heads * num_rows * num_cols);
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return num_heads * num_rows * num_cols * sizeof(T);
    }
};

template<typename T>
struct WgtMatrix
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    int const subv_rows;
    int const subv_cols;
    T* const data;

    WgtMatrix(int num_rows, int num_cols, int subv_rows, int subv_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(1)
        , subv_rows(subv_rows)
        , subv_cols(subv_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }


    WgtMatrix(int num_heads, int num_rows, int num_cols, int subv_rows, int subv_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , subv_rows(subv_rows)
        , subv_cols(subv_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }

    T& at(int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        int const subv_size = subv_rows * subv_cols;
        int const r = row % subv_rows;
        int const c = col % subv_cols;
        int const i = w8_index(r, c, subv_rows, subv_cols);
        int const rr = row / subv_rows;
        int const cc = col / subv_cols;
        int const ii = col_major_index(rr, cc, (num_rows / subv_rows), (num_cols / subv_cols));
        int const idx = i + (ii * subv_size);
        assert(idx < num_rows * num_cols);
        return data[idx];
    }

    T& at(int head, int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        assert(head < num_heads);
        int head_idx = head * num_rows * num_cols;
        int const subv_size = subv_rows * subv_cols;
        int const r = row % subv_rows;
        int const c = col % subv_cols;
        int const i = w8_index(r, c, subv_rows, subv_cols);
        int const rr = row / subv_rows;
        int const cc = col / subv_cols;
        int const ii = col_major_index(rr, cc, (num_rows / subv_rows), (num_cols / subv_cols));
        int const idx = head_idx + i + (ii * subv_size);
        assert(idx < num_heads * num_rows * num_cols);
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return num_heads * num_rows * num_cols * sizeof(T);
    }
};

template<typename T> 
struct int4_WgtMatrix
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    int const subv_rows;
    int const subv_cols;
    T* const data;

    int4_WgtMatrix(int num_heads, int num_rows, int num_cols, int subv_rows, int subv_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , subv_rows(subv_rows)
        , subv_cols(subv_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }

    T& block(int head, int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        assert(head < num_heads);
        int head_idx = head * num_rows * num_cols;
        int const subv_size = subv_rows * subv_cols / 2;
        int const r = row % subv_rows;
        int const c = col % subv_cols;
        int const i = w8_index(r, c, subv_rows, subv_cols) / 2;
        int const rr = row / subv_rows;
        int const cc = col / subv_cols;
        int const ii = col_major_index(rr, cc, (num_rows / subv_rows), (num_cols / subv_cols));
        int const idx = head_idx + i + (ii * subv_size);
        //printf("num_rows = %d, num_cols = %d, row = %d, col = %d , i = %d, idx = %d \n", num_rows, num_cols, row, col, i, idx);
        assert(idx < (num_heads * num_rows * num_cols / 2));
        return data[idx];
    }

    T at(int row, int col)
    {
        return (int8_to_int4(block(0/*head*/, row, col), (col%2) ));
    }

    T at(int head, int row, int col)
    {
        return (int8_to_int4(block(head, row, col), (col%2) ));
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(uint8_t) / 2; // or return a pair: requires more changes
    }
};

template<typename T> 
struct int2_WgtMatrix
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    int const subv_rows;
    int const subv_cols;
    T* const data;

    int2_WgtMatrix(int num_heads, int num_rows, int num_cols, int subv_rows, int subv_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , subv_rows(subv_rows)
        , subv_cols(subv_cols)
        , data(static_cast<T*>(data))
    {
        assert(num_rows % subv_rows == 0);
        assert(num_cols % subv_cols == 0);
    }

    T& block(int head, int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        assert(head < num_heads);
        int head_idx = head * num_rows * num_cols;
        int const subv_size = subv_rows * subv_cols / 4;
        int const r = row % subv_rows;
        int const c = col % subv_cols;
        int const i = w8_index(r, c, subv_rows, subv_cols) / 4;
        int const rr = row / subv_rows;
        int const cc = col / subv_cols;
        int const ii = col_major_index(rr, cc, (num_rows / subv_rows), (num_cols / subv_cols));
        int const idx = head_idx + i + (ii * subv_size);
        //printf("num_rows = %d, num_cols = %d, row = %d, col = %d , i = %d, idx = %d \n", num_rows, num_cols, row, col, i, idx);
        assert(idx < (num_heads * num_rows * num_cols / 4));
        return data[idx];
    }

    T at(int head, int row, int col)
    {
        return (int8_to_int2(block(head, row, col), (col%4) ));
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return (num_heads * num_rows * num_cols * sizeof(uint8_t) / 4);
    }
};

template<typename T, int subv_rows, int subv_cols, int aie_rows = 4, int aie_cols = 2>
struct OutMatrix
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    T* const data;

    OutMatrix(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(1)
        , data(static_cast<T*>(data))
    {
        //assert(num_rows % (subv_rows * aie_rows) == 0);
        //assert(num_cols % (subv_cols * aie_cols) == 0);
    }

    OutMatrix(int num_heads, int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , data(static_cast<T*>(data))
    {
        //assert(num_rows % (subv_rows * aie_rows) == 0);
        //assert(num_cols % (subv_cols * aie_cols) == 0);
    }


    T& at(int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        int const idx = (row*num_cols) + col;
        return data[idx];
    }

    T& at(int head, int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        assert(head < num_heads);
        int head_idx = head * num_rows * num_cols;
        int const idx = head_idx + (row*num_cols) + col;
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return num_heads * num_rows * num_cols * sizeof(T);
    }
};

template<typename T, int subv_rows, int subv_cols>
struct BiasVector
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    T* const data;

    BiasVector(int num_cols, void* data)
        : num_rows(1)
        , num_cols(num_cols)
        , num_heads(1)
        , data(static_cast<T*>(data))
    {}

    BiasVector(int num_heads, int num_cols, void* data)
        : num_rows(1)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , data(static_cast<T*>(data))
    {}

    T& at(int row, int col)
    {
        int const idx = col + row*num_cols;
        return data[idx];
    }

    T& at(int head, int row, int col)
    {   
        int head_idx = head * num_cols;
        int const idx = head_idx + col + row*num_cols;
        return data[idx];
    }

    static int size(int num_cols)
    {
        return num_cols * sizeof(T);
    }

    static int size(int num_heads, int num_cols)
    {
        return num_heads * num_cols * sizeof(T);
    }

};

template<typename T, int subv_rows, int subv_cols>
struct BlockBiasVector
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    T* const data;

    BlockBiasVector(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(1)
        , data(static_cast<T*>(data))
    {}

    BlockBiasVector(int num_heads, int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , data(static_cast<T*>(data))
    {}

    T& at(int row, int col)
    {
        int const idx = col + row*num_cols;
        return data[idx];
    }

    T& at(int head, int row, int col)
    {   
        int head_idx = head * num_rows * num_cols;
        int const idx = head_idx + col + row*num_cols;
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return num_heads * num_rows * num_cols * sizeof(T);
    }

};

template<typename T, int subv_rows, int subv_cols>
struct BlockBiasVectorInt2
{
    int const num_rows;
    int const num_cols;
    int const num_heads;
    T* const data;

    BlockBiasVectorInt2(int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(1)
        , data(static_cast<T*>(data))
    {}

    BlockBiasVectorInt2(int num_heads, int num_rows, int num_cols, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_heads(num_heads)
        , data(static_cast<T*>(data))
    {}

    T& block(int head, int row, int col)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        assert(head < num_heads);
        int head_idx = head * num_rows * num_cols;
        int const idx = head_idx + (row*num_cols) + col;
        assert(idx < num_heads * num_rows * num_cols);
        return data[idx/4];
    }

    // Unpack int2 value
    int at(int head, int row, int col)
    {
        return int8_to_int2(block(head, row, col), (col % 4));
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(uint8_t) / 4;
    }

    static int size(int num_heads, int num_rows, int num_cols)
    {
        return num_heads * num_rows * num_cols * sizeof(uint8_t) / 4;
    }

};

template<typename T>
void init_random(T mat, int64_t min, int64_t max)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            mat.at(i, j) = (rand() % (max - min)) + min;
        }
    }
}

template<typename t>
void init_bmm_random(t mat, int64_t min, int64_t max, int perm[], int unpadded_dim)
{
    for(int head = 0; head < mat.num_heads; ++head){
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                int dim[3] = {head, i ,j};
                if (dim[perm[2]] >= unpadded_dim){
                        mat.at(head, i, j) = 0;
                }
                else{
                        mat.at(head, i, j) = (rand() % (max - min)) + min;
                }
            }
        }
    }
}


template<typename T>
void init_bmm_random(T mat, int64_t min, int64_t max)
{
    for(int head = 0; head < mat.num_heads; ++head){
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                mat.at(head, i, j) = (rand() % (max - min)) + min;
            }
        }
    }
}

template<typename T>
void init_wgt_random(T mat, int64_t min, int64_t max)
{
    for(int h = 0; h < mat.num_heads; ++h){
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                mat.block(h, i, j) = (rand() % (max - min)) + min;
            }
        }
    }
}

template<typename T>
void init_const(T mat, int v = 0)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            mat.at(i, j) = v;
        }
    }
}

template<typename T>
void init_const_int2(T mat, int v = 0)
{
    for(int head = 0; head < mat.num_heads; ++head){
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                mat.block(head, i, j) = v;
            }
        }
    }
}

template<typename T>
void init_bmm_const(T mat, int v = 0)
{
    for(int head = 0; head < mat.num_heads; head++) {
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                mat.at(head, i, j) = v;
            }
        }
    }
}

template<typename T>
void init_diagonalmaxtrix(T mat, int v = 0)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            mat.at(i, j) = (i==j)?v:0;
        }
    }
}


template<typename T1, typename T2, typename T3>
void dequant(T1 mat, T2 Out, float s, T3 z)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            Out.at(i, j) = (float)(mat.at(i,j) - z) * s;
        }
    }
}

template<typename T1, typename T2, typename T3>
void quant_bfloat16_to_int16(T1 mat, T2 Out, float inv_s, T3 z)
{
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            Out.at(i, j) = static_cast<int16_t>(std::round(bfloat16_to_float(mat.at(i,j)) * inv_s) + z);
        }
    }
}

template<typename T1, typename T2, typename T3>
void quant_bmm_bfloat16_to_int16(T1 mat, T2 Out, float inv_s, T3 z)
{
    for(int head = 0; head < mat.num_heads; head++){
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                Out.at(head, i, j) = static_cast<int16_t>(std::round(bfloat16_to_float(mat.at(head,i,j)) * inv_s) + z);
            }
        }
    }
}

template<typename T1, typename T2, typename T3>
void quant_bmm_bfloat16_to_int8(T1 mat, T2 Out, float inv_s, T3 z)
{
    for(int head = 0; head < mat.num_heads; head++){
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                Out.at(head, i, j) = static_cast<int8_t>(std::round(bfloat16_to_float(mat.at(head,i,j)) * inv_s) + z);
            }
        }
    }
}

template<typename T>
void print_matrix(T mat, const char* msg = nullptr)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
                std::cout << static_cast<int64_t>(mat.at(i, j)) << " ";
        }
        std::cout << "\n";
    }
}

template<typename T>
void print_bmm_matrix(T mat, const char* msg = nullptr)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int head = 0; head < mat.num_heads; ++head) {
        for(int i = 0; i < mat.num_rows; ++i) {
            for (int j = 0; j < mat.num_cols; ++j) {
                    std::cout << static_cast<int64_t>(mat.at(head, i, j)) << " ";
            }
            std::cout << "\n";
        }
    }
}

template<typename T>
void print_matrix_float(T mat, const char* msg = nullptr)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < mat.num_rows; ++i) {
        for (int j = 0; j < mat.num_cols; ++j) {
            std::cout << static_cast<float>(mat.at(i, j)) << " ";
        }
        std::cout << "\n";
    }
}

template<typename T>
T saturate(T val, T min, T max) {
    return std::min(std::max(val, min), max);
}

//int8_t srs_to_int8(int32_t x, int shift, bool sign = true) {
//    if(sign)
//        return static_cast<int8_t>(saturate<int32_t>(std::round(x >> shift), -128, 127));
//    else{
//        return static_cast<uint8_t>(saturate<uint32_t>(std::round(x >> shift), 0, 255));
//    }
//}

int8_t srs_to_int8(int64_t x, int shift, bool sign = true) {
  if (sign)
    return static_cast<int8_t>(
        saturate<int32_t>(((x >> (shift - 1)) + 1) >> 1, INT8_MIN, INT8_MAX));
  else {
    int64_t inp_floor = (x >> shift);
    int64_t inp_frac = x - (inp_floor << shift);
    if (inp_frac == (1 << (shift - 1))) {
      if (inp_floor % 2) { // odd
        return static_cast<uint8_t>(saturate<uint32_t>(inp_floor + 1, 0, UINT8_MAX));
      } else {
        return static_cast<uint8_t>(saturate<uint32_t>(inp_floor, 0, UINT8_MAX));
      }
    } else {
      return static_cast<uint8_t>(
          saturate<uint32_t>(((x >> (shift - 1)) + 1) >> 1, 0, UINT8_MAX));
    }
  }
}

//int32_t srs_to_int32(int64_t x, int shift, bool sign = true) {

//int16_t srs_to_int16(int32_t x, int shift, bool sign = true) {
//    if(sign)
//        return static_cast<int16_t>(saturate<int32_t>(std::round(x >> shift), INT16_MIN, INT16_MAX));
//    else{
//        return static_cast<uint16_t>(saturate<uint32_t>(std::round(x >> shift), 0, UINT16_MAX));
//    }
//}

int16_t srs_to_int16(int64_t x, int shift, bool sign = true) {
  if (sign)
    return static_cast<int16_t>(
        saturate<int32_t>(((x >> (shift - 1)) + 1) >> 1, INT16_MIN, INT16_MAX));
  else {
    int64_t inp_floor = (x >> shift);
    int64_t inp_frac = x - (inp_floor << shift);
    if (inp_frac == (1 << (shift - 1))) {
      if (inp_floor % 2) { // odd
        return static_cast<uint16_t>(saturate<uint32_t>(inp_floor + 1, 0, UINT16_MAX));
      } else {
        return static_cast<uint16_t>(saturate<uint32_t>(inp_floor, 0, UINT16_MAX));
      }
    } else {
      return static_cast<uint16_t>(
          saturate<uint32_t>(((x >> (shift - 1)) + 1) >> 1, 0, UINT16_MAX));
    }
  }
}

//int32_t srs_to_int32(int64_t x, int shift, bool sign = true) {
//    if(sign)
//        return static_cast<int32_t>(saturate<int32_t>(std::round(x >> shift), INT32_MIN, INT32_MAX));
//    else{
//        return static_cast<uint32_t>(saturate<uint32_t>(std::round(x >> shift), 0, UINT32_MAX));
//    }
//}

int32_t srs_to_int32(int64_t x, int shift, bool sign = true) {
  if (shift == 0){
    if (sign) {
        return static_cast<int32_t>(saturate<int32_t>(x, INT32_MIN, INT32_MAX));
    }
    else{
        return static_cast<int32_t>(saturate<uint32_t>(x, 0, UINT32_MAX));
    }
  }
  if (sign)
    return static_cast<int32_t>(
        saturate<int32_t>(((x >> (shift - 1)) + 1) >> 1, INT32_MIN, INT32_MAX));
  else {
    int64_t inp_floor = (x >> shift);
    int64_t inp_frac = x - (inp_floor << shift);
    if (inp_frac == (1 << (shift - 1))) {
      if (inp_floor % 2) { // odd
        return static_cast<uint32_t>(saturate<uint32_t>(inp_floor + 1, 0, UINT32_MAX));
      } else {
        return static_cast<uint32_t>(saturate<uint32_t>(inp_floor, 0, UINT32_MAX));
      }
    } else {
      return static_cast<uint32_t>(
          saturate<uint32_t>(((x >> (shift - 1)) + 1) >> 1, 0, UINT32_MAX));
    }
  }
}

int64_t sls_to_int64(int64_t x, int shift, bool sign = true) {
    // NOTE: No rounding when upshifted
    if(sign)
        return static_cast<int64_t>(x << shift);
    else{
        return static_cast<uint64_t>(x << shift);
    }
}

template<typename Ti, typename Toi, typename Tout>
void qdq_golden(Ti A,
                Toi X, int32_t C2, int32_t C1,
                int64_t* C0, uint8_t sqb, uint8_t sout, Tout Y)
{
   uint32_t ifmsum[A.num_rows];
    for (int r = 0; r < A.num_rows; ++r) {
        ifmsum[r] = 0;
        for (int c = 0; c < A.num_cols; ++c) {
            ifmsum[r] += A.at(r,c);
        }
        printf("r:%d, ifmsum[c] = %d\n", r, ifmsum[r]);
    }

    for (int c = 0; c < X.num_cols; ++c) {
        for (int r = 0; r < X.num_rows; ++r) {
            Y.at(r, c) = srs_to_int16((int64_t)X.at(r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (C0[c] << sqb)  , sout, false );
        }
    }
}

template<typename Ti, typename Toi, typename Tout>
void qdq_golden_singleC0(Ti A,
                Toi X, int32_t C2, int32_t C1,
                int32_t C0, uint8_t sqb, uint8_t sout, Tout Y)
{
    for(int head = 0; head < A.num_heads; ++head){
        uint32_t ifmsum[A.num_rows];
        for (int r = 0; r < A.num_rows; ++r) {
            ifmsum[r] = 0;
            for (int c = 0; c < A.num_cols; ++c) {
                ifmsum[r] += A.at(head, r,c);
                //printf("r:%d c:%d, A.at(r,c) = %d, ifmsum[r] = %d\n", r, c,A.at(r,c), ifmsum[r]);
            }
            printf("r:%d, ifmsum[c] = %d\n", r, ifmsum[r]);
        }

        for (int c = 0; c < X.num_cols; ++c) {
            for (int r = 0; r < X.num_rows; ++r) {
                Y.at(head, r, c) = srs_to_int16((int64_t)X.at(head, r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (C0 << sqb)  , sout, false );
            }
        }
    }
}

template<typename Ti, typename Toi, typename Tout, typename Tc0, typename Tc1, typename Tc2>
void qdq_golden(Ti A, Toi X, Tc2 C2, Tc1 C1,
                Tc0* C0_vec, Tc1* C1_vec, Tc2* C2_vec, uint8_t sqb, uint8_t sout, Tout Y, int Vec_coeffs=1)
{
   uint32_t ifmsum[A.num_rows];
    for (int r = 0; r < A.num_rows; ++r) {
        ifmsum[r] = 0;
        for (int c = 0; c < A.num_cols; ++c) {
            ifmsum[r] += A.at(r,c);
        }
    }

    for (int c = 0; c < X.num_cols; ++c) {
        for (int r = 0; r < X.num_rows; ++r) {
            if(Vec_coeffs > 1){
                Y.at(r, c) = srs_to_int16((int64_t)X.at(r, c) * (int64_t)C2_vec[c] + (int64_t)C1_vec[c] * (int64_t)ifmsum[r] +  (C0_vec[c] << sqb)  , sout, false );
            }else{
                Y.at(r, c) = srs_to_int16((int64_t)X.at(r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (C0_vec[c] << sqb)  , sout, false );
            }
        }
    }
}

template<typename Ti, typename Toi, typename Tout, typename Tc0, typename Tc1, typename Tc2>
void qdq_bmm_golden(Ti A, Toi X, Tc2 C2, Tc1 C1,
                Tc0* C0_vec, Tc1* C1_vec, Tc2* C2_vec, uint8_t sqb, uint8_t sout, Tout Y, int Vec_coeffs=1)
{
    for(int head = 0; head < A.num_heads; head++) {
        uint32_t ifmsum[A.num_rows];
        for (int r = 0; r < A.num_rows; ++r) {
            ifmsum[r] = 0;
            for (int c = 0; c < A.num_cols; ++c) {
                ifmsum[r] += A.at(head, r,c);
            }
        }
    
        for (int c = 0; c < X.num_cols; ++c) {
            for (int r = 0; r < X.num_rows; ++r) {
                if(Vec_coeffs > 1){
                    Y.at(head, r, c) = srs_to_int16((int64_t)X.at(head, r, c) * (int64_t)C2_vec[c] + (int64_t)C1_vec[c] * (int64_t)ifmsum[r] +  (C0_vec[c] << sqb)  , sout, false );
                }else{
                    Y.at(head, r, c) = srs_to_int16((int64_t)X.at(head, r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (C0_vec[c] << sqb)  , sout, false );
                }
            }
        }
    }
}

template<typename Ti, typename Toi, typename Tout, typename Tc0, typename Tc1, typename Tc2>
void qdq_bmm_golden_int8(Ti A, Toi X, Tc2 C2, Tc1 C1,
                Tc0* C0_vec, Tc1* C1_vec, Tc2* C2_vec, uint8_t sqb, uint8_t sout, Tout Y, int Vec_coeffs=1)
{
    for(int head = 0; head < A.num_heads; head++) {
        uint32_t ifmsum[A.num_rows];
        for (int r = 0; r < A.num_rows; ++r) {
            ifmsum[r] = 0;
            for (int c = 0; c < A.num_cols; ++c) {
                ifmsum[r] += A.at(head, r,c);
            }
        }
    
        for (int c = 0; c < X.num_cols; ++c) {
            for (int r = 0; r < X.num_rows; ++r) {
                if(Vec_coeffs > 1){
                    Y.at(head, r, c) = srs_to_int8((int64_t)X.at(head, r, c) * (int64_t)C2_vec[c] + (int64_t)C1_vec[c] * (int64_t)ifmsum[r] +  (C0_vec[c] << sqb)  , sout, false );
                }else{
                    Y.at(head, r, c) = srs_to_int8((int64_t)X.at(head, r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (C0_vec[c] << sqb)  , sout, false );
                }
            }
        }
    }
}


template<typename T1, typename T2, typename T3, typename T4, typename Tc0, typename Tc1, typename Tc2, typename Tc3>
void qdq_asym_golden(T1 A, T2 B, T3 X,  Tc0 C0, Tc1 C1, Tc2 C2, Tc3 C3, 
                     Tc0* C0_vec, Tc1* C1_vec, Tc2* C2_vec, Tc3* C3_vec, uint8_t sqb, uint8_t sout, T4 Y, int Vec_coeffs=1)
{
    int32_t ifmsum[A.num_rows];
    for (int r = 0; r < A.num_rows; ++r) {
        ifmsum[r] = 0;
        for (int c = 0; c < A.num_cols; ++c) {
            ifmsum[r] += A.at(r,c);
            //printf("r:%d c:%d, A.at(r,c) = %d, ifmsum[r] = %d\n", r, c,A.at(r,c), ifmsum[r]);
        }
        printf("r:%d, ifmsum[c] = %d\n", r, ifmsum[r]);
    }
    int32_t wgtsum[B.num_cols];
    for (int c = 0; c < B.num_cols; ++c) {
        wgtsum[c] = 0;
        for (int r = 0; r < B.num_rows; ++r) {
            wgtsum[c] += B.at(r,c);
        }
        printf("c:%d, wgtsum[r] = %d\n", c, wgtsum[c]);
    }

    for (int c = 0; c < X.num_cols; ++c) {
        for (int r = 0; r < X.num_rows; ++r) {
            if(Vec_coeffs > 1){
                Y.at(r, c) = srs_to_int16((int64_t)X.at(r, c) * (int64_t)C2_vec[c] + (int64_t)C1_vec[c] * (int64_t)ifmsum[r] +  (((int64_t)wgtsum[c]*(int64_t)C3_vec[c] + C0_vec[c]) << sqb)  , sout, false);
            }
            else{
                Y.at(r, c) = srs_to_int16((int64_t)X.at(r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (((int64_t)wgtsum[c]*(int64_t)C3 + C0) << sqb)  , sout, false);
            }
        }
    }
}

template<typename T1, typename T2, typename T3, typename T4, typename Tc0, typename Tc1, typename Tc2, typename Tc3>
void qdq_asym_bmm_golden_int16(T1 A, T2 B, T3 X,  Tc0 C0, Tc1 C1, Tc2 C2, Tc3 C3, 
                     Tc0* C0_vec, Tc1* C1_vec, Tc2* C2_vec, Tc3* C3_vec, uint8_t sqb, uint8_t sout, T4 Y, int Vec_coeffs=1)
{
    for(int head = 0; head < A.num_heads; head++) {
        int32_t ifmsum[A.num_rows];
        for (int r = 0; r < A.num_rows; ++r) {
            ifmsum[r] = 0;
            for (int c = 0; c < A.num_cols; ++c) {
                ifmsum[r] += A.at(head, r,c);
                //printf("r:%d c:%d, A.at(r,c) = %d, ifmsum[r] = %d\n", r, c,A.at(r,c), ifmsum[r]);
            }
            printf("r:%d, ifmsum[c] = %d\n", r, ifmsum[r]);
        }
        int32_t wgtsum[B.num_cols];
        for (int c = 0; c < B.num_cols; ++c) {
            wgtsum[c] = 0;
            for (int r = 0; r < B.num_rows; ++r) {
                wgtsum[c] += B.at(head,r,c);
            }
            printf("c:%d, wgtsum[r] = %d\n", c, wgtsum[c]);
        }
    
        for (int c = 0; c < X.num_cols; ++c) {
            for (int r = 0; r < X.num_rows; ++r) {
                if (r < A.num_rows && c < B.num_cols) {
                    if(Vec_coeffs > 1){
                        Y.at(head, r, c) = srs_to_int16((int64_t)X.at(head, r, c) * (int64_t)C2_vec[c] + (int64_t)C1_vec[c] * (int64_t)ifmsum[r] +  (((int64_t)wgtsum[c]*(int64_t)C3_vec[c] + C0_vec[c]) << sqb)  , sout, false);
                    }
                    else{
                        Y.at(head, r, c) = srs_to_int16((int64_t)X.at(head, r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (((int64_t)wgtsum[c]*(int64_t)C3 + C0) << sqb)  , sout, false);
                    }
                }
            }
        }
    }
}

template<typename T1, typename T2, typename T3, typename T4, typename Tc0, typename Tc1, typename Tc2, typename Tc3>
void qdq_asym_bmm_golden_int8(T1 A, T2 B, T3 X,  Tc0 C0, Tc1 C1, Tc2 C2, Tc3 C3, 
                     Tc0* C0_vec, Tc1* C1_vec, Tc2* C2_vec, Tc3* C3_vec, uint8_t sqb, uint8_t sout, T4 Y, int Vec_coeffs=1)
{
    for(int head = 0; head < A.num_heads; head++) {
        int32_t ifmsum[A.num_rows];
        for (int r = 0; r < A.num_rows; ++r) {
            ifmsum[r] = 0;
            for (int c = 0; c < A.num_cols; ++c) {
                ifmsum[r] += A.at(head, r,c);
                //printf("r:%d c:%d, A.at(r,c) = %d, ifmsum[r] = %d\n", r, c,A.at(r,c), ifmsum[r]);
            }
            printf("r:%d, ifmsum[c] = %d\n", r, ifmsum[r]);
        }
        int32_t wgtsum[B.num_cols];
        for (int c = 0; c < B.num_cols; ++c) {
            wgtsum[c] = 0;
            for (int r = 0; r < B.num_rows; ++r) {
                wgtsum[c] += B.at(head,r,c);
            }
            printf("c:%d, wgtsum[r] = %d\n", c, wgtsum[c]);
        }
    
        for (int c = 0; c < X.num_cols; ++c) {
            for (int r = 0; r < X.num_rows; ++r) {
                if (r < A.num_rows && c < B.num_cols) {
                    if(Vec_coeffs > 1){
                        Y.at(head, r, c) = srs_to_int8((int64_t)X.at(head, r, c) * (int64_t)C2_vec[c] + (int64_t)C1_vec[c] * (int64_t)ifmsum[r] +  (((int64_t)wgtsum[c]*(int64_t)C3_vec[c] + C0_vec[c]) << sqb)  , sout, false);
                    }
                    else{
                        Y.at(head, r, c) = srs_to_int8((int64_t)X.at(head, r, c) * (int64_t)C2 + (int64_t)C1 * (int64_t)ifmsum[r] +  (((int64_t)wgtsum[c]*(int64_t)C3 + C0) << sqb)  , sout, false);
                    }
                }
            }
        }
    }
}


/*
 * CPU GEMM: (X * W)
*/
template<typename Ta, typename Tw, typename To>
void cpu_matmul(Ta X, Tw W, To Y)
{
    for (int r = 0; r < Y.num_rows; ++r) {
        for (int c = 0; c < Y.num_cols; ++c) {
            int64_t acc = 0;
            for (int k = 0; k < X.num_cols; ++k) {
                acc += X.at(r, k) * W.at(k, c);
            }
            Y.at(r, c) = acc;
        }
    }
}

/*
 * CPU GEMM: shift(X * W)
*/
template<typename Ta, typename Tw, typename To>
void cpu_matmul(Ta X, Tw W, To Y, int shift, int Msubv, int Ksubv, int Nsubv)
{
    for (int r = 0; r < Y.num_rows; ++r) {
        for (int c = 0; c < Y.num_cols; ++c) {
            int64_t acc = 0;
            for (int k_shard = 0; k_shard < (X.num_cols) / Ksubv; ++k_shard) {
                /*
                 *  Presently the AIE GEMM kernel does upshift before loading the acc
                 *  Followed by down shift when writing back the acc to TDM
                 *  This is done at the subvol boundary (32 x 128 x 64)
                 *  For this reason, the inner dim is broken down to shards of 128
                 *  and encapsulated by up/down shift as seen below
                 */
                acc = sls_to_int64(acc, shift, false);
                for (int k = 0; k < Ksubv; ++k) {
                    acc += (X.at(r, ((k_shard*Ksubv) + k) ) * W.at( ((k_shard*Ksubv) + k), c));
                }
                acc = srs_to_int32(acc, shift, false);
            }
            Y.at(r, c) = acc;
        }
    }
}


/*
 * CPU GEMM: shift(X * W)
*/
template<typename Ta, typename Tw, typename To>
void cpu_bmm_matmul(Ta X, Tw W, To Y, int shift, int Msubv, int Ksubv, int Nsubv)
{
    assert(X.num_heads == W.num_heads);
    double shard_max = std::ceil(static_cast<double>(W.num_rows) / Ksubv);
    //printf("X.num_cols: %d, W.num_rows: %d Ksubv: %d, ceil: %f, result: %.2f\n", X.num_cols, W.num_rows, Ksubv, std::ceil(W.num_rows / Ksubv), shard_max);

    for(int head = 0; head < X.num_heads; ++head) {
        for (int r = 0; r < Y.num_rows; ++r) {
            for (int c = 0; c < Y.num_cols; ++c) {
                int64_t acc = 0;
                if (r < X.num_rows && c < W.num_cols) {
                    for (int k_shard = 0; k_shard < shard_max /*ceil(W.num_rows / Ksubv)*/; ++k_shard) { //NOTE- Check on this since X.num_cols != W.num_rows
                        /*
                        *  Presently the AIE GEMM kernel does upshift before loading the acc
                        *  Followed by down shift when writing back the acc to TDM
                        *  This is done at the subvol boundary (32 x 128 x 64)
                        *  For this reason, the inner dim is broken down to shards of 128
                        *  and encapsulated by up/down shift as seen below
                        */
                        acc = sls_to_int64(acc, shift, false);
                        for (int k = 0; k < std::min(W.num_rows - (k_shard * Ksubv), Ksubv); ++k) {
                            acc += (X.at(head, r, ((k_shard*Ksubv) + k) ) * W.at(head, ((k_shard*Ksubv) + k), c));
                        }
                        acc = srs_to_int32(acc, shift, false);
                    }
                }
                Y.at(head, r, c) = acc;
            }
        }
    }
}

//TODO: Implement CPU model for blockwise qdq matmul
template<typename Ta, typename Tw, typename To, typename Tc0, typename Tc1, typename Tc2>
void cpu_matmul_block_qdq(
    Ta ifm, Tw wgt, To ofm,
    int K_block,
    int shift_tdm, int shift_res,
    Tc0 bias,         // [N]
    Tc2 sc,           // [K_shards][N]
    Tc1 zp,           // [K_shards][N]
    bool sgn_out = false
) {
    assert(ifm.num_heads == wgt.num_heads);
    int H = ifm.num_heads;
    int M = ifm.num_rows;
    int K = ifm.num_cols;
    int N = wgt.num_cols;

    int K_shards = K / K_block;


    for (int head = 0; head < H; ++head) {
        for (int r = 0; r < M; ++r) {
            for (int c = 0; c < N; ++c) {
                int64_t acc64 = 0;

                for (int ks = 0; ks < K_shards; ++ks) {
                    int k_start = ks * K_block;
                    int k_len = std::min(K - k_start, K_block);

                    int64_t psum64 = 0;

                    uint32_t zp_val = static_cast<uint32_t>(zp.at(head, ks, c));
                    // printf("zp_val = %d\n", zp_val);

                    for (int k = 0; k < k_len; ++k) {
                        int kk = k_start + k;
                        uint32_t a_val = static_cast<uint32_t>(ifm.at(head, r, kk));
                        uint32_t w = static_cast<uint32_t>(wgt.at(head, kk, c));
                        int32_t w_val = (w - zp_val);
                        psum64 += static_cast<int64_t>(a_val) * w_val;
                    }
                    // printf("psum64 outside block= %ld\n", psum64);

                    int32_t psum32 = srs_to_int32(psum64, shift_tdm, true);
                    // printf("psum32 after tdm= %ld\n", psum32);
                    acc64 += static_cast<int64_t>(psum32) * static_cast<int64_t>(sc.at(head, ks, c));
                    // printf("acc64= %ld\n", acc64);

                }

                acc64 += static_cast<int64_t>(bias.at(head,c));
                // printf("acc64 after bias= %ld\n", acc64);

                ofm.at(head, r, c) = srs_to_int16(acc64, shift_res, sgn_out);
                // printf("ofm %d\n", ofm.at(head, r, c));

            }
        }
    }
}

/*
 * CPU GEMM: (X * W) + B
*/
template<typename Ta, typename Tw,typename To, typename Tb>
void cpu_matmul(Ta X, Tw W, Tb B, To Y)
{
    for (int r = 0; r < Y.num_rows; ++r) {
        for (int c = 0; c < Y.num_cols; ++c) {
            int64_t acc = 0;
            for (int k = 0; k < X.num_cols; ++k) {
                acc += X.at(r, k) * W.at(k, c);
            }
            Y.at(r, c) = srs_to_int8(acc, 12) + B.at(0,c);
        }
    }
}

template<typename Ta, typename To>
void ifm_sum(Ta X, To Y)
{
    for (int r = 0; r < X.num_rows; ++r) {
        Y.at(r, 0) = 0;
        for (int c = 0; c < X.num_cols; ++c) {
            Y.at(r, 0) += X.at(r,c);
        }
    }
}

template<typename To>
int check_result(To cpu_Y, To aie_Y)
{
    int fail = 0;
    int max_diff = 0;
    for (int r = 0; r < cpu_Y.num_rows; ++r) {
        for (int c = 0; c < cpu_Y.num_cols; ++c) {
            int32_t diff = std::abs(cpu_Y.at(r, c) - aie_Y.at(r, c));
            if (diff == 1) {
                std::cout << "WARNING: Y[" << r << ", " << c << "]: "
                          << "CPU: " << int(cpu_Y.at(r, c)) << ", "
                          << "AIE: " << int(aie_Y.at(r, c)) << "\n";
            } else if (diff > 1) {
                std::cout << "ERROR: Y[" << r << ", " << c << "]: "
                          << "CPU: " << int(cpu_Y.at(r, c)) << ", "
                          << "AIE: " << int(aie_Y.at(r, c)) << "\n";
                fail = 1;
            } else {
                std::cout << "PASS: Y[" << r << ", " << c << "]: "
                          << "CPU: " << int(cpu_Y.at(r, c)) << ", "
                          << "AIE: " << int(aie_Y.at(r, c)) << "\n";
            }
            max_diff = (diff > max_diff) ? diff : max_diff;
        }
    }
    printf("Max diff = %d \n", max_diff);
    return fail;
}

template<typename To>
int check_bmm_result(To cpu_Y, To aie_Y)
{
    int fail = 0;
    int max_diff = 0;
    for(int head = 0; head < cpu_Y.num_heads; ++head) {
        for (int r = 0; r < cpu_Y.num_rows; ++r) {
            for (int c = 0; c < cpu_Y.num_cols; ++c) {
                int32_t diff = std::abs(cpu_Y.at(head, r, c) - aie_Y.at(head, r, c));
                if (diff == 1) {
                    std::cout << "WARNING: Y[" << head << ", " << r << ", " << c << "]: "
                            << "CPU: " << int(cpu_Y.at(head, r, c)) << ", "
                            << "AIE: " << int(aie_Y.at(head, r, c)) << "\n";
                } else if (diff > 1) {
                    std::cout << "ERROR: Y[" << head << ", " << r << ", " << c << "]: "
                            << "CPU: " << int(cpu_Y.at(head, r, c)) << ", "
                            << "AIE: " << int(aie_Y.at(head, r, c)) << "\n";
                    fail = 1;
                } else {
                    std::cout << "PASS: Y[" << head << ", " << r << ", " << c << "]: "
                            << "CPU: " << int(cpu_Y.at(head, r, c)) << ", "
                            << "AIE: " << int(aie_Y.at(head, r, c)) << "\n";
                }
                max_diff = (diff > max_diff) ? diff : max_diff;
            }
        }
    }
    printf("Max diff = %d \n", max_diff);
    return fail;
}

template<typename To>
int check_bmm_result_bf16(To cpu_Y, To aie_Y, float epsilon=1.0)
{
    int fail = 0;
    float max_diff = 0;
    for(int head = 0; head < cpu_Y.num_heads; ++head){
        for (int r = 0; r < cpu_Y.num_rows; ++r) {
            for (int c = 0; c < cpu_Y.num_cols; ++c) {
                float diff = fabsf(bfloat16_to_float(cpu_Y.at(head, r, c)) - bfloat16_to_float(aie_Y.at(head, r, c)));
                if (diff == epsilon) {
                    std::cout << "WARNING: Y[" << r << ", " << c << "]: "
                              << "CPU: " << bfloat16_to_float(cpu_Y.at(head, r, c)) << ", "
                              << "AIE: " << bfloat16_to_float(aie_Y.at(head, r, c)) << "\n";
                } else if (diff > epsilon) {
                    std::cout << "ERROR: Y[" << r << ", " << c << "]: "
                              << "CPU: " << bfloat16_to_float(cpu_Y.at(head, r, c)) << ", "
                              << "AIE: " << bfloat16_to_float(aie_Y.at(head, r, c)) << "\n";
                    fail = 1;
                } else {
                    std::cout << "PASS: Y[" << r << ", " << c << "]: "
                              << "CPU: " << bfloat16_to_float(cpu_Y.at(head, r, c)) << ", "
                              << "AIE: " << bfloat16_to_float(aie_Y.at(head, r, c)) << "\n";
                }
                max_diff = (diff > max_diff) ? diff : max_diff;
            }
        }
    }
    printf("Max diff = %f \n", max_diff);
    return fail;
}

template<typename To>
int check_result_bf16(To cpu_Y, To aie_Y, float epsilon=1.0)
{
    int fail = 0;
    float max_diff = 0;
    for (int r = 0; r < cpu_Y.num_rows; ++r) {
        for (int c = 0; c < cpu_Y.num_cols; ++c) {
            float diff = fabsf(bfloat16_to_float(cpu_Y.at(r, c)) - bfloat16_to_float(aie_Y.at(r, c)));
            if (diff == epsilon) {
                std::cout << "WARNING: Y[" << r << ", " << c << "]: "
                          << "CPU: " << bfloat16_to_float(cpu_Y.at(r, c)) << ", "
                          << "AIE: " << bfloat16_to_float(aie_Y.at(r, c)) << "\n";
            } else if (diff > epsilon) {
                std::cout << "ERROR: Y[" << r << ", " << c << "]: "
                          << "CPU: " << bfloat16_to_float(cpu_Y.at(r, c)) << ", "
                          << "AIE: " << bfloat16_to_float(aie_Y.at(r, c)) << "\n";
                fail = 1;
            } else {
                std::cout << "PASS: Y[" << r << ", " << c << "]: "
                          << "CPU: " << bfloat16_to_float(cpu_Y.at(r, c)) << ", "
                          << "AIE: " << bfloat16_to_float(aie_Y.at(r, c)) << "\n";
            }
            max_diff = (diff > max_diff) ? diff : max_diff;
        }
    }
    printf("Max diff = %f \n", max_diff);
    return fail;
}

float gelu_golden(float in)
{
    //float exp_x = std::exp(-in);
    //float sg  = in/(1+exp_x);
    float xr2 = in/(std::sqrt(2));
    float t = std::erf(xr2);
    float g = in*0.5*(1.0 + t);

    return g;
}

float silu_golden(float in)
{
    float exp_x = std::exp(-in);
    float sg  = in/(1+exp_x);
    return sg;
}

float tanh_golden(float in)
{
    float g = std::tanh(in);
    return g;
}

float rope_golden(float in_for_cos, float in_for_sin, float sin, float cos, bool do_sub)
{
    float output = 0;
    if (do_sub) {
        output = in_for_cos * cos - in_for_sin * sin;
    } else {
        output = in_for_cos * cos + in_for_sin * sin;
    }
    return output;
}

float elew_golden(float ifmA, float ifmB)
{
    float output = ifmA + ifmB;
    return output;
}

void read_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::in | std::ios::binary);
    file.read(data, size);
}

void read_bin_file(std::string filename, char* data)
{
    size_t size = 0;
    printf("Reading bin file: %s \n", filename.c_str());
    std::ifstream infile(filename, std::ios::binary);
    if(!infile){
        printf("Unable to open file: %s !!\n", filename.c_str());
        printf("Aborting...");
        exit(0);
    }
    infile.seekg(0, std::ios::end);
    size = infile.tellg();
    infile.seekg(0, std::ios::beg);
    infile.read(data, size);
    infile.close();
    printf("Read %zd Bytes \n", size);
}

template<typename inT>
void read_data_file(std::string filename, inT *in_ptr){
    std::fstream file;
    file.open(filename, std::ios::in);
    if(file.is_open()){
        printf("Opened file: %s for read \n", filename.c_str());
        int count = 0;
        std::string line;
        while(file){
            std::getline(file, line);
            std::istringstream ss(line);
            int64_t num;
            while(ss >> num)
            {
                *(in_ptr + count) = (inT)(num);
                count++; 
            }
        }
        printf("Read %d values \n", count);
    } else {
        printf("Unable to open file: %s for read \n", filename.c_str());
      abort();
    }
}

template<typename inT>
void write_mat_to_file(std::string fileName, inT* buf, const size_t bufSize)
{
    std::ofstream ofs(fileName.c_str());
    if (!ofs) {
        std::cerr << "Error writing file " << fileName << std::endl;
        exit(EXIT_FAILURE);
    }
    for (size_t i=0; i<bufSize; ++i) {
        ofs << int(buf[i]) << std::endl;
    }
    ofs.close();
}

template<typename Tc0, typename Tc1, typename Tc2>
size_t calc_const_buffer_size(
    int Mgemm,
    int Kgemm,
    int Ngemm,
    int Msubv,
    int Ksubv,
    int Nsubv,
    int num_heads,
    int Vec_coeffs=1,
    int wgt_bits=8,
    int lut_ab_size=0,
    int lut_cd_size=0,
    int rope_sin_size=0,
    int rope_cos_size=0,
		bool is_rope = false,
		bool is_elew = false
){
    size_t total_bo_size = 0;
    int no_wgt_subvols = (Kgemm*Ngemm)/(Ksubv*Nsubv);
    size_t raw_wgt_subv_size = ( Ksubv * Nsubv * wgt_bits ) / 8;
    if(raw_wgt_subv_size % 64 != 0) raw_wgt_subv_size = round_to_multiple(raw_wgt_subv_size, 64);
    size_t bias_subv_size = 0;
    if (Vec_coeffs > 1) {
        bias_subv_size = (Nsubv * sizeof(Tc0)) +  (Nsubv * sizeof(Tc1)) + (Nsubv * sizeof(Tc2)); 
    } else {
        bias_subv_size = (Nsubv * sizeof(Tc0));
    }
    if(lut_ab_size > 0) {
        total_bo_size = ((no_wgt_subvols * raw_wgt_subv_size) + (no_wgt_subvols * bias_subv_size))*num_heads + sizeof(gemm_qdq_params) + lut_ab_size + lut_cd_size + sizeof(silu_gelu_qdq_params);
    } else if (rope_sin_size > 0) {
        // Packing for GEMM + RoPE
        total_bo_size = ((no_wgt_subvols * raw_wgt_subv_size) + (no_wgt_subvols * bias_subv_size))*num_heads + sizeof(gemm_qdq_params) + rope_sin_size + rope_cos_size + sizeof(rope_qdq_params);
		} else {
        total_bo_size = ((no_wgt_subvols * raw_wgt_subv_size) + (no_wgt_subvols * bias_subv_size))*num_heads + sizeof(gemm_qdq_params) + (is_rope ? sizeof(rope_qdq_params) : 0) + (is_elew ? sizeof(elew_qdq_params) : 0);
    }
		//printf("total_bo_size: %zd, raw_wgt_subv_size: %zd, no_wgt_subvols: %d, bias_subv_size: %zd, sizeof(gemm_qdq_params): %zd, is_rope: %d, sizeof(rope_qdq_params): %zd, rope_sin_size: %d, rope_cos_size: %d, is_elew: %d, sizeof(elew_qdq_params): %zd\n", total_bo_size, raw_wgt_subv_size, no_wgt_subvols, bias_subv_size, sizeof(gemm_qdq_params),  is_rope, sizeof(rope_qdq_params), rope_sin_size, rope_cos_size, is_elew, sizeof(elew_qdq_params));

    return total_bo_size;
}

template<typename Tc0, typename Tc1, typename Tc2>
void pack_wgt_bias_qdq_fused_ops(
    void* const_bo_ptr,
    void* wgt_ptr,
    void* C0_vec_ptr,
    void* C1_vec_ptr,
    void* C2_vec_ptr,
    int Kgemm,
    int Ksubv,
    int Ngemm,
    int Nsubv,
    int num_heads,
    int wgt_bits=8,
    int Vec_coeffs=1,
    int Ngran=8,
    int lut_ab_size=0,
    int lut_cd_size=0,
    void* lut_ab_ptr=nullptr,
    void* lut_cd_ptr=nullptr,
    int rope_sin_size=0,
    int rope_cos_size=0,
    void* rope_sin_ptr=nullptr,
    void* rope_cos_ptr=nullptr,
		bool is_rope_fused=false,
		bool is_elew_fused=false
){
    int raw_wgt_subv_size = (Ksubv * Nsubv * wgt_bits) / 8;
    if(raw_wgt_subv_size % 64 != 0) raw_wgt_subv_size = round_to_multiple(raw_wgt_subv_size, 64);
    printf("raw_wgt_subv_size = %d\n", raw_wgt_subv_size);
    int write_offset=0;
    for(int head = 0; head < num_heads; head++){
        int head_c0_ptr = head * Ngemm * sizeof(Tc0);
        int head_c1_ptr = head * Ngemm * sizeof(Tc1);
        int head_c2_ptr = head * Ngemm * sizeof(Tc2);
        int head_wgt_ptr = head * Kgemm * Ngemm * wgt_bits / 8;
        for(int N_shard = 0; N_shard < (Ngemm)/(Nsubv); N_shard++) {
            for(int K_shard = 0; K_shard < (Kgemm)/(Ksubv); K_shard++) {
                memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(wgt_ptr)+(head_wgt_ptr)+(N_shard*Kgemm*Nsubv*wgt_bits/8)+(K_shard*Ksubv*Nsubv*wgt_bits/8)), (Ksubv * Nsubv * wgt_bits) / 8);
                write_offset+=raw_wgt_subv_size;
                if(Vec_coeffs > 1){
                    for(int Nchunk = 0; Nchunk < Nsubv; Nchunk+=Ngran) {
                        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(C0_vec_ptr)+(head_c0_ptr)+(N_shard*Nsubv*sizeof(Tc0))+(Nchunk*sizeof(Tc0))), (Ngran * sizeof(Tc0)));
                        write_offset+=(Ngran * sizeof(Tc0));
                        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(C1_vec_ptr)+(head_c1_ptr)+(N_shard*Nsubv*sizeof(Tc1))+(Nchunk*sizeof(Tc1))), (Ngran * sizeof(Tc1)));
                        write_offset+=(Ngran * sizeof(Tc1));
                        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(C2_vec_ptr)+(head_c2_ptr)+(N_shard*Nsubv*sizeof(Tc2))+(Nchunk*sizeof(Tc2))), (Ngran * sizeof(Tc2)));
                        write_offset+=(Ngran * sizeof(Tc2));
                    }
                } else {
                    memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(C0_vec_ptr)+(head_c0_ptr)+(N_shard*Nsubv*sizeof(Tc0))), (Nsubv * sizeof(Tc0)));
                    write_offset+=Nsubv*sizeof(uint64_t);
                }
            }
        }
    }
    printf("write_offset:%d before gemm_qdq_params\n",write_offset);
    // Copy qdq params
    memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)gemm_qdq_params, sizeof(gemm_qdq_params));
    write_offset+=sizeof(gemm_qdq_params);
    // Conditional copy luts
    if (lut_ab_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), lut_ab_ptr, lut_ab_size);
        write_offset+=lut_ab_size;
    }
    if (lut_cd_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), lut_cd_ptr, lut_cd_size);
        write_offset+=lut_cd_size;
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)silu_gelu_qdq_params, sizeof(silu_gelu_qdq_params));
        write_offset+=sizeof(silu_gelu_qdq_params);
    }
    printf("write_offset:%d before rope_qdq_params\n",write_offset);
    // Conditional copy rope sin/cos
    if (is_rope_fused){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)rope_qdq_params, sizeof(rope_qdq_params));
        write_offset+=sizeof(rope_qdq_params);
    }
    printf("write_offset:%d before rope_sin\n",write_offset);
    if (rope_sin_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), rope_sin_ptr, rope_sin_size);
        write_offset+=rope_sin_size;
    }
    printf("write_offset:%d before rope_cos\n",write_offset);
    if (rope_cos_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), rope_cos_ptr, rope_cos_size);
        write_offset+=rope_cos_size;
    }
		printf("write_offset:%d before elew_qdq_params\n",write_offset);
		// Conditional copy rope sin/cos
    if (is_elew_fused){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)elew_qdq_params, sizeof(elew_qdq_params));
        write_offset+=sizeof(elew_qdq_params);
		}
}

template<typename Tc0, typename Tc1, typename Tc2>
size_t calc_const_buffer_size_block_qdq(
    int Mgemm,
    int Kgemm,
    int Ngemm,
    int Msubv,
    int Ksubv,
    int Nsubv,
    int num_heads,
    int Vec_coeffs=1,
    int wgt_bits=8,
    int block_size=64,
    int lut_ab_size=0,
    int lut_cd_size=0,
    int rope_sin_size=0,
    int rope_cos_size=0,
		bool is_rope = false,
		bool is_elew = false
){
    size_t total_bo_size = 0;
    int no_wgt_subvols = (Kgemm*Ngemm)/(Ksubv*Nsubv);
    size_t raw_wgt_subv_size = ( Ksubv * Nsubv * wgt_bits ) / 8;
    if(raw_wgt_subv_size % 64 != 0) raw_wgt_subv_size = round_to_multiple(raw_wgt_subv_size, 64);
    size_t bias_subv_size = 0;
    size_t zp_subv_size = 0;
    size_t sc_subv_size = 0;
    if (Vec_coeffs > 1) {
        zp_subv_size = round_to_multiple(((Ksubv/block_size) * Nsubv * wgt_bits)/8, 64);
        sc_subv_size = ((Ksubv/block_size) * Nsubv * sizeof(Tc2));
        bias_subv_size = (Nsubv * sizeof(Tc0)) +  zp_subv_size + sc_subv_size; 
    } else {
        bias_subv_size = (Nsubv * sizeof(Tc0));
    }
    if(lut_ab_size > 0) {
        total_bo_size = ((no_wgt_subvols * raw_wgt_subv_size) + (no_wgt_subvols * bias_subv_size))*num_heads + sizeof(gemm_qdq_params) + lut_ab_size + lut_cd_size + sizeof(silu_gelu_qdq_params);
    } else if (rope_sin_size > 0) {
        // Packing for GEMM + RoPE
        total_bo_size = ((no_wgt_subvols * raw_wgt_subv_size) + (no_wgt_subvols * bias_subv_size))*num_heads + sizeof(gemm_qdq_params) + rope_sin_size + rope_cos_size + sizeof(rope_qdq_params);
		} else {
        total_bo_size = ((no_wgt_subvols * raw_wgt_subv_size) + (no_wgt_subvols * bias_subv_size))*num_heads + sizeof(gemm_qdq_params) + (is_rope ? sizeof(rope_qdq_params) : 0) + (is_elew ? sizeof(elew_qdq_params) : 0);
    }
		printf("total_bo_size: %zd, raw_wgt_subv_size: %zd, no_wgt_subvols: %d, bias_subv_size: %zd, sizeof(gemm_qdq_params): %zd, is_rope: %d, sizeof(rope_qdq_params): %zd, rope_sin_size: %d, rope_cos_size: %d, is_elew: %d, sizeof(elew_qdq_params): %zd\n", total_bo_size, raw_wgt_subv_size, no_wgt_subvols, bias_subv_size, sizeof(gemm_qdq_params),  is_rope, sizeof(rope_qdq_params), rope_sin_size, rope_cos_size, is_elew, sizeof(elew_qdq_params));

    return total_bo_size;
}

template<typename Tc0, typename Tc1, typename Tc2>
void pack_wgt_bias_block_qdq(
    void* const_bo_ptr,
    void* wgt_ptr,
    void* C0_vec_ptr,
    void* C1_vec_ptr,
    void* C2_vec_ptr,
    int Kgemm,
    int Ksubv,
    int Ngemm,
    int Nsubv,
    int num_heads,
    int wgt_bits=8,
    int block_size=64,
    int Vec_coeffs=1,
    int Ngran=8,
    int lut_ab_size=0,
    int lut_cd_size=0,
    void* lut_ab_ptr=nullptr,
    void* lut_cd_ptr=nullptr,
    int rope_sin_size=0,
    int rope_cos_size=0,
    void* rope_sin_ptr=nullptr,
    void* rope_cos_ptr=nullptr,
		bool is_rope_fused=false,
		bool is_elew_fused=false
){
    int raw_wgt_subv_size = (Ksubv * Nsubv * wgt_bits) / 8;
    if(raw_wgt_subv_size % 64 != 0) raw_wgt_subv_size = round_to_multiple(raw_wgt_subv_size, 64);
    int zp_subv_size = round_to_multiple(((Ksubv/block_size) * Nsubv * wgt_bits)/8, 64);
    printf("raw_wgt_subv_size = %d\n", raw_wgt_subv_size);
    
    // Calculate number of blocks per subvolume in K dimension
    int num_blocks_per_subv = Ksubv / block_size;
    int num_blocks_total = Kgemm / block_size;
    
    int write_offset=0;
    for(int head = 0; head < num_heads; head++){
        int head_c0_ptr = head *  Ngemm * sizeof(Tc0);
        int head_c1_ptr = head * num_blocks_total * (Ngemm * wgt_bits/8);
        int head_c2_ptr = head * num_blocks_total * Ngemm * sizeof(Tc2);
        int head_wgt_ptr = head * Kgemm * Ngemm * wgt_bits / 8;
        for(int N_shard = 0; N_shard < (Ngemm)/(Nsubv); N_shard++) {
            for(int K_shard = 0; K_shard < (Kgemm)/(Ksubv); K_shard++) {
                memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(wgt_ptr)+(head_wgt_ptr)+(N_shard*Kgemm*Nsubv*wgt_bits/8)+(K_shard*Ksubv*Nsubv*wgt_bits/8)), (Ksubv * Nsubv * wgt_bits) / 8);
                write_offset+=raw_wgt_subv_size;
                if(Vec_coeffs > 1){
                    // Copy C0 once per subvolume
                    memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(C0_vec_ptr)+(head_c0_ptr)+(N_shard*Nsubv*sizeof(Tc0))), (Nsubv * sizeof(Tc0)));
                    write_offset+=(Nsubv * sizeof(Tc0));
                    
                    // Copy ALL C1 (zero points) for all blocks within this subvolume
                    for(int Nchunk = 0; Nchunk < Nsubv/Ngran; Nchunk++){
                        for(int block_idx = 0; block_idx < num_blocks_per_subv; block_idx++) {
                            int block_offset = K_shard * num_blocks_per_subv + block_idx;
                            
                            memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), 
                                (void*)(static_cast<uint8_t*>(C1_vec_ptr)+(head_c1_ptr)+(block_offset * Ngemm * wgt_bits/8)+(N_shard * Nsubv * wgt_bits/8) + ((Nchunk * Ngran * wgt_bits)/8)), 
                                (Ngran * wgt_bits/8));
                            write_offset+=(Ngran * wgt_bits/8);  
                        }
                    }
                    write_offset+=(zp_subv_size - num_blocks_per_subv*(Nsubv * wgt_bits/8));

                    
                    // Copy ALL C2 (scales) for all blocks within this subvolume
                    for(int Nchunk = 0; Nchunk < Nsubv/Ngran; Nchunk++){
                        for(int block_idx = 0; block_idx < num_blocks_per_subv; block_idx++) {
                            int block_offset = K_shard * num_blocks_per_subv + block_idx;
                            
                            memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), 
                                (void*)(static_cast<uint8_t*>(C2_vec_ptr)+(head_c2_ptr)+(block_offset * Ngemm * sizeof(Tc2))+(N_shard * Nsubv * sizeof(Tc2)) + (Nchunk * Ngran * sizeof(Tc2))), 
                                (Ngran * sizeof(Tc2)));
                            write_offset+=(Ngran * sizeof(Tc2));
                        }
                    }
                } else {
                    memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)(static_cast<uint8_t*>(C0_vec_ptr)+(head_c0_ptr)+(N_shard*Nsubv*sizeof(Tc0))), (Nsubv * sizeof(Tc0)));
                    write_offset+=Nsubv*sizeof(uint64_t);
                }
            }
        }
    }
    printf("write_offset:%d before gemm_qdq_params\n",write_offset);
    // Copy qdq params
    memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)gemm_qdq_params, sizeof(gemm_qdq_params));
    write_offset+=sizeof(gemm_qdq_params);
    // Conditional copy luts
    if (lut_ab_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), lut_ab_ptr, lut_ab_size);
        write_offset+=lut_ab_size;
    }
    if (lut_cd_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), lut_cd_ptr, lut_cd_size);
        write_offset+=lut_cd_size;
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)silu_gelu_qdq_params, sizeof(silu_gelu_qdq_params));
        write_offset+=sizeof(silu_gelu_qdq_params);
    }
    printf("write_offset:%d before rope_qdq_params\n",write_offset);
    // Conditional copy rope sin/cos
    if (is_rope_fused){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)rope_qdq_params, sizeof(rope_qdq_params));
        write_offset+=sizeof(rope_qdq_params);
    }
    printf("write_offset:%d before rope_sin\n",write_offset);
    if (rope_sin_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), rope_sin_ptr, rope_sin_size);
        write_offset+=rope_sin_size;
    }
    printf("write_offset:%d before rope_cos\n",write_offset);
    if (rope_cos_size > 0){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), rope_cos_ptr, rope_cos_size);
        write_offset+=rope_cos_size;
    }
		printf("write_offset:%d before elew_qdq_params\n",write_offset);
		// Conditional copy rope sin/cos
    if (is_elew_fused){
        memcpy((void*)(static_cast<uint8_t*>(const_bo_ptr)+write_offset), (void*)elew_qdq_params, sizeof(elew_qdq_params));
        write_offset+=sizeof(elew_qdq_params);
		}

}

template<typename Ta, typename Tb>
void transpose_matix (Ta Mat, Tb Mat_T, int const perm[])
{
    int dims[3];
    printf ("Transposing Matrix !!!\n");
    for (int head = 0; head < Mat.num_heads; head++){
        for (int i = 0; i < Mat.num_rows; i++)
        {
            for (int j = 0; j < Mat.num_cols; j++)
            {
                
                dims[perm[0]] = head;
                dims[perm[1]] = i;
                dims[perm[2]] = j;

                Mat_T.at(dims[0], dims[1], dims[2]) = Mat.at(head, i, j);
            }
        }
    }
}

#endif // GEMM_HOST_RUNTIME_HPP
