/*  (c) Copyright 2019 - 2022 Xilinx, Inc. All rights reserved.
    
    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.
    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.
    
    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.
    
    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __KERNEL_SILU_EXP2_IMPL_H__
#define __KERNEL_SILU_EXP2_IMPL_H__

#include "aie_api/utils.hpp"
#include "aie_api/aie.hpp"
//#include "kernel_helpers.h"
//#include "ml_params.h"

#include "common.hh"
#include "SiLU_exp2.hpp"

using namespace aie;

template<unsigned poly_order=2, unsigned loop_range=12, typename Ti = float16, typename To = float16>
ALWAYS_INLINE void SiLU_exp2
(
    Ti * restrict input,
    To * restrict output,
    float * restrict spill_buf_a,
    float * restrict spill_buf_b,
    const KernelSiLUExp2Param &params
)
{
    To __aie_dm_resource_b * restrict pOut = (To __aie_dm_resource_b *) output;
    Ti __aie_dm_resource_a * restrict pIn32 = ( Ti __aie_dm_resource_a *) input;

    float k0s = 0;
    float k1s = 0;
    float k2s = 0;
    float k3s = 0;

    /*k0s = -1 + 895.001;
    if ( poly_order == 3 ) {
        k1s = 0.9273;
        k2s = 0;
        k3s = -0.0227;
    } else if ( poly_order == 2 ) {
        k1s = 0.8083;
        k2s = -0.1084;
    } else {
        k1s = -1.0848;
    }*/
    k0s = params.k0s;
    k1s = params.k1s;
    k2s = params.k2s;
    k3s = params.k3s;

    v32accfloat k0 = broadcast_to_v32accfloat( k0s );
    v32float chess_storage(cmal0) k1 = broadcast_to_v32float( k1s );
    v32accfloat k2 = broadcast_to_v32accfloat( k2s );
    v32float chess_storage(cmbl0) k3 = broadcast_to_v32float( k3s );

    event0();

    float minus_zero = as_float( -0x80000000 );

    v32int32 chess_storage(cmbl0) x_fix;
    v32float x;

    v32float exp2poly3;
    v32float xa, xb;

    const int fifo_size = 8;
    v32float __aie_dm_resource_a * restrict fifo_a = (v32float __aie_dm_resource_a * restrict) spill_buf_a;
    v32float __aie_dm_resource_b * restrict fifo_b = (v32float __aie_dm_resource_b * restrict) spill_buf_b;
    v32float __aie_dm_resource_a * restrict pFra   = (v32float __aie_dm_resource_a * restrict) fifo_a;
    v32float __aie_dm_resource_a * restrict pFwa   = (v32float __aie_dm_resource_a * restrict) fifo_a;
    v32float __aie_dm_resource_b * restrict pFrb   = (v32float __aie_dm_resource_b * restrict) fifo_b;
    v32float __aie_dm_resource_b * restrict pFwb   = (v32float __aie_dm_resource_b * restrict) fifo_b;
    
    const int fifo_words = 1;
    dims_2d_t dFra( fifo_size / fifo_words - 1, 1, 1 - fifo_size );
    dims_2d_t dFwa = dFra;
    dims_2d_t dFrb( fifo_size / fifo_words - 1, 1, 1 - fifo_size );
    dims_2d_t dFwb = dFrb;
    
    v32accfloat xm_acc;
    v32accfloat p_0123;
    v32accfloat z;
    
    v32accfloat x2;
    v32accfloat p_23;
    v32accfloat p_01;

    for ( int n=0; n < params.outer_g;  n++ )
    chess_prepare_for_pipelining
    chess_loop_range( loop_range, )        
    { 
        vector<Ti,32> ld_v = load_v<32>(pIn32);
        x = accum<accfloat, 32>(ld_v).to_vector<float>();
        
        pIn32 = byte_incr(pIn32, 32*sizeof(Ti));

        //calculation relu(x) and -x for later
        xa = locate_in_register<1>( bor( x, minus_zero ));
        *pFwa = xa;
        pFwa = add_2d_ptr( pFwa, dFwa );
        *pFwb = locate_in_register<7>( max( x, minus_zero ));
        pFwb = add_2d_ptr( pFwb, dFwb );

        x2     = mul_elem_32( x, x);
        if constexpr ( poly_order == 3 )
            p_23   = mac_elem_32( k3, xa, k2s);
        p_01   = mac_elem_32( k1, xa, k0s);

        chess_separator_scheduler_local();
        pFra = chess_copy( pFra );
        v32float chess_storage(cmal1) xb_cs = *pFra;
        xb = xb_cs;
        pFra = add_2d_ptr( pFra, dFra );
        pFrb = chess_copy( pFrb );
        xm_acc = v32accfloat( *pFrb );
        pFrb = add_2d_ptr( pFrb, dFrb );

        if (poly_order == 3 ) {
            p_0123 = locate_in_register<3>( mac_elem_32( v32float(p_23), v32float(x2), locate_in_register<3>( p_01 )));
        } else if (poly_order == 2 ) {
            p_0123 = locate_in_register<3>( mac_elem_32( v32float(x2), k2s, locate_in_register<3>( p_01 )));
        } else {
            p_0123 = p_01;
        }

        /* 2^( ( (k3*xa + k2)*xa + k1 )*xa + k0 ) */
        exp2poly3 = v32float(exp2_bf20_hw(p_0123));
        if ( poly_order > 1 )
            z = locate_in_register<5>( mac_elem_32( xb, exp2poly3, locate_in_register<4+std::is_integral_v<Ti> >( xm_acc )));
        else
            z = mac_elem_32( xb, exp2poly3, xm_acc );

        //branching on output, this is determined by the type of the pointer
        aie::vector<To,32> out_v = aie::accum<accfloat,32>(z).to_vector<To>();
        store_v(pOut, out_v);
        pOut = byte_incr(pOut, 32*sizeof(To));
    }

    event1( );
}

#endif // __SiLU_TEMPLATE_H__