#ifndef __KERNEL_DIV_BF16X16_IMPL_HPP__
#define __KERNEL_DIV_BF16X16_IMPL_HPP__

#include "div_bf16x16.hpp"
#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"

using namespace aie;

void div_bf16x16_nontemplatized 
(
    unsigned has_scalar_broadcast,
    bfloat16 * input,
    bfloat16 * weights,
    bfloat16 * __restrict output,
    const KernelDivBf16x16Param &params
){
    if (has_scalar_broadcast){
        div_bf16x16<1, bfloat16, bfloat16, bfloat16>(
            input, weights, output, params
        );
    } else {
        div_bf16x16<0, bfloat16, bfloat16, bfloat16>(
            input, weights, output, params
        );
    }
}

template<unsigned has_scalar_broadcast = 0, typename Ti0 = bfloat16, typename Ti1 = bfloat16, typename To = bfloat16>
void div_bf16x16 (
    Ti0 * in0, 
    Ti1 * in1, 
    To  * __restrict ofm, 
    const KernelDivBf16x16Param &params
) {

    Ti0 * pI0 = in0;
    Ti1 * pI1 = in1;
    To  * pO  = ofm;
    constexpr unsigned loop_range = 6;
    constexpr unsigned gran = 64;
    dims_2d_t dims_x = params.dims_x.instantiate( );
    dims_2d_t dims_y = params.dims_y.instantiate( );

    for ( unsigned j = 0; j < params.loop_count; j++ )
        chess_prepare_for_pipelining
        chess_loop_range( loop_range, )
    {
        vector<Ti0, gran> Xbuff;
        vector<Ti1, gran> Ybuff;
        accum<accfloat, 32> acc0, acc1;
        vector<To, 32> out0, out1;

        /* Load input data */
        Xbuff = load_v<gran>( pI0 );

        if constexpr( has_scalar_broadcast == 1 ) {
            Ybuff = aie::broadcast<Ti1, gran>( *pI1 );
        } else {
            Ybuff = load_v<gran>( pI1 );
        }

        /* Compute division using reciprocal */
        Ybuff = aie::inv( Ybuff );
        
        /* Multiply: div(x, y) = x * inv(y) */
        acc0 = aie::mul( Xbuff.template extract<32>( 0 ), Ybuff.template extract<32>( 0 ) );
        acc1 = aie::mul( Xbuff.template extract<32>( 1 ), Ybuff.template extract<32>( 1 ) );

        out0 = acc0.template to_vector<To>( );
        out1 = acc1.template to_vector<To>( );

        /* Store output */
        store_v( pO + 32, out1 );
        store_v( pO, out0 );
        
        /* Increment pointers */
        pI0 = add_2d_byte( pI0, dims_x );
        pI1 = add_2d_byte( pI1, dims_y );
        pO += gran;
    }
}

#endif // __KERNEL_DIV_BF16X16_IMPL_HPP__
