/*  (c) Copyright 2019 - 2024 Xilinx, Inc. All rights reserved.
This file contains confidential and proprietary information
of Xilinx, Inc. and is protected under U.S. and
international copyright and other intellectual property
laws.
DISCLAIMER
This disclaimer is not a license and does not grant any
rights to the materials distributed herewith. Except as
otherwise provided in a valid license issued to you by
Xilinx, and to the maximum extent permitted by applicable
law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
(2) Xilinx shall not be liable (whether in contract or tort,
including negligence, or under any other theory of
liability) for any loss or damage of any kind or nature
related to, arising under or in connection with these
materials, including for any direct, or any indirect,
special, incidental, or consequential loss or damage
(including loss of data, profits, goodwill, or any type of
loss or damage suffered as a result of any action brought
by a third party) even if such damage or loss was
reasonably foreseeable or Xilinx had been advised of the
possibility of the same.
CRITICAL APPLICATIONS
Xilinx products are not designed or intended to be fail-
safe, or for use in any application requiring fail-safe
performance, such as life-support or safety devices or
systems, Class III medical devices, nuclear facilities,
applications related to the deployment of airbags, or any
other applications that could lead to death, personal
injury, or severe property or environmental damage
(individually and collectively, "Critical
Applications"). Customer assumes the sole risk and
liability of any use of Xilinx products in Critical
Applications, subject only to applicable laws and
regulations governing limitations on product liability.
THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __KERNEL_SOFTMAX_FP16X16_IMPL_HPP__
#define __KERNEL_SOFTMAX_FP16X16_IMPL_HPP__

#include "softmax_fp16x16.hpp"


template<typename Ti, typename To>
requires(( std::is_same_v<Ti, float16> || std::is_same_v<Ti, bfloat16> ) && ( std::is_same_v<To, float16> || std::is_same_v<To, bfloat16> || std::is_same_v<To, float8> ))
ALWAYS_INLINE void softmax_fp16x16
(
        Ti * input,
        int * mask,
        To * restrict output,
        const KernelSoftmax_fp16x16Param &params
) {

    constexpr unsigned granX = 4;
    constexpr unsigned granC = 32;
    constexpr unsigned granXC = granX * granC;
    
    Ti * pIn1 = input;
    To * pOut = output;
    int * pM = mask;

    int inner_g = params.Co_g;
    int outer_g = params.outer_g;
    int rnd = get_rnd( );
    
    auto dimsI_il =    params.dimsI_il.instantiate() ;
    auto dimsI_ol =    params.dimsI_ol.instantiate() ;
    auto dimsO_il =    params.dimsO_il.instantiate() ;
    auto dimsO_ol =    params.dimsO_ol.instantiate() ;
    auto dimsM_il =    params.dimsM_il.instantiate() ;
    auto dimsM_ol =    params.dimsM_ol.instantiate() ;

    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    
    for ( int j=0; j<outer_g; j+=inner_g )
        //chess_prepare_for_pipelining
        chess_loop_range( 1, )
    {
        int zero_init = 1;
        auto * restrict pOut1 = chess_copy( pIn1 );
        auto * restrict pOut2 = ( v32float16 * ) chess_copy( pIn1 );
        auto * pIn2  = chess_copy( pIn1 );
        auto * pIn3  = ( v64float16 * ) chess_copy( pIn1 );

        Ti inv_ln2 = Ti( 1.4423828125f );
        Ti one = Ti( 1.0f );
        float neg_inf = as_float( 0xFF800000 );

        float mx[granX] = { [ 0 ... ( granX - 1 )] = neg_inf };

        for ( int i=0; i<inner_g; i++ )
            chess_prepare_for_pipelining
            chess_loop_range( 2, )
        {
            #pragma unroll
            for ( int x=granX-1; x>=0; x-- ) {
                pIn1  = chess_copy( pIn1 );
                pOut1 = chess_copy( pOut1 );
                pM    = chess_copy( pM );

                // use following line if Mask is supported by dataflow :
                v32float x2 = sel( broadcast_to_v32float( neg_inf ), ( v32float ) locate_in_register<0>( mul_elem_32( locate_in_register<7>( read_v<32>( pIn1 + 64 * x )), inv_ln2 ) ), pM[2*x] );
                
                // use following line if Mask is NOT supported by dataflow :
                //v32float x2 = ( v32float ) locate_in_register<0>( mul_elem_32( locate_in_register<7>( read_v<32>( pIn1 + 64 * x )), inv_ln2 ) );
                
                write_v( pOut1 + 64 * x, aie::accum<accfloat, 32>(( v32accfloat ) x2 ).to_vector<Ti>( ));
                mx[x] = max_reduce( x2, mx[x] );
            }
            pIn1  = add_2d_byte( pIn1, dimsI_il );
            pOut1 = add_2d_byte( pOut1, dimsO_il );
            pM    = add_2d_byte( pM, dimsM_il );
        }

        #pragma unroll
        for ( int x=granX-1; x>=0; x-- ) {
            mx[x] = sub( 895.001, mx[x] );
        }

        float sm[granX] = { [ 0 ... ( granX - 1 )] = 0 };

        for ( int i=0; i<inner_g; i++ )
            chess_prepare_for_pipelining
            chess_loop_range( 2, )
        {
            #pragma unroll
            for ( int x=granX-1; x>=0; x-- ) {
                pIn2  = chess_copy( pIn2 );
                pOut2 = chess_copy( pOut2 );
                aie::vector in0 = read_v<32>( pIn2 + 64 * x );
                v32accfloat chess_storage( y0 ) in2 = mac_elem_32( in0, aie::broadcast<Ti, 32>( one ), mx[x] );
                v32accfloat p2 =  exp2_bf20_hw( in2 );
                pOut2[2*x] = to_v32float16( p2 );

                sm[x] = add_reduce(( v32float ) p2, sm[x] );
            }
            pIn2  = add_2d_byte( pIn2,  dimsI_il );
            pOut2 = add_2d_byte( pOut2, dimsO_il );
        }
        
        v64float16 inv_sm[granX];
        #pragma unroll
        for ( unsigned x=0; x<granX; x++ ) {
            v32float16 tmp = to_v32float16(( v32accfloat ) broadcast_to_v32float( inv( sm[x] )));
            inv_sm[x] = concat( tmp, tmp );
        }

        for ( int i=0; i<inner_g; i+=2 )
            chess_prepare_for_pipelining
            chess_loop_range( 2, )
        {
            //#pragma unroll
            //for ( int x=granX/2-1; x>=0; x-- ) {
            aie::unroll_times<granX>( [&]( auto xc ) __attribute__(( always_inline ))
            {
                constexpr unsigned x = granX - 1 - xc;
                pIn3 = chess_copy( pIn3 );
                pOut = chess_copy( pOut );
                if constexpr( std::is_same_v<To, bfloat16> ) {
                    v64bfloat16 out = mul_elem_64b( locate_in_register<3>( pIn3[x] ), inv_sm[x] );
                    write_v( pOut + 64 * x, aie::vector( out ));
                } else {
                    v64float16 out = mul_elem_64( locate_in_register<3>( pIn3[x] ), inv_sm[x] );
                    if constexpr( std::is_same_v<To, float8> ) {
                        write_v( pOut + 64 * x, aie::vector( to_v64float8( out )));
                    } else {
                        write_v( pOut + 64 * x, aie::vector( out ));
                    }
                }
            });
            pIn3 = byte_incr( pIn3, params.step_Ci );
            pOut = byte_incr( pOut, params.step_Co );
        }
        pM   = add_2d_byte( pM,   dimsM_ol );
        pIn1 = add_2d_byte( pIn1, dimsI_ol );
        pOut = add_2d_byte( pOut, dimsO_ol );
    }
}


#endif // __KERNEL_SOFTMAX_FP16X16_IMPL_HPP__
    