#ifndef __GLOBAL_REDUCEMAX_BF16X16_TEMPLATE_H__
#define __GLOBAL_REDUCEMAX_BF16X16_TEMPLATE_H__

#include <adf.h>
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include <aie_api/utils.hpp>

#include "../../include/access_casc_stream.h"
#include "../../include/access_core_stream.h"

#include "../global_reduce.h"
#include "kernel_helpers.h"

//inline __attribute__((always_inline))
v16accfloat tree_max_8x8 (aie::vector<bfloat16,64> in)
{
    v64bfloat16 inT;


    aie::vector<bfloat16, 64> in_transposed = aie::transpose( in, 8, 8);
    inT = in_transposed;

    v32bfloat16 v4x8_0 = extract_v32bfloat16(inT, 0);
    v32bfloat16 v4x8_1 = extract_v32bfloat16(inT, 1);
    v32bfloat16 v4x8   = max(v4x8_0, v4x8_1);
    v32bfloat16 v2x8   = max( v4x8, shuffle( v4x8, T256_2x2_hi ));
    v32bfloat16 v1x8   = max( v2x8, shuffle( v2x8, T128_4x2_hi ));
    v16accfloat vout;  // embedding 8x1 bfloat16 data in first 16 bytes, rest undefined
    vout = v16accfloat( v1x8 );
    return vout;
}

//inline __attribute__((always_inline))
v8bfloat16 extract_LSB_16bytes(v16accfloat v, int idx)  // Actually returning v4accfloat
{
    v32bfloat16 v32bf = v32bfloat16(v);
    return extract_v8bfloat16(v32bf, idx);
}


inline __attribute__((always_inline))
void global_reducemax_bf16x16_template
(
    int8_t * input,
    int8_t * restrict output,
    int subvolume_count
    ,v32bfloat16* negative_max,
    int num_cols,
    int col_id,
    int row_id
    , bool multi_core_sm
)
{
    int rowIdx = row_id;//(get_coreid() & 0xF);
    int colIdx = col_id;//(get_coreid() >> 16);

    v32bfloat16 ones = broadcast_bfloat16( 1.0 );
    v32bfloat16 * pIn1  = ( v32bfloat16 * ) chess_copy( input );
    v32bfloat16 * restrict pOut = ( v32bfloat16 * ) output;
    v32bfloat16 * pIn2;
    v32bfloat16 * pOut2;
    v32bfloat16 * pIn3  = ( v32bfloat16 * ) chess_copy( input );

    {
        aie::vector<bfloat16,64> max_8x8_0;
        aie::vector<bfloat16,64> max_8x8_1;

        pIn2  = chess_copy( pIn1 );
        pOut2 = chess_copy( pIn1 );
        uint16 neg_infinity = (uint16) 0xff80;
        bfloat16 *bf_neg_infinity = (bfloat16 *)&neg_infinity;

        v32bfloat16 * pI = pIn1;
        v32bfloat16 max_4x8[4];
        #pragma unroll
        for(int k = 0; k < 4; k++)
        {
            max_4x8[k] = broadcast_bfloat16(*bf_neg_infinity);
        }
        //#pragma nounroll
        for(int i = 0; i < (num_cols/8); i++)
            chess_no_hw_loop
            //chess_prepare_for_pipelining
        {
            max_4x8[0] = max(max_4x8[0], pI[0]);
            max_4x8[1] = max(max_4x8[1], pI[1]);
            max_4x8[2] = max(max_4x8[2], pI[2]);
            max_4x8[3] = max(max_4x8[3], pI[3]);
            pI = byte_incr(pI, 256 * subvolume_count);
        }

        max_8x8_0.insert(0, max_4x8[0]);
        max_8x8_0.insert(1, max_4x8[1]);
        max_8x8_1.insert( 0, max_4x8[2]);
        max_8x8_1.insert( 1, max_4x8[3]);

        v16accfloat max0 = tree_max_8x8( max_8x8_0 );
        v16accfloat max1 = tree_max_8x8( max_8x8_1 );

        v16accfloat localmax = shuffle( max0, max1, T128_2x4_lo );
        v16accfloat rx_glb_max = localmax;

        if (multi_core_sm){
            rx_glb_max = global_add_reduce( localmax, rowIdx, colIdx, 1 ); // 1 - use it for global reduce max
        }
        negative_max[0] = *((v32bfloat16*)(&rx_glb_max));

    }
}

#endif