#ifndef __SOFTMAX_BF16X16_TEMPLATE_H__
#define __SOFTMAX_BF16X16_TEMPLATE_H__

#include <adf.h>
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include <aie_api/utils.hpp>

//#include "nlf_common.h"
//#include "tree_add.h"
#include <stdio.h>

#ifdef _4x4_SFMX       // For standalone softmax test
    #define APPLY_MAX_SHIFT 0
    #define SKIP_MAX_STAGE_3 0
    #define APPLY_CAUSAL_MASK 0
    #define EIGHT_BIT_OUTPUT 0  // 8-bit output mode have to be disabled for standalone
    #define OUTER_LOOP_CNT 2
    #define LAST_COL_IDX 3
    #define USE_INV_AIEAPI 0
#elif defined(_4x2_SFMX)
    #define APPLY_MAX_SHIFT 0
    #define SKIP_MAX_STAGE_3 0
    #define APPLY_CAUSAL_MASK 0
    #define EIGHT_BIT_OUTPUT 0  // 8-bit output mode have to be disabled for standalone
    #define OUTER_LOOP_CNT 2
    #define LAST_COL_IDX 1
    #define USE_INV_AIEAPI 0
#elif defined(_4x2_MHA)
    #define APPLY_MAX_SHIFT 1
    #define SKIP_MAX_STAGE_3 1
    #define APPLY_CAUSAL_MASK 0
    #define EIGHT_BIT_OUTPUT 0  // 8-bit output mode have to be disabled for standalone
    #define OUTER_LOOP_CNT 2
    #define LAST_COL_IDX 1
    #define USE_INV_AIEAPI 0
#else		           // For softmax run in fused context
    #define APPLY_MAX_SHIFT 1    //1
    #define SKIP_MAX_STAGE_3 1  //only work if apply_max_shift set
    #define APPLY_CAUSAL_MASK 0
    #define EIGHT_BIT_OUTPUT 0
    #define OUTER_LOOP_CNT 2
    #define LAST_COL_IDX 1
    #define USE_INV_AIEAPI 0
#endif

#ifdef _4x2_MHA
#define INPUT_FACTOR INT16_BYTES
#else
#define INPUT_FACTOR INPUT_MULTIPLIER
#endif

#define BASE_BLOCK_ROWCNT  (64) / 8 // M,K,N,L 4d shapes for mha, N==64

#if APPLY_MAX_SHIFT
#include "../glbmax_bf16x16/global_reducemax_bf16x16_kernel.h"
#endif

#define DEBUG_PRINT 0
//#define ONE_OFM_STREAM_PER_COLUMN 0

//#ifndef _FUSED_OP
//#define OUTER_LOOP_CNT 2
//#else
//#define OUTER_LOOP_CNT 4
//#endif

#include "../../include/access_casc_stream.h"
#include "../../include/access_core_stream.h"
//#include "../access_pkt_stream.h"
#include "../../include/kernel_helpers.h"
#include "../../include/tree_add.h"
#include "../global_reduce.h"

#if EIGHT_BIT_OUTPUT
inline __attribute__((always_inline))
v64int8 sfmx_convert_to_int8(v64bfloat16 bf)
{
    // Convert bfloat16 to int32 with multiplication by 2^8 = 256
    v16acc32 acc0 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 0), 30));
    v16acc32 acc1 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 1), 30));
    v16acc32 acc2 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 2), 30));
    v16acc32 acc3 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 3), 30));
    // Convert int32 to int8 with the following SRS config
    //      0 - no downshifting
    //      0 - unsigned operation (int32 values outside [0..255] will be clipped)
    //      1 - asymmetric saturation
    //      0 - no rounding
    v64acc32 acc  = concat(acc0, acc1, acc2, acc3);
    v64int8  vec  = ssrs_conf(acc, 22, 0, 1, 0);
    // Shift the int8 values from [0..255] to [-128..127]
    //vec += broadcast_s8(-128);

    //return vec;
    return add(vec, broadcast_s8(-128));
}

//v32int16 lsrs (v32acc32 acc, int shft, int sign)
inline __attribute__((always_inline))
v64int16 sfmx_convert_to_int16(v64bfloat16 bf)
{
    // Convert bfloat16 to int32 with multiplication by 2^8 = 256
    v16acc32 acc0 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 0), 30));
    v16acc32 acc1 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 1), 30));
    v16acc32 acc2 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 2), 30));
    v16acc32 acc3 = v16acc32(bfloat16_to_int(extract_v16bfloat16(bf, 3), 30));
    // Convert int32 to int8 with the following SRS config
    //      0 - no downshifting
    //      0 - unsigned operation (int32 values outside [0..255] will be clipped)
    //      1 - asymmetric saturation
    //      0 - no rounding
    //v64acc32 acc  = concat(acc0, acc1, acc2, acc3);
    v32acc32 accA = concat(acc0, acc1);
    v32acc32 accB = concat(acc2, acc3);

    v32int16 vec0 = lsrs(accA, 8, 0);//ssrs_conf(acc, 22, 0, 1, 0);
    v32int16 vec1 = lsrs(accB, 8, 0);
    // Shift the int8 values from [0..255] to [-128..127]
    //vec += broadcast_s8(-128);

    //return vec;
    //v32int16 add(v32int16 a, v32int16 b)
    //v32int16 broadcast_s16(int b)
    v64int16 out;
    out = insert(out, 0, add(vec0, broadcast_s16(-128)));
    out = insert(out, 1, add(vec1, broadcast_s16(-128)));
    return out;
    //return add(vec, broadcast_s8(-128));
}
#endif

