#include<cmath>
#include<limits>
#include "matrix.hpp"
#include <string.h> // memcpy
#define NUM_ROWS 4
#define NUM_COLS 2
#define ENABLE_BIAS_ADD 0
#define ENABLE_MASK_ADD 0  // only take affect if bias add is 1

int get_HW_index(int y, int x, int H ,int W)
{
    return y * W + x;
}

// get the 1d index with the "w8" aligned layout
int get_HW8_index(int y, int x, int H, int W)
{
	int blockId = x / 8;
	int blockSize = H * 8;
	int index_in_block = (y*8+(x%8));
    return blockId * blockSize + index_in_block;
}

// Convert layout from row-major to 28 aligned
template<typename T>  // T == int8 or int16 or int32
void from_HW_to_HW8
(
	T* HW_buf , // src
	T* HW8_buf, // dst
	int H,
	int W
)
{
	for(    int y = 0; y < H; y++) {
		for(int x = 0; x < W; x++) {
			int wr_idx = get_HW8_index(y, x, H, W);
			int rd_idx = get_HW_index( y, x, H, W);
			HW8_buf[wr_idx] = HW_buf[rd_idx];
		}
	}
}


// Convert layout from W8 aligned to row-major
template<typename T>  // T == int8 or int16 or int32
void from_HW8_to_HW
(
	T* HW8_buf, // src
	T* HW_buf , // dst
	int H,
	int W
)
{
	for(    int y = 0; y < H; y++) {
		for(int x = 0; x < W; x++) {
			int wr_idx = get_HW_index( y, x, H, W);
			int rd_idx = get_HW8_index(y, x, H, W);
			HW_buf[wr_idx] = HW8_buf[rd_idx];
		}
	}
}

void randomize(int16_t* buf, int numrows, int numcols)
{
	int seed = 0x3785;//time(NULL);
	srand(seed);
	for(int y = 0; y < numrows; y++) {
		for(int x = 0; x < numcols; x++) {
			int idx = y * numcols + x;
			buf[idx] = (std::rand() % 32 - 16);//(std::rand() % 256 - 128);
		}
	}
}

void prepare_2_ifm_gmios
(
	int16_t* k_tensor_rmaj,    // kxn (K tensor)
	int16_t* v_tensor_rmaj,    // nxl (V tensor)
	int k_tsor_dim,
	int n_tsor_dim,
	int l_tsor_dim,
	int16_t* ifm_gmio_0,  // Nr x (Nc / 2)
	int16_t* ifm_gmio_1   // Nr x (Nc / 2)
)
{
	int16_t* k_tensor_w8 = (int16_t*) malloc(k_tsor_dim*n_tsor_dim*sizeof(int16_t));
	from_HW_to_HW8<int16_t>(k_tensor_rmaj, k_tensor_w8, k_tsor_dim, n_tsor_dim);

	int16_t* v_tensor_w8 = (int16_t*) malloc(l_tsor_dim*n_tsor_dim*sizeof(int16_t));
	from_HW_to_HW8<int16_t>(v_tensor_rmaj, v_tensor_w8, n_tsor_dim, l_tsor_dim);

	int n_core_dim    = n_tsor_dim / NUM_COLS / NUM_ROWS;
	int cpy_elems     = k_tsor_dim*n_core_dim;
	int cpy_bytes     = cpy_elems*sizeof(int16_t);
	int dst_off_delta = n_core_dim*(k_tsor_dim+l_tsor_dim);

	// For K:
	memcpy(ifm_gmio_0+dst_off_delta*0, k_tensor_w8+cpy_elems*0, cpy_bytes);
	memcpy(ifm_gmio_0+dst_off_delta*1, k_tensor_w8+cpy_elems*1, cpy_bytes);
	memcpy(ifm_gmio_0+dst_off_delta*2, k_tensor_w8+cpy_elems*2, cpy_bytes);
	memcpy(ifm_gmio_0+dst_off_delta*3, k_tensor_w8+cpy_elems*3, cpy_bytes);

	memcpy(ifm_gmio_1+dst_off_delta*0, k_tensor_w8+cpy_elems*4, cpy_bytes);
	memcpy(ifm_gmio_1+dst_off_delta*1, k_tensor_w8+cpy_elems*5, cpy_bytes);
	memcpy(ifm_gmio_1+dst_off_delta*2, k_tensor_w8+cpy_elems*6, cpy_bytes);
	memcpy(ifm_gmio_1+dst_off_delta*3, k_tensor_w8+cpy_elems*7, cpy_bytes);

	// For V:
	int init_offset = k_tsor_dim * n_core_dim;
	cpy_elems = n_core_dim * 8;
	cpy_bytes = cpy_elems*sizeof(int16_t);

	for(int block = 0; block < (l_tsor_dim / 8); block++)
	{
		int blkbase = n_tsor_dim * 8 * block;
		int jump    = n_core_dim * 8 * block;

		memcpy(ifm_gmio_0 + init_offset + dst_off_delta*0 + jump, v_tensor_w8 + blkbase + cpy_elems*0, cpy_bytes);
		memcpy(ifm_gmio_0 + init_offset + dst_off_delta*1 + jump, v_tensor_w8 + blkbase + cpy_elems*1, cpy_bytes);
		memcpy(ifm_gmio_0 + init_offset + dst_off_delta*2 + jump, v_tensor_w8 + blkbase + cpy_elems*2, cpy_bytes);
		memcpy(ifm_gmio_0 + init_offset + dst_off_delta*3 + jump, v_tensor_w8 + blkbase + cpy_elems*3, cpy_bytes);
		memcpy(ifm_gmio_1 + init_offset + dst_off_delta*0 + jump, v_tensor_w8 + blkbase + cpy_elems*4, cpy_bytes);
		memcpy(ifm_gmio_1 + init_offset + dst_off_delta*1 + jump, v_tensor_w8 + blkbase + cpy_elems*5, cpy_bytes);
		memcpy(ifm_gmio_1 + init_offset + dst_off_delta*2 + jump, v_tensor_w8 + blkbase + cpy_elems*6, cpy_bytes);
		memcpy(ifm_gmio_1 + init_offset + dst_off_delta*3 + jump, v_tensor_w8 + blkbase + cpy_elems*7, cpy_bytes);
	}
}

