// This act as uniop's weight formatter :
#include <string>
#include <stdio.h>
#include <iostream>
#include "qdq_utils_aie4.hpp"
#define _USE_MATH_DEFINES
#include <cmath>

#define NUM_ELEMENTS_DQ_Q_BUF 96
#define NUM_RESERVE_FIELDS 25
struct uniop_qdq_param
{
    enum FloatType {
        BFLOAT16,
        FLOAT16,
        FLOAT32
    };

    unsigned int dq_enable;
    unsigned int q_enable;
    unsigned int nlf_enable;
    unsigned int dq_sign_A;
    unsigned int dq_float_type;
    unsigned int q_sign_O;
    unsigned int q_float_type;
    unsigned int reserved[NUM_RESERVE_FIELDS];

    // Data fields up to this is 128 byte aligned

    float dq_buf[NUM_ELEMENTS_DQ_Q_BUF];
    float q_buf[NUM_ELEMENTS_DQ_Q_BUF];

    uniop_qdq_param()
    {
    }

    void init()
    {
        dq_enable = 0;  // false
        q_enable = 0;   // false
        nlf_enable = 1; // true

        dq_sign_A = 0;
        q_sign_O = 0;

        dq_float_type = FLOAT16;
        q_float_type = FLOAT16;
    }

    void clear_reserve_fields()
    {
        for(int i = 0; i < NUM_RESERVE_FIELDS; i++)
            reserved[i] = 0;

        for(int i = 0; i < NUM_ELEMENTS_DQ_Q_BUF; i++)
            dq_buf[i] = 0;

        for(int i = 0; i < NUM_ELEMENTS_DQ_Q_BUF; i++)
            q_buf[i] = 0;
    }

    void set_uniop_dq(float zp_i, float sc_i, int sign_A=0, int dq_f_type=FLOAT16)
    {
        dq_enable = 1;
        dq_sign_A = sign_A;
        dq_float_type = dq_f_type;
        for(int i = 0; i < 32; i++) {
            dq_buf[i] = zp_i;
        }

        for(int i = 0; i < 32; i++) {
            dq_buf[64+i] = sc_i;
        }
    }

    void set_uniop_q(float zp_o, float sc_o, int sign_O=0,  int q_f_type=FLOAT16)
    {
        q_enable = 1;
        q_sign_O = sign_O;
        q_float_type = q_f_type;
        for(int i = 0; i < 32; i++) {
            q_buf[i] = zp_o;
        }

        for(int i = 0; i < 32; i++) {
            q_buf[64+i] = sc_o;
        }
    }
};


template<typename TG, typename TB, typename TLUT>    // Type of Gamma / Type of Beta
struct UnaryOpWgtTensor
{
    uniop_qdq_param* unary_op_params;  // Qdq and Silu Gelu coefficients
    int const inner_dim;               // "Padded" inner true dimension
    int const data_size;              // number of bytes of Constant bo buffer required for Unary op
    uint8_t* const data;              // Stores the base address of Constant bo buffer of Unary op
    uint32_t lut_len;                 // Stores length of the LUT table
    static uint8_t const num_lut_elements_each_lut_entry = 2;
    static uint8_t const num_lut_entry_copies_in_each_lut_table = 2;
    static uint8_t const num_lut_tables = 2;
    static uint8_t const num_copies_of_each_lut_table_entry = num_lut_elements_each_lut_entry * num_lut_entry_copies_in_each_lut_table * num_lut_tables;

    UnaryOpWgtTensor(int C, uint32_t lut_len, void* data)
    : inner_dim(C)
    , data_size(size(C, lut_len))
    , data(static_cast<uint8_t*>(data))
    , lut_len(lut_len * num_lut_elements_each_lut_entry * num_lut_entry_copies_in_each_lut_table) // each LUT entry has two values (slope and offset) and every 16 elements in the calcuated lut table is copied twice in the final lut table
    {
        unary_op_params = (uniop_qdq_param*)data;
        unary_op_params->init();
        unary_op_params->clear_reserve_fields();
    }