// print utility
/*
const int aiecol = 0;
const int aierow = 2+0;
const int Srow = 44;
float t32( bfloat16 v ) { return v; }
int32 t32( uint16 v ) { return v; }
int32 t32( int32 v ) { return v; }
*/
template<typename T>
void print_mat( T * mat, int M, int N, const char * name, char fmt='i' ) {
#if 0
    if((get_coreid() & 0xF)==aierow && (get_coreid() >> 16)==aiecol){
#if 0
        printf( "%s werweqrqw retfgsf cq mat:\n", name );
        char fmt1[] = "%i\n";
        char fmt2[] = "%i ";
        fmt1[1] = fmt;
        fmt2[1] = fmt;
        for( int i=0; i<M; i++ ) {
            for( int k=0; k<N/8; k++ ) {
                for( int c=0; c<8; c++ ) {
                    auto v = t32( mat[8*i+M*8*k+c] );
                    if ( k == N/8-1 && c == 7 )
                        printf( fmt1, v );
                    else
                        printf( fmt2, v );
                }
            }
        }
#else
       auto vtmp3 = (v8bfloat16*)mat+Srow;
       for (int col = 0; col < N/8; col++){
           aie::print(aie::vector<bfloat16, 8>(*vtmp3), true, name);
           vtmp3 += M;
       }
#endif
    }
#endif
}


inline v4float tree_add_4x8_local(v32accfloat acc) {
#if 0
        // BF16 variant
        v32bfloat16 ones = broadcast_bfloat16(1.0);
        v32accfloat sum;
        v32bfloat16 bf;
        bf = to_v32bfloat16(acc);
        sum = to_v32accfloat(shuffle(bf, T16_8x8_lo));
        sum = mac_elem_32(shuffle(bf, T16_8x8_hi), ones, sum);
        sum += set_v32accfloat(0, extract_v16accfloat(sum, 1));
        sum += set_v32accfloat(0, shuffle(extract_v16accfloat(sum, 0), T256_2x2_hi));
        return extract_v8accfloat(sum, 0);
#else
        // FP32 variant
        v16accfloat accA;
        accA  = shuffle(extract_v16accfloat(acc, 0), extract_v16accfloat(acc, 1), T32_4x8_lo);
        accA += shuffle(extract_v16accfloat(acc, 0), extract_v16accfloat(acc, 1), T32_4x8_hi);
        accA += shuffle(accA, T256_2x2_hi);
        accA += shuffle(accA, T128_4x2_hi);
        return extract_v4float(v16float(accA), 0);
#endif
}


