#ifndef __PWL_NLF_TEMPLATE_H__
#define __PWL_NLF_TEMPLATE_H__
#include <adf.h>
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include <aie_api/utils.hpp>
#define NLF_LUT_SHIFT 5

const int lut_sz = 160*2;
const int lut_frac_bits = 14;

// Number of elements
//static int const CORE_IN_SIZE = 4096;
//static int const CORE_OUT_SIZE = CORE_IN_SIZE;
static int const frac_bits = 10;

const int scratch_size = 8*8;
alignas(32) int scratch_pad[scratch_size];

//inline v32acc64 lups_mul( v16int32 in, int shift ) {
//  #ifdef __ndl__ //WA for CRVO-4013
//    cint16 coeff = as_cint16( 0xFFFF & ( -1<<shift ));
//  #else
//    cint16 coeff;
//    coeff.real = -1<<shift;
//    coeff.imag = 0;
//  #endif
//    return ( v32acc64 )negmul_elem_8(( v8cint32 )in, broadcast_c16( coeff ));
//}
//inline v32acc64 lups_mac( v16int32 in, int shift, v32acc64 acc ) {
//  #ifdef __ndl__ //WA for CRVO-4013
//    cint16 coeff = as_cint16( 0xFFFF & ( -1<<shift ));
//  #else
//    cint16 coeff;
//    coeff.real = -1<<shift;
//    coeff.imag = 0;
//  #endif
//    return ( v32acc64 )msc_elem_8(( v8cint32 )in, broadcast_c16( coeff ), ( v16cacc64 )acc );
//}

template<typename in_el_type = int16, typename out_el_type = int16>
__attribute__((noinline)) void pwla_nlf(
                                        int8_t* matIn
                                        ,int8_t* matOut
                                        ,int* lnr_lutab
                                        ,int* lnr_lutcd
                                        ,int nlf_mode
                                        ,int num_elem
        ){

    //uint64_t start, end;
    //start = get_cycles();
#if 0
    const aie::lut<4, int16> my_lut(lut_sz,lnr_lutab,lnr_lutcd);
 
    //calling linear_approx with my_lut, step_bits=3, bias=0, shift_offset=0
    aie::linear_approx<int16, aie::lut<4, int16, int16>> linear_ap(my_lut, frac_bits-NLF_LUT_GELU_SHIFT, (lut_sz/2),frac_bits);
    auto it=aie::begin_vector<16>(index);
    auto ot=aie::begin_vector<16>(out);
    for(int i=0;i<CORE_IN_SIZE/16;i++){ 
       aie::vector<int16,16> vin=*it++;
       *ot++ = linear_ap.compute(vin).to_vector<int16>(0);
    }
#else

    int * restrict scratch = (int*)conv_to_local_ptr(0x2000);
   
    v16int16 * restrict pIn16  = ( v16int16 * ) matIn;
    v16int16 * restrict pOut16 = ( v16int16 * ) matOut;
    v16int8 * restrict pIn8    = ( v16int8 * ) matIn;
    v16int8 * restrict pOut8   = ( v16int8 * ) matOut;

    v16int16 * pDin  = ( v16int16 * ) scratch;
    v16int16 * pDout = ( v16int16 * ) scratch;

    #ifdef RND_FLOOR
        int rnd = 0;
    #else
        int rnd = 1;
    #endif

    static_assert(frac_bits >= NLF_LUT_SHIFT || sizeof(in_el_type)==1);
    int8_t shift_in = 0;
    if (chess_manifest(sizeof(in_el_type)==1)) {
        shift_in     = NLF_LUT_SHIFT;
    }

    int in_frac_bits = frac_bits + shift_in;
    int8_t shift_out    = lut_frac_bits + shift_in;    //params.shift_res;
    const int8_t shift_addr   = in_frac_bits-NLF_LUT_SHIFT-3;    // params.shift_norm;
    int8_t shift_bias   = shift_addr + 3;  //shift_addr-rnd;
    int    bias_s       = (lut_sz/2);  //( params.step_Kx << rnd ) - rnd;
    int    idx_max_s    =  lut_sz*8-1;  //params.step_Ky - 1;
    int    idx_min_s    =  0;  //params.step_Ky - 1;
    int8_t shift_offset = in_frac_bits;   //arams.shift_bias;

    int8_t rem_bits     = in_frac_bits-NLF_LUT_SHIFT;
    rem_bits = rem_bits < 0 ? 0 : rem_bits;

    // nlf_mode = 1 - gelu, 2 - sigmoid
    int16_t rem_mask_s = ( nlf_mode == 1 ) ? 0xffff : ( 1 << rem_bits ) - 1;

    addr_t cntDin  = 0;
    addr_t cntDout = 0;
    int numD  = scratch_size / 32 - 1;
    int incD1 = 32;
    int incD2 = -incD1 * numD;

    v32acc64  bias     = set_v32acc64( 0, lups( broadcast_s32( bias_s ), shift_bias ));
    v16int32 idx_max  = broadcast_s32( idx_max_s );
    v16int32 idx_min  = broadcast_s32( idx_min_s );
    v32int16 rem_mask = broadcast_s16( rem_mask_s );

    for ( int j=0; j<num_elem/16; j++ )
        chess_prepare_for_pipelining
        //chess_unroll_loop(2)
        chess_loop_range( 2, )
        //chess_modulo_scheduling_budget_ratio( 500 )
    {
        v16int32 index;
        v32int16 coeff0, coeff1, slope_offset, remainder;
        v32acc64 acc;


        if (chess_manifest(sizeof(in_el_type)==1)) {
            v16int8 inp = *pIn8++;
            acc = lups( unpack(concat(inp,inp)), shift_in);
        } else {  // input datatype int16
            acc    = set_v32acc64( 0, lups( *pIn16++, shift_in )); 
        }

        *pDout = ssrs( extract_v16acc64( acc, 0 ), 0);       pDout = add_2d_byte( pDout, incD2, numD, cntDout, incD1 );
        acc   += bias;
        index = lsrs( extract_v16acc64( acc, 0 ), shift_addr );
        //index = min( index, idx_max );
        //index = max( index, idx_min );

        load_lut_2x_int16( (int*)lnr_lutab, (int*)lnr_lutcd, index, coeff0, coeff1 );
        slope_offset = shuffle( coeff0, coeff1, T16_16x4_lo );

        remainder = set_v32int16( 1, ( v16int16 )*pDin );        pDin = add_2d_byte( pDin, incD2, numD, cntDin, incD1 );
        remainder = band( remainder, rem_mask );

        acc = set_v32acc64( 1, lups( extract_v16int16( slope_offset, 0 ), shift_offset ));
        acc = mac_elem_32( slope_offset, remainder, acc );

        if (chess_manifest(sizeof(out_el_type)==1)) {
            v32int8 tmp = pack(ssrs( acc, shift_out ));
            *pOut8++ = extract_v16int8( tmp, 1 );
        } else {  // input datatype int16
            *pOut16++ = ssrs( extract_v16acc64( acc, 1 ), shift_out );
        }

        //*pOut16++ = extract_v16int16( remainder, 1 );
        //*pOut16++ = ssrs( extract_v16acc64( acc, 0 ), 0 ); 
        //*pOut16++ = lsrs( v16acc32(index), 0 ); 
        //*pOut16++ =  extract_v16int16( slope_offset, 1 );
        //*pOut16++ =  extract_v16int16( coeff1, 0);
    }


#endif


    //end = get_cycles();
    //*((uint16_t*) out.data()) = (uint16_t)(end - start);
}
#endif
