#ifndef QDQ_INT8_BFLOAT16_HPP
#define QDQ_INT8_BFLOAT16_HPP

template<unsigned Elems>
aie::vector<bfloat16, Elems> __attribute__((always_inline)) dequant_vector(
    aie::vector<uint8, Elems> vec,
    uint8 zero_point, bfloat16 scale, bool sign)
{
    aie::vector<int16, Elems>    v0 = vec.unpack_sign(sign).template cast_to<int16>();
    aie::vector<int16, Elems>    v1 = aie::sub(v0, int16(zero_point));
    aie::vector<bfloat16, Elems> v2 = convert<bfloat16>(v1);
    aie::accum<accfloat, Elems> acc = aie::mul(v2, scale);
    aie::vector<bfloat16, Elems> v3 = acc.template to_vector<bfloat16>();
    return v3;
}

template<unsigned Elems>
aie::vector<uint8, Elems> __attribute__((always_inline)) quant_vector(
    aie::vector<bfloat16, Elems> vec,
    uint8 zero_point, bfloat16 inv_scale, bool sign)
{
    aie::accum<accfloat, Elems> acc = aie::mul(vec, inv_scale);
    aie::vector<bfloat16, Elems> v0 = acc.template to_vector<bfloat16>();
    aie::vector<int16, Elems>    v1 = convert<int16>(v0);
    aie::accum<acc32, Elems>     v2(aie::add(v1, int16(zero_point)));
    aie::vector<uint8, Elems>    v3 = v2.template to_vector_sign<uint8>(sign);
    return v3;
}

void __attribute__((noinline)) dequant_int8_to_bf16(
    int8_t* q_in, int8_t* dq_out, int num_elems,
    uint8 zero_point, bfloat16 scale, bool sign = true, bool dq_enable = true)
{   
    if(dq_enable){
        set_sat();
        int loop_count = chess_copy(num_elems / 64);
        auto v_in = aie::cbegin_vector<64>(reinterpret_cast<uint8*>(q_in));
        auto v_out = aie::begin_vector<64>(reinterpret_cast<bfloat16*>(dq_out));
        for (int i = 0; i < loop_count; ++i) chess_loop_range(2,) {
            *v_out++ = dequant_vector<64>(*v_in++, zero_point, scale, sign);
        }
    }
}

void __attribute__((noinline)) quant_bf16_to_int8(
    int8_t* dq_in, int8_t* q_out, int num_elems,
    uint8 zero_point, bfloat16 inv_scale, bool sign = true, bool q_enable = true)
{
    if(q_enable){
        set_sat();
        int loop_count = chess_copy(num_elems / 64);
        auto v_in = aie::cbegin_vector<64>(reinterpret_cast<bfloat16*>(dq_in));
        auto v_out = aie::begin_vector<64>(reinterpret_cast<uint8*>(q_out));
        for (int i = 0; i < loop_count; ++i) chess_loop_range(2,) {
            *v_out++ = quant_vector<64>(*v_in++, zero_point, inv_scale, sign);
        }
    }
}

#endif // QDQ_INT8_BFLOAT16_HPP