inline __attribute__((always_inline))
void softmax_bf16x16_template
(
    int8_t * input,
    int8_t * restrict output,
    int num_rows, int num_cols,
    int col_id, int row_id
    , bool multi_core_sm,
    float sfmx_in_scale = 1.0000000000000f //0.180336880111120
)
{
    //print_mat<bfloat16>( (bfloat16 *)input, 64, 56, "SM_input ");

    const int num_outer_loops = num_rows / 16;
    const int num_inner_loops = num_cols /  8;
    v32bfloat16 * restrict pIn1_base; 
    v32bfloat16 * restrict pOut1_base;
    pIn1_base = (v32bfloat16*) input;
    pOut1_base = (v32bfloat16*) input;


    v32accfloat acc[4];
    v32bfloat16 ones = broadcast_bfloat16( 1.0 );
    v64bfloat16 one = concat( ones, ones );

    for ( int j=0; j<num_outer_loops; j+=1 )
        chess_prepare_for_pipelining
        // chess_flatten_loop
        chess_loop_range( 1, 1)
    {
        int8_t* pIn_RedMax = (int8_t*)(pIn1_base);
        v32bfloat16 neg_max[4];

        global_reducemax_bf16x16
        (
            pIn_RedMax,
            output,
            num_outer_loops,
            neg_max,
            num_cols,
            col_id, row_id,
            multi_core_sm
        );

#define PRINT 0

        int zero_init = 1;
        v4float tree_sums[4];

        // ----------------------------------------------------------------------------------
        // Stage 1 : Critical section. Iterate over 16xMxbf16
        // ----------------------------------------------------------------------------------

        for(int k = 0; k < 4; k++)
        chess_prepare_for_pipelining
        {
            v32bfloat16* restrict pIn1  = pIn1_base+k;
            v32bfloat16* restrict pOut1 = pOut1_base+k;

            v32accfloat global_max = to_v32accfloat(shuffle(broadcast_to_v32bfloat16(extract_v4bfloat16(neg_max[0], k)), T16_8x4));

            //////////////////////////////////////////////////////////////////////////
            ////   Remove loop_range(2, ) to support 128 shape (Nsubv==8) of SDXL MHA
            //////////////////////////////////////////////////////////////////////////
            for(int i = 0; i < num_inner_loops; i++)
            chess_prepare_for_pipelining
            chess_loop_range( 1, )
            chess_no_hw_loop
            {
                v32accfloat infp0 = sub(to_v32accfloat(pIn1[0]), global_max);
                
#if 1
                aie::vector<bfloat16, 32> inbf16;
                inbf16 = to_v32bfloat16(infp0);

                auto in_exp = aie::mul(inbf16, bfloat16(sfmx_in_scale)); //For SDXL MUL by 0.125*log2e
                v32bfloat16 p2 = set_v32bfloat16( 0, aie::exp2( in_exp.extract<16>(0).to_vector()));
                p2 = insert( p2,      1, aie::exp2( in_exp.extract<16>(1).to_vector()));
#else
                v32bfloat16 p2 = set_v32bfloat16( 0, exp2( extract_v16accfloat( infp0, 0 )));
                p2 = insert( p2,      1, exp2( extract_v16accfloat( infp0, 1 )));
#endif
                pOut1[0] = p2;
                acc[k] = mac_elem_32_conf( p2, ones, acc[k], zero_init, 0, 0 );

                zero_init = 0;
                pIn1 = byte_incr(pIn1, 256 * num_outer_loops);
                pOut1 = byte_incr(pOut1, 256 * num_outer_loops);
            }
            zero_init = 1;
            tree_sums[k] = tree_add_4x8_local(acc[k]);
        }
        v16accfloat localsum = *((v16accfloat*) tree_sums);

        if (multi_core_sm){
            // ----------------------------------------------------------------------------------
            // Stage 2 divergent section:
            // ----------------------------------------------------------------------------------
            localsum = global_add_reduce( localsum, row_id, col_id);
        }

        // ----------------------------------------------------------------------------------
        // Stage 3 : Critical section:
        //     After exiting divergent section:
        //     sum0, sum1 has the global reduced sum of exponents,
        //     --> Derive four 4x8 reciprocal operands for (16x8)
        //          which are inv_sum0, inv_sum1, inv_sum2, inv_sum3
        // ----------------------------------------------------------------------------------
        v16accfloat inv_sum;
        v32bfloat16 inv_sum0, inv_sum1, inv_sum2, inv_sum3;
        v16float tmp;

        #pragma nounroll
        for(int i = 0; i < 8; i++)
        chess_no_hw_loop
        {
            v2float in = extract_v2float( v16float(localsum), i );
            v2float out = set_v2float( 0, inv( extract_elem( in, 0 )));
            out = insert( out, 1, inv( extract_elem( in, 1 )));
            tmp = insert( tmp, i, out );
        }

        inv_sum = v16accfloat(tmp);
        inv_sum3 = set_v32bfloat16( 0, to_v16bfloat16( inv_sum ));


        inv_sum0 = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( inv_sum3, 0 )), T16_8x4 );
        inv_sum1 = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( inv_sum3, 1 )), T16_8x4 );
        inv_sum2 = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( inv_sum3, 2 )), T16_8x4 );
        inv_sum3 = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( inv_sum3, 3 )), T16_8x4 );
#if PRINT
    if((get_coreid() & 0xF)==aierow && (get_coreid() >> 16)==aiecol){
        aie::print(aie::vector<bfloat16, 32>(inv_sum0), true, "inv_sum0");
        aie::print(aie::vector<bfloat16, 32>(inv_sum1), true, "inv_sum1");
        aie::print(aie::vector<bfloat16, 32>(inv_sum2), true, "inv_sum2");
        aie::print(aie::vector<bfloat16, 32>(inv_sum3), true, "inv_sum3");
    }
#endif
        v32bfloat16* restrict pIn2; 
        v32bfloat16* restrict pOut2;
        
        pIn2  = (v32bfloat16*)pIn1_base;
        pOut2 = (v32bfloat16*)output;
        //////////////////////////////////////////////////////////////////////////
        ////   Remove loop_range(2, ) to support 128 shape (Nsubv==8) of SDXL MHA
        //////////////////////////////////////////////////////////////////////////
        for (int i = 0; i < num_inner_loops; i++)
            chess_no_hw_loop
        {
            pOut2[0] = to_v32bfloat16(mul_elem_32(pIn2[0], inv_sum0));
            pOut2[1] = to_v32bfloat16(mul_elem_32(pIn2[1], inv_sum1));
            pOut2[2] = to_v32bfloat16(mul_elem_32(pIn2[2], inv_sum2));
            pOut2[3] = to_v32bfloat16(mul_elem_32(pIn2[3], inv_sum3));
            
            pOut2 = byte_incr(pOut2, 256 * num_outer_loops);
            pIn2 = byte_incr(pIn2, 256 * num_outer_loops);
        }

        pIn1_base += 4;
        pOut1_base += 4;
    }
}

#endif // __SOFTMAX_BF16X16_TEMPLATE_H__