    // It is expected the runtime use this function to calculate the allocation size for constant buffer
    static int size(int C, uint32_t lut_len)
    {
        //
        // each lut table entry has two values (slope and offset) and every 16 elements in the calculated lut table is copied twice in the final lut table.
        //
        return sizeof(uniop_qdq_param) + (sizeof(TG) + sizeof(TB)) * C + sizeof(TLUT) * lut_len * num_copies_of_each_lut_table_entry;
    }

    TG& Gamma_at(int idx) // Assuming Gamma is 1 x C, which is a vector (one dimensional)
    {
        assert((0 <= idx) && (idx < inner_dim));
        auto ptr = reinterpret_cast<TG*>(data + sizeof(uniop_qdq_param) + inner_dim * sizeof(TB));
        return ptr[0];
    }

    TB& Beta_at(int idx) // Assuming Beta  is 1 x C, which is a vector (one dimensional)
    {
        assert((0 <= idx) && (idx < inner_dim));
        auto ptr = reinterpret_cast<TB*>(data + sizeof(uniop_qdq_param));
        return ptr[0];
    }

    //
    // Routine to return pointer to an element in lut ab
    //
    TLUT& Lut_ab_at(int idx)
    {
        assert((0 <= idx) && (idx < lut_len));
        auto ptr = reinterpret_cast<TLUT*>(data + sizeof(uniop_qdq_param) + inner_dim * sizeof(TB) + inner_dim * sizeof(TG));
        return ptr[idx];
    }

    //
    // Routine to return pointer to an element in lut cd
    //
    TLUT& Lut_cd_at(int idx)
    {
        assert((0 <= idx) && (idx < lut_len));
        auto ptr = reinterpret_cast<TLUT*>(data + sizeof(uniop_qdq_param) + inner_dim * sizeof(TB) + inner_dim * sizeof(TG) + lut_len * sizeof(TLUT));
        return ptr[idx];
    }

    void set_dq(float zp_i, float sc_i, int sign_A=0, int dq_f_type=uniop_qdq_param::FloatType::FLOAT16)
    {
        unary_op_params->set_uniop_dq(zp_i, sc_i, sign_A, dq_f_type);
    }

    void set_q(float zp_o, float sc_o, int sign_O=0, int q_f_type=uniop_qdq_param::FloatType::FLOAT16)
    {
        unary_op_params->set_uniop_q(zp_o, sc_o, sign_O, q_f_type);
    }
};