void prepare_2_wgt_gmios
(
	int16_t* buf_rmaj,
	int Nr,
	int Nc,
	int16_t* wgt_gmio_0,  // Nr x Nc (broadcasted)
	int16_t* wgt_gmio_1   // Nr x Nc (broadcasted)
)
{
	int16_t* interim_buf = (int16_t*) malloc(Nr*Nc*sizeof(int16_t));
	from_HW_to_HW8<int16_t>(buf_rmaj, interim_buf, Nr, Nc);

	int cpy_amt = Nr*Nc*sizeof(int16_t);
	memcpy(wgt_gmio_0, interim_buf, cpy_amt);
	memcpy(wgt_gmio_1, interim_buf, cpy_amt);

	//ActQMatrix<int16_t, 32, 96> matQ(512, 96, );
	//for(int y = 0; matQ.num_rows)


}


void prepare_1_wgt_gmios
(
	int16_t* buf_rmaj,
	int Nr,
	int Nc,
	int16_t* wgt_gmio  // Nr x Nc (broadcasted)
)
{
	int cpy_amt = Nr*Nc*sizeof(int16_t);
	ActQMatrix<int16_t, 32, 96> matQ(512, 1152, wgt_gmio);

	for(int y = 0; y < matQ.num_rows; y++)
	{
		for(int x  = 0; x < matQ.num_cols; x++)
		{
			matQ.at(y, x) = buf_rmaj[y * matQ.num_cols + x];
		}
	}
}


void prepare_1_ofm_gmios
(
	uint32_t* bufO_w8,
	int Nr,
	int Nc,
	uint32_t* ofm_gmio  // Nr x Nc (broadcasted), assumed to be rmaj rlayout
)
{
	int cpy_amt = Nr*Nc*sizeof(uint32_t);
	OutMatrix<uint32_t, 32, 64> matO(512, 768, bufO_w8);

	for(int y = 0; y < matO.num_rows; y++)
	{
		for(int x  = 0; x < matO.num_cols; x++)
		{
			ofm_gmio[y*matO.num_cols + x] = matO.at(y, x);
		}
	}
}

void prepare_1_ifm_gmios
(
	int16_t* bufK_rmaj,
	int16_t* bufV_rmaj,
	int Nr,
	int Nc,
	int16_t* ifm_gmio  // Nr x Nc (broadcasted)
)
{
	ActKVMatrix<int16_t, 96, 64, 64, 64> matKV(1152, 512, 512, 768, ifm_gmio);     //512 x (96 + 64) 512 x 160

	for(int y = 0; y < matKV.key_rows; y++)
	{
		for(int x  = 0; x < matKV.key_cols; x++)
		{
			matKV.atK(y, x) = bufK_rmaj[y * matKV.key_cols + x];
		}
	}
	for(int y = 0; y < matKV.val_rows; y++)
	{
		for(int x  = 0; x < matKV.val_cols; x++)
		{
			matKV.atV(y, x) = bufV_rmaj[y * matKV.val_cols + x];
		}
	}
}


