using namespace std;

#include <iostream>
#include <iomanip> 
#include <vector>
#include <cstring>
#include <cassert>
#include <string>
#include <cmath>
#include "carf_dc.hpp"
#include "common.hpp"


using Telem = uint16_t;
using Telem_8b = uint8_t;
using Welem = int32_t;


// ================== DQ ====================

// input can be uint8 or uint16
template<typename T1, typename T3>
void dequant(T1* in_data, uint16_t* out_data, int h_in, int w_in, int c_in, float s, T3 z)
{
    for(int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in; ++c) {     
                // float val = dc_f2bf(in_data[(i * w_in * c_in) + (j * c_in) + c] - static_cast<float>(z));
                float val = bfloat16_to_float(float_to_bfloat16(in_data[(i * w_in * c_in) + (j * c_in) + c] - static_cast<float>(z)).value); //this matches with carf
                float val1 = val * s; 
                out_data[(i * w_in * c_in) + (j * c_in) + c] = float_to_bfloat16(val1).value;
            }
        }
    }
}

// ================== Q ====================

float round_half_to_odd(float inp) {
    float round_even_res = std::round(inp);
    float inp_minus_round_even_res = std::abs(inp - round_even_res);
    int bool_inp_minus_round_even_res = (inp_minus_round_even_res == 0.5) ? 1 : 0;

    float floor_res = std::floor(inp);
    int bool_floor_res_odd = (static_cast<int64_t>(floor_res) % 2 == 1) ? 1 : 0;
    int bool_floor_res_even = 1 - bool_floor_res_odd;

    float round_odd_res = round_even_res * (1 - bool_inp_minus_round_even_res) +
                           floor_res * (bool_inp_minus_round_even_res * bool_floor_res_odd) +
                           (floor_res + 1) * (bool_inp_minus_round_even_res * bool_floor_res_even);

    return round_odd_res;
}

float round_half_to_even(float value) {
    float int_part;
    float frac_part = std::modf(value, &int_part);  // Get integer and fractional parts

    if (std::abs(frac_part) == 0.5) {
        // Round half to nearest even
        return (static_cast<int>(int_part) % 2 == 0) ? int_part : int_part + (value > 0 ? 1 : -1);
    } else {
        // Use std::round for normal rounding
        return std::round(value);
    }
}

int16_t quant_linear_uint16_cstm(float inp, float inv_scale, float zero_pt) {

    float scaled_in_plus_zp = inp * inv_scale + zero_pt;
    float res;
    if (static_cast<int32_t>(zero_pt) % 2 == 0) {
        // If zero point is even, use round since it is round to even
        res = round_half_to_even(scaled_in_plus_zp);
    } else {
        // If zero point is odd, use custom round to odd code
        res = round_half_to_odd(scaled_in_plus_zp);
    }
    res = std::max(0.0f, std::min(res, 65535.0f)); //for int16
    return static_cast<int16_t>(res);
}
int16_t quant_linear_uint8_cstm(float inp, float inv_scale, float zero_pt) {

    float scaled_in_plus_zp = inp * inv_scale + zero_pt;
    float res;
    if (static_cast<int32_t>(zero_pt) % 2 == 0) {
        res = std::nearbyint(scaled_in_plus_zp);
    } else {
        res = round_half_to_odd(scaled_in_plus_zp);
    }
    res = std::max(0.0f, std::min(res, 255.0f)); //for int8
    return static_cast<int16_t>(res);
}

template<typename T3> 
void quant_bfloat16_to_uint16(uint16_t* in_data, uint16_t* out_data, int h_in, int w_in, int c_in, float inv_s, T3 z)
{
    for(int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in; ++c) { 
                float val = bfloat16_to_float(in_data[(i * w_in * c_in) + (j * c_in) + c]);
                out_data[(i * w_in * c_in) + (j * c_in) + c] = quant_linear_uint16_cstm(val,inv_s,static_cast<float>(z));
            }
        }
    }
}
template<typename T3> 
void quant_bfloat16_to_uint8(uint16_t* in_data, uint16_t* out_data, int h_in, int w_in, int c_in, float inv_s, T3 z)
{
    for(int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in; ++c) { 
                float val = bfloat16_to_float(in_data[(i * w_in * c_in) + (j * c_in) + c]);
                out_data[(i * w_in * c_in) + (j * c_in) + c] = quant_linear_uint8_cstm(val,inv_s,static_cast<float>(z));
            }
        }
    }
}
