#ifndef __AVGPOOL_BF16X16_3X3S2_IMPL_HPP__
#define __AVGPOOL_BF16X16_3X3S2_IMPL_HPP__

#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"
#include "api_loop_pipe_helper.hpp"
#include "avgpool_bf16x16_3x3s2.hpp"

using namespace aie;

inline void avgpool_bf16x16_3x3s2( bfloat16 * ifm, bfloat16 * wgt, bfloat16 * __restrict ofm, const AvgpoolBf16x16Params &param )
{
    bfloat16 * pI = ifm;
    bfloat16 * pW = wgt;
    bfloat16 * pO = ofm;

    dims_3d_t dimsA = param.dimsA.instantiate( );

    accum<accfloat, 32> acc0, acc1, acc2, acc3, res0, res1;
    int zero_acc;
    
    for ( unsigned j = 0; j < param.outer_loop; j++ )
        chess_prepare_for_pipelining
        chess_loop_range( 2, )
    {
        vector<bfloat16, 32> Xbuff0, Xbuff1, Xbuff2, Xbuff3, Xbuff4;
        vector<bfloat16, 32> interleave_0, interleave_1, interleave_2, interleave_3;
        vector<bfloat16, 32> shift_buff_0, shift_buff_1;
        vector<bfloat16, 32> Obuff0, Obuff1;
        
        vector<bfloat16, 32> div_factor0 = load_v<32>( pW );
        vector<bfloat16, 32> div_factor1 = load_v<32>( pW+32 );

        zero_acc = 1;

        aie::pipelined_loop<3>( param.inner_loop, [&]( auto j ) __aie_inline {
            Xbuff0 = load_v<32>( pI    );
            Xbuff1 = load_v<32>( pI+32 );
            Xbuff2 = load_v<32>( pI+64 );
            Xbuff3 = load_v<32>( pI+96 );
            Xbuff4 = load_v<32>( pI+128 );
            
            interleave_0 = shuffle( Xbuff0, Xbuff1, T128_4x2_lo );
            interleave_1 = shuffle( Xbuff0, Xbuff1, T128_4x2_hi );
            interleave_2 = shuffle( Xbuff2, Xbuff3, T128_4x2_lo );
            interleave_3 = shuffle( Xbuff2, Xbuff3, T128_4x2_hi );

            shift_buff_0 = shuffle_down_fill( interleave_0, Xbuff2, param.shft_0 );
            shift_buff_1 = shuffle_down_fill( interleave_2, Xbuff4, param.shft_0 );
            
            acc0 = mac( op_zero( acc0, zero_acc ), interleave_0, div_factor0 );
            acc1 = mac( op_zero( acc1, zero_acc ), interleave_1, div_factor0 );
            acc0 = mac( acc0,                      shift_buff_0, div_factor0 );
            acc2 = mac( op_zero( acc2, zero_acc ), interleave_2, div_factor1 );
            acc3 = mac( op_zero( acc3, zero_acc ), interleave_3, div_factor1 );
            acc2 = mac( acc2,                      shift_buff_1, div_factor1 );
            
            pI = byte_incr( pI, param.step_Ky );
            zero_acc = 0;
        });

        pI = add_3d_byte( pI, dimsA );

        res0 = acc0 + acc1;
        res1 = acc2 + acc3;

        Obuff0 = to_v32bfloat16( res0 );
        Obuff1 = to_v32bfloat16( res1 );
        
        store_v( pO, Obuff0 );
        store_v( pO+32, Obuff1 );

        pW += 64;
        pO += 64;
    }
}

#endif