void get_ofm_rmaj
(
	uint32_t* buf_rmaj,
	int Nr,
	int Nc,
	uint32_t* ofm_gmio_0,  // Nr x Nc / 2
	uint32_t* ofm_gmio_1   // Nr x Nc / 2
)
{
	uint32_t* interim_buf = (uint32_t*) malloc(Nr*Nc*sizeof(uint32_t));


	int cpy_amt = Nr*(Nc/NUM_COLS)*sizeof(uint32_t);
	memcpy(interim_buf				   , ofm_gmio_0, cpy_amt);
	memcpy(interim_buf+Nr*(Nc/NUM_COLS), ofm_gmio_1, cpy_amt);

	from_HW8_to_HW<uint32_t>(interim_buf, buf_rmaj, Nr, Nc);
}

void get_reduced_ofm
(
	uint32_t* buf_reduced,    // Nr x Nc
	int Nr_core,
	int Nc_core,
	uint32_t* ofm_gmio_0,  // Nr x Nc x 4
	uint32_t* ofm_gmio_1   // Nr x Nc x 4
)
{
	uint32_t* ptr[8];
	for(int k = 0; k < 4; k++)
		ptr[k] = &(ofm_gmio_0[  k  *Nr_core*Nc_core]);
	for(int k = 4; k < 8; k++)
		ptr[k] = &(ofm_gmio_1[(k-4)*Nr_core*Nc_core]);

	for(int y = 0; y < Nr_core; y++) {
		for(int x = 0; x < Nc_core; x++) {

			int idx = y * Nc_core + x;
			buf_reduced[idx] = 0;
			for(int k = 0; k < 8; k++)
				buf_reduced[idx] += ptr[k][idx];

		}
	}
}

template<typename T1, typename T2>
void ref_mac(T1* ifm, T2* wgt, uint32_t* v_acc, int m_subv, int k_subv, int n_subv, bool is_wgt_transposed=true)
{
	int ifm_idx;
	int wgt_idx;
	int acc;
	for(int r = 0; r < m_subv; r++) {        // M
		for(int c = 0; c < n_subv; c++) {    // N

			acc = 0;
			for(int k = 0; k < k_subv; k++)  // K
			{
				ifm_idx = r*k_subv+k;      //kth entry of the r-th row     : ifm [r, k]
				if(is_wgt_transposed)
					wgt_idx = k*n_subv+c;  //kth entry of the c-th column  : wgtT[k, c]
				else
					wgt_idx = c*k_subv+k;   //kth entry of the c-th row     : wgt [c, k], 1 row has k_subv elements
				acc += ifm[ifm_idx] * wgt[wgt_idx];
			}
			v_acc[r*n_subv+c] = acc;
		}
	}
}

/*
void int32_to_int16(int16_t* dst, uint32_t* src, int height, int width)
{
	for(int y = 0; y < height; y++){
		for(int x = 0; x < width; x++){
			int idx = y*width+x;
			dst[idx] = (int16_t)(src[idx] & 0x0000FFFF);
		}
	}
}
*/


float truncate_to_bf16(float in)
{
	uint32_t* tmp = (uint32_t*)&in;
	*tmp = (*tmp) & 0xFFFF0000;
	return *((float*)tmp);
}


void uint32_to_int8_to_fp32(float* dst, uint32_t* src, int height, int width, int shift)
{
	for(int r = 0; r < height; r++) {
		for(int c = 0; c < width; c++) {

			uint32_t number = src[r*width+c];
            float fp_num = truncate_to_bf16((float)number);

			dst[r*width+c] = fp_num / (float)(1<<shift);
		}
	}
}

void fp32_to_int16
(
	int16_t* dst,
	float* src,
	int height, int width
)
{
	for(int r = 0; r < height; r++) {
		for(int c = 0; c < width; c++) {

			float fp_num = src[r*width+c];  // softmax output, range : [0.0, 1.0]

			//if(fp_num != 0.0f)
			//	printf("fp_num : %2.10e\n", fp_num);

			//fp_num = truncate_to_bf16(fp_num);

        	uint32_t number = (uint32_t)(fp_num * 65536.0 - 32768);   //in range : [-32768, 32767]

			dst[r*width+c] = number;
		}
	}
}