template<typename Ta, typename Tb>
float check_result_rmse(Ta cpu_Y, Tb aie_Y, float max_relative_err_percentage_tolerance, bool output_fp16 = true, bool verbose = true)
{
    int err_count = 0;
    double SumErSq = 0.0;
    float SumAbsErr = 0.0;
    float relative_error = 0.0;
    float max_error = 0.0;
    float maxEP = 0.0;

    using namespace std;
#if __IS_QDQ_FP16__
    float (*funcPtr)(uint16_t) = fp16_to_fp32;
    double val_threshold = 5*1e-5;
#else
    float (*funcPtr)(uint16_t) = bfloat2float;
    double val_threshold = 5*1e-8;
#endif

    for (int b = 0; b < cpu_Y.Y; ++b) {
        for (int r = 0; r < cpu_Y.X; ++r) {
            for (int c = 0; c < cpu_Y.C; ++c) {

                // fp16_to_fp32 or bfloat2float
                float cpu_val = (std::is_same<Ta, float>::value)? cpu_Y.at(c, b, r) : funcPtr(cpu_Y.at(c, b, r));
                float aie_val = (std::is_same<Tb, float>::value)? aie_Y.at(c, b, r) : funcPtr(aie_Y.at(c, b, r));

                uint16_t cpu_val_u16 = cpu_Y.at(c, b, r);
                uint16_t aie_val_u16 = aie_Y.at(c, b, r);

                float err = (output_fp16)? std::abs(cpu_val - aie_val) : std::abs(cpu_val_u16 - aie_val_u16);
                SumAbsErr += err;
                SumErSq += std::pow((double)err, 2);
                double abs_cpu_val = (output_fp16)? std::abs(cpu_val) : std::abs(cpu_val_u16);
                double relative_err_percentage;
                relative_err_percentage = (abs_cpu_val == 0.0) ? 0.0 : (err / abs_cpu_val) * 100;
                //double relative_err_percentage = (abs_cpu_val == 0.0) ? 0.0 : (std::abs(cpu_Y.at(r, c) - aie_Y.at(r, c)) / abs_cpu_val) * 100;

                if(abs_cpu_val < val_threshold)
                    relative_err_percentage = 0;
                relative_error += relative_err_percentage;

                if (err > max_error)
                    max_error = err;

                if (relative_err_percentage > maxEP)
                    maxEP = relative_err_percentage;

                if (relative_err_percentage > max_relative_err_percentage_tolerance)
                {
                    err_count += 1;
                }

                if (verbose)
                {
                    if(!output_fp16)
                        printf("Y[%4d, %4d]: Expected: %5hu, Received: %5hu, Pct Diff: %4.4f, Error : %4.8e\n", r,c, cpu_val_u16, aie_val_u16, relative_err_percentage, err);
                    else
                        printf("Y[%4d, %4d]: Expected: %4.8e, Received: %4.8e, Pct Diff: %4.4f, Error : %4.8e\n", r,c, cpu_val, aie_val, relative_err_percentage, err);
                }
            }
        }
    }
    int total_number_elements = (cpu_Y.Y * cpu_Y.X * cpu_Y.C);
    std::cout << "total_number_elements: "  << total_number_elements << "\n";
    float RMSE = std::sqrt((float)SumErSq / (float)(cpu_Y.Y * cpu_Y.X * cpu_Y.C));
    float MAE = SumAbsErr / (float)(cpu_Y.Y * cpu_Y.X * cpu_Y.C);
    float average_relative_error = relative_error / (float)(cpu_Y.Y * cpu_Y.X * cpu_Y.C);
    std::cout << "Root Mean square Error = " << RMSE << "\n";
    std::cout << "Max Error = " << max_error << "\n";
    std::cout << "Mean Absolute Error = " << MAE << "\n";
    std::cout << "Average Relative Error Percentage = " << average_relative_error << "\n";
    std::cout << "Error Count = " << err_count << "\n";
    std::cout << "Max Relative Error Percentage = " << maxEP << "\n";

    if (average_relative_error > 0.8) {
        printf("DI_FAIL: Y=%d X=%d C=%d\n", cpu_Y.Y, cpu_Y.X, cpu_Y.C);
    } else {
        printf("DI_PASS: Y=%d X=%d C=%d\n", cpu_Y.Y, cpu_Y.X, cpu_Y.C);
    }
    return average_relative_error;
}

