#ifndef __ADD2D_BF16x16_IMPL_H__
#define __ADD2D_BF16x16_IMPL_H__

#include "aie_api/aie.hpp"
#include "broadcast/add2d_bf16x16.hpp"

using namespace aie;


void add2d_bf16x16_nontemplatized 
(
    unsigned has_scalar_broadcast,
    QDQFloatType * input,
    QDQFloatType * weights,
    QDQFloatType * __restrict output,
    const Add2dBf16x16Params &params
){
    if (has_scalar_broadcast){
        add2d_bf16x16<6, 1, QDQFloatType, QDQFloatType>(
            input, weights, output, params
        );
    } else {
        add2d_bf16x16<6, 0, QDQFloatType, QDQFloatType>(
            input, weights, output, params
        );
    }
}


template<unsigned loop_range=6, unsigned has_scalar_broadcast = 0, typename Ti = QDQFloatType, typename To = QDQFloatType>
ALWAYS_INLINE void add2d_bf16x16 
(
    Ti * input,
    Ti * weights,
    To * __restrict output,
    const Add2dBf16x16Params &params
){
    Ti * pA = input;
    Ti * pW = weights;
    To * pO = output;

    constexpr unsigned gran = 64;
    Ti one = 1.0f;
    dims_2d_t dims_x = params.dims_x.instantiate( );
    dims_2d_t dims_y = params.dims_y.instantiate( );

    for ( unsigned j = 0; j < params.outer_loop; j++ )
        chess_prepare_for_pipelining
        chess_loop_range( loop_range, )
    {


        vector<Ti, gran> Ybuff;
        accum<accfloat, 32> acc0, acc1;
        vector<To, 32> out0, out1;

        /* Load data */
        acc0.from_vector( load_v<32>( pA ));
        acc1.from_vector( load_v<32>( pA + 32 ));
        if constexpr (has_scalar_broadcast == 0){
            Ybuff = load_v<gran>( pW );
        } else {
            Ybuff = aie::broadcast<Ti,gran>( *pW );
        }
        
        /* Addition */
        acc0 = mac_elem_32_conf( one, Ybuff.template extract<32>( 0 ), acc0, 0 , params.is_sub , 0);
        acc1 = mac_elem_32_conf( one, Ybuff.template extract<32>( 1 ), acc1, 0 , params.is_sub , 0);

        out0 = acc0.template to_vector<To>( );
        out1 = acc1.template to_vector<To>( );


        /* Store output */
        store_v( pO + 32, out1 );
        store_v( pO, out0 );

        /* Increment pointers */
        pA = add_2d_byte( pA, dims_x );
        pW = add_2d_byte( pW, dims_y );

        pO += gran;
    }
}

#endif