void fp32_to_uint8
(
	uint8_t* dst,
	float* src,
	int height, int width
)
{
	for(int r = 0; r < height; r++) {
		for(int c = 0; c < width; c++) {

			float fp_num = src[r*width+c];  // softmax output, range : [0.0, 1.0]

			//fp_num = truncate_to_bf16(fp_num);

        	//uint32_t number = (uint32_t)(fp_num * 255.0);// - 128.0);   //in range : [-128, 127]
			uint32_t number = (uint32_t)(fp_num * 255.0);// - 128.0);   //in range : [-128, 127]

			dst[r*width+c] = (uint8_t)number;
		}
        printf("\n");
	}
}

/*
void uint32_to_int16(int16_t* dst, uint32_t* src, int height, int width, int shift)
{
	for(int r = 0; r < height; r++) {        // M
		for(int c = 0; c < width; c++) {    // K
			uint32_t number = 0;//src[r*width+c];

			int16_t val16 = (int16_t)(number >> shift);
            //dst[r*width+c] = val16;

			printf("%d, %d\n", r, c);
		}
	}
}
*/

void uint32_to_int16(int16_t* dst, uint32_t* src, int height, int width, int shift)
{
	for(int y= 0; y < height; y++)
	{
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			uint32_t val32 = src[idx];
			dst[idx] = (int16_t)(val32 >> shift);
		}
	}
}

void uint32_to_int8(uint8_t* dst, uint32_t* src, int height, int width, int shift)
{
	for(int y= 0; y < height; y++)
	{
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			uint32_t val32 = src[idx];
			dst[idx] = (uint8_t)(val32 >> shift);
		}
	}
}

template<typename T1, typename T2>
void dequant_int8_to_float(RowMajorMatrix<T1> in_mat, RowMajorMatrix<T2> out_mat, T1 zero_point, float scale)
{
    for(int i = 0; i < in_mat.num_rows; ++i) {
        for (int j = 0; j < in_mat.num_cols; ++j) {
            out_mat.at(i, j) = (float(in_mat.at(i, j) - zero_point) * scale);
        }
    }
}

template<typename T1, typename T2>  // T1 float, T2 uint8
void quant_float_to_uint8(RowMajorMatrix<T1> in_mat, RowMajorMatrix<T2> out_mat, T2 zero_point, float scale)
{
    for(int i = 0; i < in_mat.num_rows; ++i) {
        for (int j = 0; j < in_mat.num_cols; ++j) {
            out_mat.at(i, j) = ((T2)(in_mat.at(i, j) * (1.0/scale)) + zero_point);
        }
    }
}

void uint32_to_uint8(uint8_t* dst, uint32_t* src, int height, int width, int shift)
{
	for(int y= 0; y < height; y++)
	{
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			uint32_t val32 = src[idx];
			//dst[idx] = (uint8_t)(val32 >> shift);
			dst[idx] = (uint8_t)(val32 & 0x000000FF);
		}
	}
}

float approx_exp2(float x)
{
	//inp = self.f2bf( inp, rounding=False )
	x = truncate_to_bf16(x);
	//out = ( 1+( inp-np.floor( inp ))) * 2**np.floor( inp )
	float y = (1 + ( x-floor(x))) * pow(2.0f, floor(x));

	//out[inp<0] -= 2**( np.floor( np.log2( out[inp<0] ))-23 )
	if(x < 0)
		y -= pow(2.0f, floor( log2( y ))-23);
	return truncate_to_bf16(y);
}

void softmax(float* dst, float* src, int height, int width)
{
	for(int y = 0; y < height; y++) {
		float rowsum = 0.0f;
		float rowmax = std::numeric_limits<float>::min();
		for(int x = 0; x < width; x++) {
			int idx = y*width+x;
			rowmax = std::max(src[idx], rowmax);
		}

		for(int x = 0; x < width; x++) {
			int idx = y*width+x;
			//rowsum += pow(2, (src[idx]));
			rowsum += approx_exp2(src[idx]-rowmax);//pow(2, (src[idx] - rowmax));
			//rowsum += std::exp(src[idx]-rowmax);//pow(2, (src[idx] - rowmax));
		}
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			float exp = src[idx]-rowmax;

			float pow2 = approx_exp2(src[idx]-rowmax);//pow(2, src[idx]-rowmax);
			//float pow2 = std::exp(src[idx]-rowmax);//pow(2, src[idx]-rowmax);

			/*
			if(exp == 0.0f || pow2 != 0.0f)
			{
				printf("exp[%d] == %2.10e  ", idx, exp );
				printf("pow[%d] == %2.10e  ", idx, pow2);
				printf("rowsum  == %2.10e\n", rowsum   );
			}
			*/

			dst[idx] = pow2 / rowsum;
			//dst[idx] = truncate_to_bf16(pow(2, src[idx]) / rowsum);

			//if(dst[idx] != 0.0f)
			//	printf("dst[%d] : %2.10e\n", idx, dst[idx]);

		}
	}
}