//
// Routine to calculate LUT values for tanh, sigmoid and swish operators
//
template <typename TGDQ, typename TBDQ, typename TLUT>
void uniop_calculate_lut(std::string function,
                        int32_t lut_start_val,
                        int32_t lut_end_val,
                        uint32_t lut_len,
                        UnaryOpWgtTensor<TGDQ, TBDQ, TLUT> *ptr_uniop_wgt_tensor)
{
    const uint8_t LUT_BLOCK_SIZE = 32;
    static const float tanh_lower_limit = -2.3f;
    static const float tanh_upper_limit = 2.3f;
    auto cfg = load_json("uniop_cfg.json");
    float alpha;
    //
    // Calculate the LUT linear space for the given LUT range and LUT length
    //
    std::vector<float> LUT_indx_vec = linspace(lut_start_val, lut_end_val, lut_len + 1);
    //
    // Calculate LUT step size
    //
    float step_size = ((float)lut_end_val - (float)lut_start_val) / (float)lut_len;

    //
    // Allocate vector for LUT samples
    //
    std::vector<float> nlf_samples(LUT_indx_vec.size());
    //
    // Allocate a vector to store LUT values
    // the vector size is multiplied by 2 to store both slope and offset
    //
    std::vector<TLUT> nlf_LUT((LUT_indx_vec.size() - 1) * 2);
    
    if (function == "swish" || function == "elu") {
        //
        // Retrieve alpha value from config json file
        //
        alpha = extract_json_float(cfg, "PWLA_ALPHA");
    }
    //
    // Calculate LUT samples based on optype
    //
    for (size_t i = 0; i < LUT_indx_vec.size(); ++i) {
        if (function == "swish") {
            nlf_samples[i] = LUT_indx_vec[i] * (1 / (1 + std::exp(-alpha * LUT_indx_vec[i])));
        } else if (function == "tanh") {
            float e_pow_x = std::exp(LUT_indx_vec[i]);
            float e_pow_minus_x = std::exp(-LUT_indx_vec[i]);
            nlf_samples[i] = (e_pow_x - e_pow_minus_x) / (e_pow_x + e_pow_minus_x);
        } else if (function == "sigmoid") {
            nlf_samples[i] = 1 / (1 + std::exp(-LUT_indx_vec[i]));
        } else if (function == "silu") {
            nlf_samples[i] = LUT_indx_vec[i] * (1 / (1 + std::exp(-LUT_indx_vec[i])));
        } else if (function == "gelu") {
            float tanh_x = (std::sqrt(2.0f / (float)M_PI) * (LUT_indx_vec[i] + (0.044715f * std::pow(LUT_indx_vec[i], 3.0f))));
            float e_pow_x = std::exp(tanh_x);
            float e_pow_minus_x = std::exp(-tanh_x);
            float tanh_val = (e_pow_x - e_pow_minus_x) / (e_pow_x + e_pow_minus_x);
            nlf_samples[i] = 0.5f * LUT_indx_vec[i] * (1 + tanh_val);
        } else if (function == "elu") {
            if (LUT_indx_vec[i] < 0) {
                nlf_samples[i] = alpha * (std::exp(LUT_indx_vec[i]) - 1);
            }
            else {
                nlf_samples[i] = LUT_indx_vec[i];
            }
        }
    }

    //
    // Calculate slope and offset and copy the values to LUT table
    //
    for (size_t i = 0, j = 0; i < nlf_samples.size() - 1; i++, j += 2) {
        float slope = (nlf_samples[i + 1] - nlf_samples[i]) / step_size;
        float offset = (nlf_samples[i] - slope * LUT_indx_vec[i]);
        //
        // check if TLUT size is uint16 if so convert float to either fp16 or bfloat16
        //

        if (sizeof(TLUT) == sizeof(uint16_t)) {
#if (__IS_QDQ_FP16__)
            nlf_LUT[j] = float32_to_float16(offset);
            nlf_LUT[j + 1] = float32_to_float16(slope);
#else
            nlf_LUT[j] = float_to_bfloat16(offset).value;
            nlf_LUT[j + 1] = float_to_bfloat16(slope).value;
#endif
        } else {
            nlf_LUT[j] = offset;
            nlf_LUT[j + 1] = slope;
        }
    }

    //
    // Adjust first and last slope and offset values to reflect the graph shape below and above LUT range
    //
    //
    // LUT is two times the size of nlf samples - 1.
    // Here nlf_samples size is lut_len + 1 = 513.
    // LUT size = (513 - 1) * 2 = 1024.
    // Last entry has both slope and offset values so -2.
    //
    size_t last_entry = ((nlf_samples.size() - 1) * 2) - 2;
    if (function == "silu" || function == "gelu") {
        nlf_LUT[0] = 0;
        nlf_LUT[1] = 0;
        nlf_LUT[last_entry] = 0;
        if (sizeof(TLUT) == sizeof(uint16_t)) {
#if (__IS_QDQ_FP16__)
            nlf_LUT[last_entry + 1] = float32_to_float16(1.0f);
#else
            nlf_LUT[last_entry + 1] = float_to_bfloat16(1.0f).value;
#endif
        }
        else {
            nlf_LUT[last_entry + 1] = 1;
        }
    }

    //
    // Copy the calculated LUT values to wgt buffer
    // Every 16 entries of the calculated LUT is copied twice into the wgt buffer
    //
    TLUT *lnr_lutab = (TLUT*)(&ptr_uniop_wgt_tensor->Lut_ab_at(0));
    TLUT *lnr_lutcd = (TLUT*)(&ptr_uniop_wgt_tensor->Lut_cd_at(0));
    for (size_t i = 0; i < nlf_LUT.size(); i += LUT_BLOCK_SIZE) {
        //
        // Copy entries indicated in LUT BLOCK SIZE in the current LUT BLOCK in lut ab.
        //
        std::copy(nlf_LUT.begin() + i, nlf_LUT.begin() + i + LUT_BLOCK_SIZE, lnr_lutab + 2 * i);
        //
        // Copy the same entries indicated in LUT BLOCK SIZE to next LUT BLOCK in lub ab.
        // 
        std::copy(nlf_LUT.begin() + i, nlf_LUT.begin() + i + LUT_BLOCK_SIZE, lnr_lutab + 2 * i + LUT_BLOCK_SIZE);
        //
        // Copy entries indicated in LUT BLOCK SIZE in the current LUT BLOCK in lut cd.
        //
        std::copy(nlf_LUT.begin() + i, nlf_LUT.begin() + i + LUT_BLOCK_SIZE, lnr_lutcd + 2 * i);
        //
        // Copy the same entries indicated in LUT BLOCK SIZE to next LUT BLOCK in lub ab.
        //
        std::copy(nlf_LUT.begin() + i, nlf_LUT.begin() + i + LUT_BLOCK_SIZE, lnr_lutcd + 2 * i + LUT_BLOCK_SIZE);
    }
}

