#include <cstdint>
#include <cstring>
#include <cmath>
#include <limits>
#include <algorithm>

// Compute scale and zero_point for quantization
template <typename T>
void compute_scale_and_zp(float rmin, float rmax, float &scale, int32_t &zero_point) {
    int32_t qmin = std::numeric_limits<T>::min();
    int32_t qmax = std::numeric_limits<T>::max();

    scale = (rmax - rmin) / float(qmax - qmin);

    if (scale == 0.0f) {
        scale = 1.0f; // prevent divide by zero
    }

    float zero_point_f = qmin - rmin / scale;
    int32_t zp = static_cast<int32_t>(std::round(zero_point_f));
    zero_point = std::max(qmin, std::min(qmax, zp));
}

template <typename T>
T quantize(float x, float scale, int zero_point) {
    int q = static_cast<int>(std::round(x / scale) + zero_point);
    int qmin = std::numeric_limits<T>::min();
    int qmax = std::numeric_limits<T>::max();
    q = std::min(std::max(q, qmin), qmax);
    return static_cast<T>(q);
}

template <typename T>
float dequantize(T q, float scale, int zero_point) {
    return scale * (static_cast<int>(q) - zero_point);
}

inline uint16_t float_to_uint_bits(float x, bool is16) {
    uint32_t i;
    std::memcpy(&i, &x, sizeof(float));

    if (is16) {
        return static_cast<uint16_t>(i >> 16);
    } else {
        return static_cast<uint8_t>(i >> 24);
    }
}

inline float uint_bits_to_float(uint16_t bits, bool is16) {
    uint32_t i;
    if (is16) {
        i = static_cast<uint32_t>(bits) << 16;
    } else {
        i = static_cast<uint32_t>(static_cast<uint8_t>(bits)) << 24;
    }
    float x;
    std::memcpy(&x, &i, sizeof(float));
    return x;
}

uint16_t float_to_bfloat16(float x)
{
    uint32_t i;
    uint8_t* src = (uint8_t*) &x;
    uint8_t* tmp = (uint8_t*) &i;
    // copy float to uint32_t
    tmp[0] = src[0];
    tmp[1] = src[1];
    tmp[2] = src[2];
    tmp[3] = src[3];
    // round to nearest even
    uint32_t lsb = (i >> 16) & 0x1;
    uint32_t bias = 0x7fff + lsb;
    i += bias;
    // extract upper half of input
    uint16_t y = uint16_t(i >> 16);
    return y;
}

inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}

inline float bfloat16_to_float(uint16_t bf)
{
    return uint_to_float(uint32_t(bf) << 16);
}