template<typename T>
void elemw_ups(uint32_t* pdst, T* psrc, int height, int width)
{
	for(int y = 0; y < height; y++)
	{
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			pdst[idx] = (uint32_t)psrc[idx];
		}
	}
}

template<typename T>
void elemw_add(T* pdst, T* psrc1, T* psrc2, int height, int width)
{
	for(int y = 0; y < height; y++)
	{
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			pdst[idx] = psrc1[idx] + psrc2[idx];
		}
	}
}

template<typename T>
void broadcast_vertically(T* pdst, T* bcast_array, int height, int width)
{
	for(int y = 0; y < height; y++)
	{
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			pdst[idx]=bcast_array[x];
		}
	}
}


template<typename T>
void ref_qxkxv_channel
(
	T* q_mat,           // (mxk)
	T* k_mat,           // (nxk)
	T* v_mat,           // (nxl)
	T* o8  ,       // (mxl)
	uint32_t* z32  ,     // (mxl)
	int m_subv, int k_subv,
	gemm_qdq_param<T> qdq_params[4]
)
{
	T* qt_mat     = (T *)malloc(k_subv * m_subv * sizeof(T));
	uint32_t* qkt_32      = (uint32_t*)malloc(k_subv * k_subv * sizeof(uint32_t));
	T* qkt_qtzed     = (T *)malloc(k_subv * k_subv * sizeof(T));
	float* qkt_fp32      = (float  *)malloc(k_subv * k_subv * sizeof(float  ));
	float* sfmx_fp32     = (float  *)malloc(k_subv * k_subv * sizeof(float  ));
	T* sfmx_qtzed  = (T*)malloc(k_subv * k_subv * sizeof(T));


	RowMajorMatrix<T>   Q           (m_subv, k_subv, q_mat	   );
	RowMajorMatrix<T>   Qt          (k_subv, m_subv, qt_mat	   );
	RowMajorMatrix<T>   K		    (m_subv, k_subv, k_mat	   );
	RowMajorMatrix<T>   V		    (m_subv, k_subv, v_mat	   );
	RowMajorMatrix<uint32_t> QKT     (k_subv, k_subv, qkt_32    );
	RowMajorMatrix<T> QKT_QTZED(k_subv, k_subv, qkt_qtzed  );
	RowMajorMatrix<float> QKT_FP32  (k_subv, k_subv, qkt_fp32  );
	RowMajorMatrix<T> SM_U8   (k_subv, k_subv, sfmx_qtzed);
	RowMajorMatrix<float> SM_FP32   (k_subv, k_subv, sfmx_fp32);
	RowMajorMatrix<T> O_S8     (m_subv, k_subv, o8        );
	RowMajorMatrix<uint32_t> Z_S32   (m_subv, k_subv, z32       );

   for (int i = 0; i < Q.num_rows; ++i) {
       for (int j = 0; j < Q.num_cols; ++j) {
           Qt.at(j, i) = Q.at(i , j);
       }
   }
   

	ref_mac<T, T>(qt_mat, k_mat, qkt_32, k_subv, m_subv, k_subv);

	qdq_asym_golden<T, T, uint32_t, T>(Qt, K, QKT, qdq_params[0], QKT_QTZED, false);

	dequant_int8_to_float(QKT_QTZED, QKT_FP32, T(qdq_params[2].zero_point), qdq_params[2].scale);

print_matrix(QKT_FP32, "QKT_FP32 = ");
#if 0
	printf("SFMX IN :  data:\n");
	for(int k = 0; k < 8; k++)
	{
		printf("%f ", QKT_FP32[k]);
	}
	printf("\n");
#endif

	softmax(sfmx_fp32, QKT_FP32.data, k_subv, k_subv);

	//dequant_int8_to_float<RowMajorMatrix<uint8_t>, RowMajorMatrix<float>>(QKT_QTZED, QKT_FP32, qdq_params[2].zero_point, qdq_params[2].scale);
	//softmax(sfmx_fp32, qkt_fp32, m_subv, n_subv);

	quant_float_to_uint8( SM_FP32, SM_U8, T(qdq_params[3].zero_point), (qdq_params[3].scale));

#if 0
	printf("SMXV_IN :  data:\n");
	for(int k = 0; k < 8; k++)
	{
		printf("%d ", sfmx_qtzed[k]);
	}
	printf("\n");
#endif
	ref_mac<T, T>(v_mat, sfmx_qtzed, z32, m_subv, k_subv, k_subv, false);
#if 0
	printf("QDQ_GEMM_IN :  data:\n");
	for(int k = 0; k < 8; k++)
	{
		printf("%d ", z32[k]);
	}
	printf("\n");
#endif
	//ref_mac<uint8_t, T>(qkt_int8, v_mat, z32, m_subv, n_subv, l_subv);
	//ref_mac<uint8_t, T>(tmp_int8, v_mat, z32, m_subv, n_subv, l_subv);
	//int32_to_int8(o8, z32, m_subv, l_subv, 11);
	// replace this with qdq_asym_golden function call
	//int32_to_uint8((uint8_t*)o8, z32, m_subv, l_subv, 11);
	qdq_asym_golden<T, T, uint32_t, T>(V, SM_U8, Z_S32, qdq_params[1], O_S8, true);
	//qdq_asym_golden<RowMajorMatrix<uint8_t>, RowMajorMatrix<T>, RowMajorMatrix<uint32_t>, RowMajorMatrix<uint8_t>> \
	//(QKT_QTZED, V, Z_S32, qdq_params[1], O_S8, false);
#if 0
	 //uint32_to_int16((int16_t*)qkt_qtzed, qkt_32, m_subv, n_subv, 0);
    int16_t * bufO_int = ((int16_t*) o8);
    int16_t * bufT2_int = ((int16_t*) v_mat);
    //int16_t * bufT2_int = ((int16_t*) qkt_qtzed);
    //int16_t * bufT2_int = ((int16_t*) sfmx_qtzed);
    //int32_t * bufT2_int = ((int32_t*) z32);
    for (int i=0; i<m_subv; i++ ) {
        for (int j=0; j<k_subv; j++ ) {
            //bufO_int[i*l_subv + j] = bufT2_int[i*n_subv + j];
            bufO_int[i*k_subv + j] = bufT2_int[i*k_subv + j];   // for passthrough
        }
    }


#else
#endif
}