template <typename TG, typename TB, typename TGDQ, typename TBDQ, typename TLUT>
inline void uniop_init_gamma_beta_model_data(
    std::string function,
    std::string const_path,
    std::string node_name,
    UnaryOpWgtTensor<TGDQ, TBDQ, TLUT> *ptr_uniop_wgt_tensor,
    bool dequant_enable,
    uint32_t nopad_gamma_beta_dim,
    uint32_t padded_gamma_beta_dim)
{
    // get scale, zp
    std::vector<nlohmann::json> scale_zp = waic_runtime_aie4::get_scale_zp_vector(const_path, node_name);

    //
    // gamma scale value will be at index 2 if dequant is enabled else at 0 in config file.
    //
    int gamma_idx_start = (dequant_enable) ? 2 : 0;
    //
    // read gamma (scale) and beta (bias) params from B.bin and Scale.bin
    //
    std::vector<float> gamma_dq_s;
    std::vector<int64_t> gamma_dq_zp;
    if (scale_zp[2].is_array())
    {
        gamma_dq_s = std::vector<float>(scale_zp[gamma_idx_start]);
        gamma_dq_zp = std::vector<int64_t>(scale_zp[gamma_idx_start + 1]);
    }
    else
    {
        gamma_dq_s.push_back(scale_zp[gamma_idx_start]);
        gamma_dq_zp.push_back(scale_zp[gamma_idx_start + 1]);
    }
    
    std::vector<float> beta_dq_s;
    std::vector<int64_t> beta_dq_zp;

    beta_dq_s.push_back(scale_zp[gamma_idx_start + 2]);
    beta_dq_zp.push_back(scale_zp[gamma_idx_start + 3]);

    std::vector<TGDQ> gamma_dq;
    std::vector<TBDQ> beta_dq;

    gamma_dq.resize(nopad_gamma_beta_dim);
    beta_dq.resize(nopad_gamma_beta_dim);

    uint32_t gamma_size = nopad_gamma_beta_dim;
    uint32_t beta_size = nopad_gamma_beta_dim;
    waic_runtime_aie4::read_scale_bias_bins<TG, TB, TGDQ, TBDQ>(function,
                                                                const_path,
                                                                node_name,
                                                                gamma_dq_s,
                                                                gamma_dq_zp,
                                                                beta_dq_s,
                                                                beta_dq_zp,
                                                                gamma_dq,
                                                                gamma_size,
                                                                beta_dq,
                                                                beta_size);
    //
    // copy gamma (scale) and beta (bias) values to wgt tensor
    //
    TBDQ* bias_ptr = (TBDQ*)(&(ptr_uniop_wgt_tensor->Beta_at(0)));
    TGDQ* scale_ptr = (TGDQ*)(&(ptr_uniop_wgt_tensor->Gamma_at(0)));

    //
    // Check if gamma and beta bin file sizes are the same.
    // If not print error message.
    //
    if (gamma_size != beta_size)
    {
        std::cout << "Error: Gamma Beta bin file size mismatch" << std::endl;
    }

    //
    // Copy dequantized values to wgt buffer.
    //
    for (uint32_t index = 0; index < gamma_size; index++)
    {
        *scale_ptr = gamma_dq[index];
        *bias_ptr = beta_dq[index];
        scale_ptr++;
        bias_ptr++;
    }
    //
    // Set 0 to padded gamma and beta dimensions
    // 
    for (uint32_t index = gamma_size; index < padded_gamma_beta_dim; index++)
    {
        *scale_ptr = 0;
        *bias_ptr = 0;
        scale_ptr++;
        bias_ptr++;
    }
}

