#ifndef __PWL_NLF_BF16_TEMPLATE_H__
#define __PWL_NLF_BF16_TEMPLATE_H__
#include <adf.h>
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include <aie_api/utils.hpp>
#include "qdq/qdq_kernel_helpers.h"
#define NLF_LUT_SHIFT 5

const int lut_sz = 160*2;

template<typename To, typename Ti, unsigned V>
requires( !std::is_integral_v<Ti> && std::is_integral_v<To> )
inline aie::vector<To,V> convert_pwl( aie::vector<Ti,V> in)
{
    const aie::saturation_mode sat = aie::tile::current().get_saturation();
    aie::tile::current().set_saturation(aie::saturation_mode::saturate);
    constexpr unsigned Vop = std::max( V, 32u );
    using acc_fp_t  = aie::accum<accfloat, Vop>;
    using acc_int_t = aie::accum<acc32, Vop>;

    const acc_fp_t magic_l( aie::broadcast<int16, V>(0x4b01).template cast_to<bfloat16>( ));
    const acc_fp_t acc_input(in);
    acc_fp_t vfp = acc_input;
    acc_int_t vint;
    aie::vector<int16, Vop> out_h;
    aie::vector<To, V> output;

    vfp = vfp + magic_l;
    vint = aie::sub(vfp.template cast_to<acc32>( ), magic_l.template cast_to<acc32>( ));

    output = vint.template to_vector<To>( );


    aie::tile::current().set_saturation(sat);
    return output;
}
//int subv = 0;
//int col_id = get_coreid() >> 16;
//int row_id = get_coreid() & 0xf;
template<typename in_el_type = int16, typename out_el_type = int16>
__attribute__((noinline)) void pwla_nlf_bf16(
                                        int8_t* matIn
                                        ,int8_t* matOut
                                        ,int* lnr_lutab
                                        ,int* lnr_lutcd
                                        ,int num_elem
        ){
    aie::tile::current().set_saturation(aie::saturation_mode::saturate);
    aie::set_rounding(aie::rounding_mode::symmetric_zero);


    v32bfloat16 * restrict pIn16  = ( v32bfloat16 * ) matIn;
    v32bfloat16 * restrict pIn16_1  = ( v32bfloat16 * ) matIn;

    auto outIter = aie::begin_vector<32>(( bfloat16 * ) matOut);

    const bfloat16  shift_idx = 256.0;    //    (1 << shift_addr);   // TODO: convert to bf16 properly
    const v32bfloat16 shift_index = broadcast_to_v32bfloat16(shift_idx);

    const bfloat16  bias_s    = float(lut_sz/2) * 8.0;   // (lut_sz/2 << shift_bias);   // TODO: convert to bf16 properly
    v32accfloat bias    = ups_to_v32accfloat( broadcast_to_v32bfloat16( bias_s ));

    const bfloat16 idx_max_s =  4.96875;     //4.98;      //lut_sz*8.0-1.0;
    const bfloat16 idx_min_s =  -5.0;
    const v32bfloat16 idx_max  = broadcast_to_v32bfloat16(idx_max_s);
    const v32bfloat16 idx_min  = broadcast_to_v32bfloat16(idx_min_s);

    for ( int j=0; j<chess_copy( num_elem/32 ); j++ )
        //chess_prepare_for_pipelining
        chess_loop_range(1, )
    {
        v16int32 index1, index2;
        v32int16 coeff00, coeff10, slope_offset0, remainder;
        v32int16 coeff01, coeff11, slope_offset1;
        v32bfloat16 slope, offset;
        v32accfloat acc;
        v32bfloat16 inp_bf16;

        inp_bf16 = v32bfloat16(*pIn16++);
        inp_bf16 = max(inp_bf16, idx_min);
        inp_bf16 = min(inp_bf16, idx_max);
        //if(col_id == 0 && row_id == 4 && subv == 10){
        //    printf("\n Contents of Xbuff buffer \n");
        //    for(int i = 0; i < 32; i++){
        //        aie::detail::print_elem<bfloat16, 32>(inp_bf16, i);
        //    }
        //}

#if 0
        aie::vector<bfloat16, 32> index_fl = to_v32bfloat16(mac_elem_32( inp_bf16, shift_index, bias));
        // aie::vector<int16, 32> index_int16 = convert<int16>(index_fl);
        v32int16 index_int16 = aie::to_fixed<int16>(index_fl);
#else
        aie::vector<float, 32> index_fl = v32float(mac_elem_32( inp_bf16, shift_index, bias));
        v32int32 index_int = (convert_pwl<int32>(index_fl)).to_native();
#endif

        index1 = extract_v16int32(index_int,0);
        index2 = extract_v16int32(index_int,1);

        load_lut_2x_int16( (int*)lnr_lutab, (int*)lnr_lutcd, index1, coeff00, coeff10 );
        load_lut_2x_int16( (int*)lnr_lutab, (int*)lnr_lutcd, index2, coeff01, coeff11 );
        slope_offset0 = shuffle( coeff00, coeff10, T16_16x4_lo );
        slope_offset1 = shuffle( coeff01, coeff11, T16_16x4_lo );

        slope = v32bfloat16(shuffle( slope_offset0, slope_offset1, T256_2x2_hi));
        offset = v32bfloat16(shuffle( slope_offset0, slope_offset1, T256_2x2_lo));

        v32accfloat acc1 = mac_elem_32( slope,  *pIn16_1++ /*inp_bf16*/, ups_to_v32accfloat(offset));
        *outIter++ = to_v32bfloat16(acc1);
    }
    //subv++;
}
#endif