template<typename T>
void ref_qxkxv
(
	T* q_mat,           // (mxk)
	T* k_mat,           // (nxk)
	T* v_mat,           // (nxl)
	T* o8  ,       // (mxl)
	uint32_t* z32  ,     // (mxl)
	int m_subv, int k_subv, int n_subv, int l_subv, //(m, k, n, l)
	gemm_qdq_param<T> qdq_params[4],
	bool k_mat_transposed = true
)
{
	uint32_t* qkt_32      = (uint32_t*)malloc(m_subv * n_subv * sizeof(uint32_t));
	T* qkt_qtzed     = (T *)malloc(m_subv * n_subv * sizeof(T));
	float* qkt_fp32      = (float  *)malloc(m_subv * n_subv * sizeof(float  ));
	float* sfmx_fp32     = (float  *)malloc(m_subv * n_subv * sizeof(float  ));
	T* sfmx_qtzed  = (T*)malloc(m_subv * n_subv * sizeof(T));


	RowMajorMatrix<T>   Q           (m_subv, k_subv, q_mat	   );
	RowMajorMatrix<T>   K		    (n_subv, k_subv, k_mat	   );
	RowMajorMatrix<T>   V		    (n_subv, l_subv, v_mat	   );
	RowMajorMatrix<uint32_t> QKT     (m_subv, n_subv, qkt_32    );
	RowMajorMatrix<T> QKT_QTZED(m_subv, n_subv, qkt_qtzed  );
	RowMajorMatrix<float> QKT_FP32  (m_subv, n_subv, qkt_fp32  );
	RowMajorMatrix<T> SM_U8   (m_subv, n_subv, sfmx_qtzed);
	RowMajorMatrix<float> SM_FP32   (m_subv, n_subv, sfmx_fp32);
	RowMajorMatrix<T> O_S8     (m_subv, l_subv, o8        );
	RowMajorMatrix<uint32_t> Z_S32   (m_subv, l_subv, z32       );

	bool perf_tranpose_k = !(k_mat_transposed);
	ref_mac<T, T>(q_mat, k_mat, qkt_32, m_subv, k_subv, n_subv, k_mat_transposed);

	qdq_asym_golden<T, T, uint32_t, T>(Q, K, QKT, qdq_params[0], QKT_QTZED, perf_tranpose_k);

	dequant_int8_to_float(QKT_QTZED, QKT_FP32, T(qdq_params[2].zero_point), qdq_params[2].scale);

print_matrix(QKT_FP32, "QKT_FP32 = ");
#if 0
	printf("SFMX IN :  data:\n");
	for(int k = 0; k < 8; k++)
	{
		printf("%f ", QKT_FP32[k]);
	}
	printf("\n");
#endif

	softmax(sfmx_fp32, QKT_FP32.data, m_subv, n_subv);

	//dequant_int8_to_float<RowMajorMatrix<uint8_t>, RowMajorMatrix<float>>(QKT_QTZED, QKT_FP32, qdq_params[2].zero_point, qdq_params[2].scale);
	//softmax(sfmx_fp32, qkt_fp32, m_subv, n_subv);

	quant_float_to_uint8( SM_FP32, SM_U8, T(qdq_params[3].zero_point), (qdq_params[3].scale));

#if 0
	printf("SMXV_IN :  data:\n");
	for(int k = 0; k < 8; k++)
	{
		printf("%d ", sfmx_qtzed[k]);
	}
	printf("\n");
#endif
	ref_mac<T, T>(sfmx_qtzed, v_mat, z32, m_subv, n_subv, l_subv);
#if 0
	printf("QDQ_GEMM_IN :  data:\n");
	for(int k = 0; k < 8; k++)
	{
		printf("%d ", z32[k]);
	}
	printf("\n");
#endif
	//ref_mac<uint8_t, T>(qkt_int8, v_mat, z32, m_subv, n_subv, l_subv);
	//ref_mac<uint8_t, T>(tmp_int8, v_mat, z32, m_subv, n_subv, l_subv);
	//int32_to_int8(o8, z32, m_subv, l_subv, 11);
	// replace this with qdq_asym_golden function call
	//int32_to_uint8((uint8_t*)o8, z32, m_subv, l_subv, 11);
	qdq_asym_golden<T, T, uint32_t, T>(SM_U8, V, Z_S32, qdq_params[1], O_S8, false);
	//qdq_asym_golden<RowMajorMatrix<uint8_t>, RowMajorMatrix<T>, RowMajorMatrix<uint32_t>, RowMajorMatrix<uint8_t>> \
	//(QKT_QTZED, V, Z_S32, qdq_params[1], O_S8, false);
#if 0
	 uint32_to_int16((int16_t*)qkt_qtzed, qkt_32, m_subv, n_subv, 6);
    int16_t * bufO_int = ((int16_t*) o8);
    //int16_t * bufT2_int = ((int16_t*) q_mat);
    int16_t * bufT2_int = ((int16_t*) qkt_qtzed);
    //int16_t * bufT2_int = ((int16_t*) sfmx_qtzed);
    //int32_t * bufT2_int = ((int32_t*) z32);
    for (int i=0; i<m_subv; i++ ) {
        for (int j=0; j<k_subv; j++ ) {
            bufO_int[i*l_subv + j] = bufT2_int[i*n_subv + j];
            //bufO_int[i*l_subv + j] = bufT2_int[i*k_subv + j];   // for passthrough
        }
    }


#else
#endif
}