template <typename TGDQ, typename TBDQ, typename TLUT>
inline void uniop_init_model_data(
    std::string const_path,
    std::string node_name,
    UnaryOpWgtTensor<TGDQ, TBDQ, TLUT> *ptr_uniop_wgt_tensor,
    bool dequant_enable,
    bool quant_enable,
    int sign_A,
    int sign_O,
    std::string dq_float_str="float16",
    std::string q_float_str="float16"
) {
    // get scale, zp
    std::vector<nlohmann::json> scale_zp = waic_runtime_aie4::get_scale_zp_vector(const_path, node_name);
    
	// get scale/zp size
    uint16_t scale_zp_size = scale_zp.size();

    if(dequant_enable)
    {
        int dq_float_type = uniop_qdq_param::FloatType::FLOAT16;
        if(dq_float_str == "float32")
           dq_float_type = uniop_qdq_param::FloatType::FLOAT32;
        else if(dq_float_str == "bfloat16")
           dq_float_type = uniop_qdq_param::FloatType::BFLOAT16;

        float in_s = scale_zp[0];
        float in_zp = scale_zp[1];
        ptr_uniop_wgt_tensor->set_dq(in_zp, in_s, sign_A, dq_float_type);
    }

    if(quant_enable)
    {
        float o_s;
        float o_zp;

        int q_float_type = uniop_qdq_param::FloatType::FLOAT16;
        if(q_float_str == "float32")
           q_float_type = uniop_qdq_param::FloatType::FLOAT32;
        else if(q_float_str == "bfloat16")
           q_float_type = uniop_qdq_param::FloatType::BFLOAT16;

        //
        // out scale and zp are always last to elements in scale zp vector
        //
        o_s = scale_zp[scale_zp_size - 2];
        o_zp = scale_zp[scale_zp_size - 1];

        ptr_uniop_wgt_tensor->set_q(o_zp, 1.0f/o_s, sign_O, q_float_type);
    }
}

