#ifndef CARF_DC_HPP
#define CARF_DC_HPP

#include <iostream>
#include <cstdint>
#include <cstring>
// class Converter {
// public:
    
// re-implement dc.f2bf only for bfloat (softmax), rounding = False
// f2bf( self, f, bits=16, expo_bits=8, rounding=False, filter_denorm=True )
//TODO: rounding = true cases?

int32_t shift_round(int32_t data, int shift, bool rounding, bool bfloat) {
    // if (rounding) {
    //     if (bfloat) {
    //         // Implement custom half_inf rounding here if applicable
    //     }
    // } else {
    //     // Default sym_floor behavior, no rounding in this example
    // }

    int bits = sizeof(int32_t) * 8;
    
    if (shift >= 63 || shift <= -bits) {
        return 0;
    }
    
    int rshift = std::max(0, shift);
    int lshift = std::max(0, -shift);

    if (rounding && rshift > 0) {
        data += (1 << (rshift - 1));
    }
    
    if (rshift != 0) {
        data >>= rshift;
    }

    if (lshift != 0) {
        data <<= lshift;
    }

    return data;
}
float dc_f2bf(float f, int bits = 16, int expo_bits = 8, bool rounding = false, bool filter_denorm = false) {
    int expo_bias = 127;
    int bits_fp = bits + 8 - expo_bits;
    int32_t bi;
    std::memcpy(&bi, &f, sizeof(float));

    if (bits_fp < 32) {
        int32_t expo = (bi >> 23) & 255;
        
        if ((expo < 255) && ((expo_bits >= 8) || (expo > 127 - expo_bias))) {
            bi = shift_round(bi, 32 - bits_fp, rounding, true) << (32 - bits_fp);
        }
    }

    if ((bi & (0xFF << 23)) == 0) {
        bi &= (1 << 31);
    }

    float b;
    std::memcpy(&b, &bi, sizeof(float));
    return b;
}

// private:
// int32_t shift_round(int32_t data, int shift, bool rounding, bool bfloat) {
//     // if (rounding) {
//     //     if (bfloat) {
//     //         // Implement custom half_inf rounding here if applicable
//     //     }
//     // } else {
//     //     // Default sym_floor behavior, no rounding in this example
//     // }

//     int bits = sizeof(int32_t) * 8;
    
//     if (shift >= 63 || shift <= -bits) {
//         return 0;
//     }
    
//     int rshift = std::max(0, shift);
//     int lshift = std::max(0, -shift);

//     if (rounding && rshift > 0) {
//         data += (1 << (rshift - 1));
//     }
    
//     if (rshift != 0) {
//         data >>= rshift;
//     }

//     if (lshift != 0) {
//         data <<= lshift;
//     }

//     return data;
// }
// };

#endif // CARF_DC_HPP