#if 0
template<class T>
void extract_columns(T* srcMat, int src_numrows, int src_numcols, int startcol, int numcols, T* dstMat)
{
	T* psrc = srcMat + startcol;
	int cpyamt = numcols * sizeof(T);
	for(int r = 0; r < src_numrows; r++)
	{
		int srcoff = r*src_numcols;
		int dstoff = r*numcols;
		memcpy(dstMat + dstoff, psrc + srcoff, cpyamt);
	}
}

template<class T>
void paste_columns(T* srcMat, int src_numrows, int src_numcols,
				   int startcol,								  // startcol in dstMat to paste
				   int numcols, 								  // numcols to paste
				   T* dstMat, int dst_numrows, int dst_numcols)
{
	T* psrc = srcMat;
	T* pdst = dstMat + startcol;
	int cpyamt = numcols * sizeof(T);
	for(int r = 0; r < src_numrows; r++)
	{
		int srcoff = r*src_numcols;
		int dstoff = r*dst_numcols;
		memcpy(pdst + dstoff, psrc + srcoff, cpyamt);
	}
}

template<class T>
void extract_rows(T* srcMat, int src_numrows, int src_numcols, int startrow, int numrows, T* dstMat)
{
	T* psrc = srcMat + startrow * src_numcols;
	int cpyamt = src_numcols * sizeof(T);
	for(int r = 0; r < numrows; r++)
	{
		int srcoff = r*src_numcols;
		int dstoff = r*src_numcols;
		memcpy(dstMat + dstoff, psrc + srcoff, cpyamt);
	}
}

