#ifndef SM_CARF_HPP
#define SM_CARF_HPP

#include <assert.h>
#include <stdlib.h>
#include <iostream>
#include <iomanip> 
#include <ios>
#include <fstream>
#include <cstdint>

#include <stdlib.h>
#include <math.h>
#include <numeric>
#include <vector>
#include "common.hpp"
#include "carf_dc.hpp"


//======================== Custom Op Function ========================
void softmax_newref
(
    RowMajorMatrix<uint16_t>& X,
    RowMajorMatrix<uint16_t>& Y,
    int true_cols,
    float mask_val,
    float xscale
)   //    float* dst, float* src, int height, int width) 
{
    assert(Y.num_rows == X.num_rows);
    assert(Y.num_cols == X.num_cols);

    int height = Y.num_rows;
    int width  = Y.num_cols;
    float pow_e;
    assert(Y.num_cols >= true_cols);
	for(int y = 0; y < height; y++){
		for(int x = true_cols; x < width; x++){
            X.at(y,x) = float2bfloat(mask_val);
        }
    }
    // Create a row-major order vector to hold intermediate bfloat-processed values
    std::vector<float> bf_values(height*width);

    // Step 1: Normalize input data by subtracting the max value along the last dimension
	for(int y = 0; y < height; y++)
	{
		float rowmax = std::numeric_limits<float>::lowest();
        
		for(int x = 0; x < width; x++){
			rowmax  = std::max((bfloat16_to_float(X.at(y, x))),  rowmax);
        }
		for(int x = 0; x < width; x++){
            float f_sub = bfloat16_to_float(X.at(y, x)) - rowmax;
            bf_values[y * width + x] = bfloat16_to_float(float_to_bfloat16(f_sub).value);
        }
	}
    
    // Perform specific operations mimicking Python logic. get exp_values i.e., approx of e^(x-rowmax(x))
    const float const_base_exp = bfloat16_to_float(float_to_bfloat16(std::log2(std::exp(1.0f))).value); // log2(e)
    std::vector<float> exp_values(height*width);
    for (int64_t i = 0; i < height*width; ++i) {
        float val = bf_values[i] * const_base_exp;
        // float inp = bfloat2float(float2bfloat(val)); //Note: this does not have exactly same bahavior as python code {inp = dc.f2bf(val2, rounding=False)}
        float inp = dc_f2bf(val); //minimal re-impl of dc.f2bf in carf. {TODO}: expand to rounding=True
        // Step 2: Calculate approximate exponential using the algorithm
        float exponent_part = std::floor(inp);
        float exp = (1 + (inp - exponent_part)) * powf(2, exponent_part);
        if (inp < 0) {
            float log2_value = std::log2(std::fabs(exp));
            exp -= std::pow(2, std::round(bfloat16_to_float(float_to_bfloat16(log2_value).value)) - 23);
        }
        exp_values[i] = dc_f2bf(exp);
    }

    // Step 3, 4, and 5: Compute inverse sum, apply SRS, and compute final softmax value for each slice
    for (int y = 0; y < height; y++) {
        float sum_exp = 0.0f;
        for (int x = 0; x < width; x++) {
            sum_exp += static_cast<double>(exp_values[y * width + x]);
        }
        float inv_sum = 1.0f / sum_exp;
        if (inv_sum==0){ printf("\033[31mWarning: 1/sum div by 0!!!\033[0m\n");}
        float inv_sum_bf = bfloat16_to_float(float_to_bfloat16(inv_sum).value);
        
        for (int x = 0; x < width; ++x) {
            float final_fp32 = exp_values[y*width+x] * inv_sum_bf;
            float final_bf16 = bfloat16_to_float(float_to_bfloat16(final_fp32).value);
            Y.at(y, x) = float_to_bfloat16(final_bf16).value; //match with aie_Y
        }
    }
}

#endif // SM_CARF_HPP