void ref_mha_qxkxv
(
	int16_t* q_tensor,  // (mxk)
	int16_t* k_tensor,  // (kxn)
	int16_t* v_tensor,  // (nxl)
#if ENABLE_BIAS_ADD
	int16_t* b_tensor,
#if ENABLE_MASK_ADD
	int16_t* m_tensor,
#endif
#endif
	uint32_t* z_tensor,  // (mxl)
	int m_tsor, int k_tsor, int n_tsor, int l_tsor
)
{
	//get K
	//get Q
	//get V
	int num_heads = 12;
	int k_head = k_tsor / num_heads; // per head k dimension
	int l_head = l_tsor / num_heads; // per head l dimension

	int16_t* q_mat   = (int16_t*)malloc( m_tsor * k_head * sizeof(int16_t));
	int16_t* k_mat   = (int16_t*)malloc( k_head * n_tsor * sizeof(int16_t));
	int16_t* v_mat   = (int16_t*)malloc( n_tsor * l_head * sizeof(int16_t));
	uint32_t* o_mat   = (uint32_t*)malloc( m_tsor * l_head * sizeof(uint32_t));

	int16_t* o_mat16 = (int16_t*)malloc( m_tsor * l_head * sizeof(int16_t));

	for(int h = 0; h < 12; h++) // h == head index
	{
		extract_columns<int16_t>(q_tensor,  m_tsor, k_tsor, h*k_head,  k_head, q_mat);
		extract_rows   <int16_t>(k_tensor,  k_tsor, n_tsor, h*k_head,  k_head, k_mat);
		extract_columns<int16_t>(v_tensor,  n_tsor, l_tsor, h*l_head,  l_head, v_mat);

#if ENABLE_BIAS_ADD
		extract_rows   <int16_t>(b_tensor,  num_heads * m_tsor, n_tsor, h*m_tsor,  m_tsor, b_mat);
#if ENABLE_MASK_ADD
		ref_qxkxv(q_mat, k_mat, v_mat, b_mat, m_mat, o_mat16, o_mat, m_tsor, k_head, n_tsor, l_head);
#else
		ref_qxkxv(q_mat, k_mat, v_mat, b_mat, o_mat16, o_mat, m_tsor, k_head, n_tsor, l_head);
#endif
#else
		ref_qxkxv(q_mat, k_mat, v_mat, o_mat16, o_mat, m_tsor, k_head, n_tsor, l_head);
#endif

		paste_columns<uint32_t>(o_mat   , m_tsor, l_head,
				   			   h*l_head,							  // startcol in dstMat to paste
				   			   l_head  , 							  // numcols to paste
				   			   z_tensor, m_tsor, l_tsor);
	}

	return;
}
#endif

template<class T>
bool check_consistent(T* ref, T* buf, int height, int width)
{
	bool has_discrep = false;
	//for(int i = 0; i < len; i++)
	float max_rel_error = 0.0f;
	float sum_rel_error = 0.0f;
	float avg_rel_error;
	for(int y = 0; y < height; y++)
	{

		for(int x = 0; x < width; x++)
		{
			//if(x % 64 == 0)
			//	printf("\n-------------------------------\n");
			int i = y * width + x;
			if(ref[i] != buf[i])
			{
				has_discrep = true;
				//printf("%c ", 'X');
			}
			else
			{
				//printf("%c ", 'O');
			}
			float rel_error = std::abs((float)(ref[i]-buf[i])/(float)abs(ref[i]));
			max_rel_error = fmax(rel_error, max_rel_error);
			sum_rel_error += rel_error;
			printf("@(%3d, %3d)-->(%10d, %10d). Error %4.4f\n", y, x, ref[i], buf[i], rel_error*100);
		}
		printf("\n");
	}
	avg_rel_error = (sum_rel_error / (float)(height * width));
	//printf("Maximum relative error: %4.4f % \n", max_rel_error*100);
	printf("Average relative error: %4.4f % \n", avg_rel_error*100);
	return avg_rel_error * 100 < 0.2;
	//return !has_discrep;
}

void vis_buf(uint32_t* buf, int height, int width)
{
	for(int y = 0; y < height; y++)
	{
		if(y % 32 == 0)
			printf("\n-------------------------------\n");
		for(int x = 0; x < width; x++)
		{
			int idx = y*width+x;
			//if(buf[idx]==0)
			//	printf("O");
			//else
			//	printf("X");
			printf("%7d ", buf[idx]);
		}
		printf("\n");
	}
}