/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __BIASED_CONV_INT8X8_TEMPLATE_H__
#define __BIASED_CONV_INT8X8_TEMPLATE_H__
#include "arch_kernel_helpers.h"
#include "super_kernel_types.h"

using conv3d_params_t = super_kernel_params_t;

#include <optional>

#ifndef ReLU_uMAX
#define ReLU_uMAX -1
#endif
#ifndef ReLU_sMAX
#define ReLU_sMAX 127
#endif
#ifndef ReLU6_MAX
#define ReLU6_MAX 96
#endif


template<enum KernelConfig kc_in, enum KernelConfig kc_out, unsigned ol_lr=8, unsigned il_lr=8, bool gate_bias=1, bool has_stride=1, bool reduced_stride=1, bool has_lrelu=0, bool has_relu6=0>
ALWAYS_INLINE
void biased_conv_int8x8_template
(
        int8_t* input,
        int8_t * weights,
        int8_t * bias,
        int * restrict tdm1_in,
        int * restrict tdm2_in,
        int8_t *__restrict output,
        const conv3d_params_t &params,
        const bool zero_init = 0,
        const bool bias_usage = 1,
        //const leakyrelu_kernel_params_t &relu_params,
        const std::optional<leakyrelu_kernel_params_t> relu_params_opt = std::nullopt
) {

#ifdef USE_RTP_SIGN
    auto apply_sign = [sign = params.ifm_sign]( auto &&v ) __attribute__(( always_inline )) {
        return aie::op_sign( v, sign );
    };
#else
    auto apply_sign = []( auto &&v ) __attribute__(( always_inline )) { return v; };
#endif

    using bias_mul = aie::mmul<8, 2, 8, uint16, int16>;
    using mac888 = aie::mmul<8, 8, 8, int8, int8>;

    constexpr unsigned ol_fold_threshold = ol_lr > 3 ? 3 : 2;
    constexpr bool requires_relu_min = has_relu6 || int8( ReLU_uMAX ) != -1;
    static_assert( il_lr >= 8, "Kernel written with the assumption of il_lr >= 8" );

    int norm_ch_g = bias_usage*params.tile_ocg*2*gate_bias;
    int no_tdm_buffer = (( kc_out == KC_RESULT8 ) && zero_init );

    v64int8 chess_storage( DM_bankAC ) * restrict pIn = ( v64int8 chess_storage( DM_bankAC ) * ) input;
    int8 chess_storage( DM_bankB ) * pW  = ( int8 chess_storage( DM_bankB ) * ) weights;

    v16int16 chess_storage( DM_bankB ) * pBN0 = ( v16int16 chess_storage( DM_bankB ) * ) bias;
    v16int16 chess_storage( DM_bankB ) * pBN1 = pBN0 + (( norm_ch_g )>1 )*2;
    constexpr aie_dm_resource output_resource = ((( kc_in == KC_ZERO ) && ( !has_stride || reduced_stride )) ? aie_dm_resource::ac : aie_dm_resource::none );
    int8 * restrict pOut8  = ( int8 * ) output;

    constexpr aie_dm_resource tdm_resource[2] = { aie_dm_resource::a, aie_dm_resource::c };
    int32 * pTdm32[2] = { ( int32* ) tdm1_in + 16 * 4 * no_tdm_buffer,
                          ( int32* ) tdm2_in + 16 * 4 * no_tdm_buffer };
    int16 * pTdm16[2] = { ( int16 * )tdm1_in + 32 * 2 * no_tdm_buffer,
                          ( int16 * )tdm2_in + 32 * 2 * no_tdm_buffer };
    int32 * restrict pOut32[2] = { ( int32 * )tdm1_in, ( int32 * )tdm2_in };

    int16 * restrict pOut16a = ( int16 * )chess_copy( tdm1_in );
    int16 * restrict pOut16b = ( int16 * )chess_copy( tdm2_in );
    int16 * restrict pOut16[2] = { pOut16a, pOut16b };
    //int16 * restrict pOut16[2] = { (int16 *)chess_copy( tdm1_in ), (int16 *)chess_copy( tdm2_in ) };

    aie::accum<acc32, 64> accs[2];
    aie::accum<acc32, 64> psum[2];

    fifo_state_t fA;
    fA.pos = 0;

    addr_t cntAL1 = 0;
    addr_t cntAL2 = 0;
    addr_t chess_storage( dc2 ) cntAO1 = 0;
    addr_t chess_storage( dc6 ) cntAO2 = 0;
    addr_t chess_storage( dc0 ) cntS1  = 0;
    addr_t chess_storage( dc4 ) cntS2  = 0;
    addr_t cntB   = 0;
    addr_t cntBN  = 0;

    unsigned il_count = params.inner_loop;
    unsigned ol_count = params.outer_loop;
    dims_3d_t dimsO;
    dims_3d_t dimsAO;
    if ( has_stride && !reduced_stride ) {
        ol_count = ol_count >> ( params.str_w - 1 );
        dimsAO = dims_3d_t( params.numAO1, params.incAO1, (( params.numAO2 + 1 ) >> ( params.str_w - 1 )) - 1, params.incAO2 + 64 * ( params.str_w - 1 ), params.incAO3 + 64 * ( params.str_w - 1 ));
    } else {
        dimsAO = dims_3d_t( params.numAO1, params.incAO1, params.numAO2, params.incAO2, params.incAO3 );
    }
    if ( has_stride && reduced_stride ) {
        dimsO  = dims_3d_t( params.numCS1, params.incCS1, params.numCS2, params.incCS2, params.incCS3 );
    } else {
        dimsO  = dims_3d_t( params.numCS1, params.incCS1 + params.incS0, (( params.numCS2 + 1 ) >> ( params.str_w - 1 )) - 1, params.incCS2 + 32, params.incCS3 + 32 );
    }

    // Hsigmoid and Hswish are not supported
    //const leakyrelu_kernel_params_t &relu_params = relu_params_opt.has_value( ) ? relu_params_opt.value( ) : leakyrelu_kernel_params_t( );
    const bool sign_res = params.act_type_1 == ReluType::NoRelu || params.act_type_1 == ReluType::Leaky_Prelu;
    //uint16 alpha = has_lrelu && params.act_type_1 == ReluType::Leaky_Prelu ? relu_params.alpha : 0;
    uint16 alpha = has_lrelu && params.act_type_1 == ReluType::Leaky_Prelu ? params.leaky_alpha : 0;

    v64int8 min_res = aie::broadcast<int8, 64>(( params.act_type_1 == ReluType::Relu6 ) ? ReLU6_MAX : ( sign_res == 0 ) ? ReLU_uMAX : ReLU_sMAX ); //needs to be properly tested
    aie::vector<int8, 64> min_res_v = min_res;

    aie::vector<uint16, 16> bias_shift_v = aie::select( aie::broadcast<uint16, 32>((( norm_ch_g ) > 0 ) << ( params.shift_bias_1 - 1 )),
                                                     aie::broadcast<uint32, 16>(( norm_ch_g ) > 0 ).cast_to<uint16>( ),
                                                     aie::mask<32>( params.shift_bias_1 == 0 )).extract<16>( 0 );
    int incrT0 = ( has_stride == 0 ) | ( reduced_stride == 0 ) | ( params.str_w == 1 ) | -no_tdm_buffer;

    aie::vector<int8, 64> xbuff0;
    aie::vector<int8, 64> ybuff0;
    aie::vector<int8, 64> ybuff1;

    mac888 D0, D1;

    auto fetch_data = [&]( ) __attribute__(( always_inline )) {
        if constexpr( reduced_stride || !has_stride ) {
            aie::vector<int8, 64> __aie_register( x0 ) sbuff_s = fifo_ld_popx_3d_byte( pIn, fA, params.step_align, 63, params.incAL3-32, params.numAL1, cntAL1, params.incAL1-32, params.numAL2, cntAL2, params.incAL2-32 );
            if constexpr( has_stride )
                xbuff0 = aie::interleave_unzip( sbuff_s, sbuff_s, params.shfl ).first;
            else
                xbuff0 = sbuff_s;
        } else {
            aie::vector<int8, 64> __aie_register( x0 ) sbuff_s = fifo_ld_popx( pIn, fA, params.step_align, 63 );
            aie::vector<int8, 64> __aie_register( x1 ) sbuff_s_2 = fifo_ld_pop_3d_byte( pIn, fA, params.incAL3-96, params.numAL1, cntAL1, params.incAL1-96, params.numAL2, cntAL2, params.incAL2-96 );
            xbuff0 = aie::interleave_unzip( sbuff_s, sbuff_s_2, params.shfl ).first;
        }
        if constexpr( has_lrelu && has_stride ) {
            xbuff0 = locate_in_register_t<8>( xbuff0 );
        }

        ybuff0 = locate_in_register_t<2>( aie::load_v<64>( pW )); pW += 64;
        ybuff1 = locate_in_register_t<3>( aie::load_v<64>( pW )); pW += 64;
    };

    auto read_psum = [&]<unsigned i>( ) __attribute__(( always_inline )) {
        psum[i] = aie::accum<acc32, 64>( );
        if constexpr ( kc_in == KC_TDM16 ) {
            aie::accum<acc32, 32> tmp;
            tmp.from_vector_sign( aie::load_v<32, tdm_resource[i]>( pTdm16[i] ), true, params.shift_psum_in );
            psum[i].insert( 0, tmp ); pTdm16[i] += 32 * incrT0;
            tmp.from_vector_sign( aie::load_v<32, tdm_resource[i]>( pTdm16[i] ), true, params.shift_psum_in );
            psum[i].insert( 1, tmp ); pTdm16[i] += 32;
        } else if constexpr ( kc_in == KC_TDM32 ) {
            aie::accum<acc32, 16> tmp;
            tmp.from_vector( aie::load_v<16, tdm_resource[i]>( pTdm32[i] ));  pTdm32[i] += 16;
            psum[i].insert( 0, tmp );
            tmp.from_vector( aie::load_v<16, tdm_resource[i]>( pTdm32[i] ));  pTdm32[i] += 16 * ( 2*incrT0 - 1 );
            psum[i].insert( 1, tmp );
            tmp.from_vector( aie::load_v<16, tdm_resource[i]>( pTdm32[i] ));  pTdm32[i] += 16;
            psum[i].insert( 2, tmp );
            tmp.from_vector( aie::load_v<16, tdm_resource[i]>( pTdm32[i] ));  pTdm32[i] += 16;
            psum[i].insert( 3, tmp );
        }
    };

    auto store_psum = [&]<unsigned i>( ) __attribute__(( always_inline )) {
        //accs[i] = locate_in_register<3*i>( accs[i] );
        if constexpr ( kc_out == KC_TDM16 ) {
            //aie::vector<int16, 32> __aie_register(x11) x;
            aie::vector<int16, 32> x;
            //accs[i].insert( 0, accs[i].extract<32>(0));
            //accs[i] = locate_in_register<2+i>( accs[i] );
            aie::store_v<tdm_resource[i]>( pOut16[i], x = accs[i].extract<32>( 0 ).to_vector_sign<int16>( true, params.shift_psum_out )); pOut16[i] += 32 * incrT0;
            aie::store_v<tdm_resource[i]>( pOut16[i], x = accs[i].extract<32>( 1 ).to_vector_sign<int16>( true, params.shift_psum_out )); pOut16[i] += 32;


        } else if constexpr ( kc_out == KC_TDM32 ) {
            aie::store_v<tdm_resource[i]>( pOut32[i], accs[i].extract<16>( 0 ).to_vector<int32>( ));    pOut32[i] += 16;
            aie::store_v<tdm_resource[i]>( pOut32[i], accs[i].extract<16>( 1 ).to_vector<int32>( ));    pOut32[i] += 16 * ( 2*incrT0 - 1 );
            aie::store_v<tdm_resource[i]>( pOut32[i], accs[i].extract<16>( 2 ).to_vector<int32>( ));    pOut32[i] += 16;
            aie::store_v<tdm_resource[i]>( pOut32[i], accs[i].extract<16>( 3 ).to_vector<int32>( ));    pOut32[i] += 16;
        }
    };

    auto init_accum_and_first_iter = [&]( ) __attribute__(( always_inline )) {
        aie::vector<int16, 32> chess_storage( x9 ) bias_v0;
        aie::vector<int16, 32> chess_storage( x9 ) bias_v1;
        if constexpr( gate_bias ) {
            if constexpr( kc_in != KC_ZERO && has_stride && reduced_stride ) {
                //int   cntB_r  = cntB;
                //int   offBN_r = ((norm_ch_g)>2)*2*64*cntB_r;
                //int   offBN = offBN_r;
                int offBN = cntBN;
                cntBN = cntBN >= 64 * norm_ch_g - 2*64 ? 0 : cntBN + 2*64;

                bias_v0 = set_v32int16( 0, *byte_incr( pBN0, offBN ));
                pBN0 = chess_copy( pBN0 + 2 );
                bias_v1 = set_v32int16( 0, *byte_incr( pBN0, offBN ));
                pBN0 = chess_copy( chess_copy( pBN0 ) - 2 );
            } else {
                bias_v0 = set_v32int16( 0, *pBN0 );
                pBN0 = add_2d_ptr( pBN0, 2*( 1 - norm_ch_g ), norm_ch_g - 1, cntBN, 2 );
                bias_v1 = set_v32int16( 0, *pBN0 );
                pBN0 = add_2d_ptr( pBN0, 2*( 1 - norm_ch_g ), norm_ch_g - 1, cntBN, 2 );
            }
        }

        if constexpr( kc_in != KC_ZERO ) {
            read_psum.template operator( )<0>( );
            read_psum.template operator( )<1>( );
            
            if constexpr( !gate_bias )
                psum[0] = locate_in_register<4>( psum[0] );

            D0 = mac888( aie::op_zero( psum[0], zero_init ));
            D0.mac( apply_sign( xbuff0 ), ybuff0 );
            D1 = mac888( aie::op_zero( psum[1], zero_init ));
            D1.mac( apply_sign( xbuff0 ), ybuff1 );

            if constexpr( gate_bias ) {
                bias_mul B;
                B.mul( bias_shift_v, bias_v0.extract<16>( 0 ));
                D0 = aie::add( D0.to_accum( ), locate_in_register<4>( B.to_accum( )));
                B.mul( bias_shift_v, bias_v1.extract<16>( 0 ));
                D1 = aie::add( D1.to_accum( ), locate_in_register<4>( B.to_accum( )));
            }
        } else {
            if constexpr( gate_bias ) {
                bias_mul B;
                B.mul( bias_shift_v, bias_v0.extract<16>( 0 ));
                D0 = mac888( B.to_accum( ));
                B.mul( bias_shift_v, bias_v1.extract<16>( 0 ));
                D1 = mac888( B.to_accum( ));

                D0.mac( apply_sign( xbuff0 ), ybuff0 );
                D1.mac( apply_sign( xbuff0 ), ybuff1 );
            } else {
                D0.mul( apply_sign( xbuff0 ), ybuff0 );
                D1.mul( apply_sign( xbuff0 ), ybuff1 );
            }
        }
    };


    auto process_data = [&]( ) __attribute__(( always_inline )) {
        D0.mac( apply_sign( xbuff0 ), ybuff0 );
        D1.mac( apply_sign( xbuff0 ), ybuff1 );
    };

    auto pin_accum = [&]( ) __attribute__(( always_inline )) {
        aie::accum<acc32, 64> tmp0 = locate_in_register<0>( D0.to_accum( ));
        aie::accum<acc32, 64> tmp1 = locate_in_register<1>( D1.to_accum( ));
        D0 = tmp0;
        D1 = tmp1;
    };

    auto capture_output = [&]( ) __attribute__(( always_inline )) {
        accs[0] = locate_in_register<2>( D0.to_accum( ));
        accs[1] = locate_in_register<3>( D1.to_accum( ));
    };

    auto store_data = [&]( ) __attribute__(( always_inline )) {
        if constexpr( kc_out == KC_RESULT8 ) {
            aie::vector<int8, 64> __aie_register( x10 ) Obuff0, __aie_register( x11 ) Obuff1;
            Obuff0 = accs[0].to_vector_sign<int8>( sign_res, params.shift_out_1 );
            Obuff1 = accs[1].to_vector_sign<int8>( sign_res, params.shift_out_1 );

            if constexpr( has_lrelu ) {
                aie::vector<int16, 64> val_i16_0, val_i16_1;
                val_i16_0 = locate_in_register_t<5>( accs[0].to_vector<int16>( params.shift_out16 ));
                val_i16_1 = locate_in_register_t<5>( accs[1].to_vector<int16>( params.shift_out16 ));

                alpha = __aie_copy( alpha );
                aie::vector<uint16,64> alpha_v = locate_in_register_t<3>( aie::broadcast<uint16,64>( alpha ));
                aie::accum<acc32, 64> chess_storage( dm4 ) accum_leaky_0 = aie::mul( val_i16_0, alpha_v );
                aie::accum<acc32, 64> chess_storage( dm4 ) accum_leaky_1 = aie::mul( val_i16_1, alpha_v );

                Obuff0 = aie::max( Obuff0, locate_in_register_t<5>( accum_leaky_0.to_vector<int8>( params.shift_leaky )), sign_res );
                Obuff1 = aie::max( Obuff1, locate_in_register_t<7>( accum_leaky_1.to_vector<int8>( params.shift_leaky )), sign_res );
            }

            if constexpr( requires_relu_min ) {
                Obuff0 = aie::min( Obuff0, min_res_v, sign_res );
                Obuff1 = aie::min( Obuff1, min_res_v, sign_res );
            }

            if constexpr( reduced_stride && has_stride ) {
                aie::store_v<output_resource>( pOut8, Obuff0.extract<32>( 0 ));   pOut8 = byte_incr( pOut8, params.incS0 );
                aie::store_v<output_resource>( pOut8, Obuff0.extract<32>( 1 ));   pOut8 = byte_incr( pOut8, params.incCS1 );
                aie::store_v<output_resource>( pOut8, Obuff1.extract<32>( 0 ));   pOut8 = byte_incr( pOut8, params.incS0 );
                aie::store_v<output_resource>( pOut8, Obuff1.extract<32>( 1 ));
            } else {
                aie::store_v<output_resource>( pOut8, Obuff0 );     pOut8 = byte_incr( pOut8, params.incCS1 + params.incS0 );
                aie::store_v<output_resource>( pOut8, Obuff1 );
            }
            pOut8 = add_3d_byte( pOut8, dimsO );
        }
        store_psum.template operator( )<1>( );
        store_psum.template operator( )<0>( );
    };

    if constexpr( ol_lr > ol_fold_threshold ) {
        //read_psum.template operator()<0>( );
        //read_psum.template operator()<1>( );
        fetch_data( );
        init_accum_and_first_iter( );
        //read_psum.template operator()<0>( );

        for ( unsigned i=0; i<il_count-1; i++ )
            chess_prepare_for_pipelining
            chess_loop_range( il_lr-1, )
        {
            fetch_data( );
            process_data( );
        }

        pIn = add_3d_byte( pIn, dimsAO );
        pW = add_2d_byte( pW, params.incB2, params.numB, cntB, params.incB1 );
        capture_output( );
        //store_psum.template operator()<0>( );
    }

    for ( unsigned j=0; j<ol_count - ( ol_lr > ol_fold_threshold ); j++ )
        chess_prepare_for_pipelining
        chess_loop_range( std::min( ol_fold_threshold, ol_lr ), )
        //chess_allocate( X:12 )
        //chess_allocate( DM:5 )
        //chess_allocate( R:28 )
        //chess_allocate( P:8 )
        //chess_allocate( M:8 )
    {
        //if constexpr( ol_lr <= ol_fold_threshold )
        //    read_psum.template operator()<0>( );
        //read_psum.template operator()<1>( );
        fetch_data( );
        init_accum_and_first_iter( );

        fetch_data( );
        //if constexpr( ol_lr > ol_fold_threshold ) {
        //    read_psum.template operator()<0>( );
        //}
        constexpr unsigned il_peeled = il_lr - 6;

        for ( unsigned i=0; i<il_count-il_peeled; i++ )
            chess_prepare_for_pipelining
            chess_loop_range( 6, )
            chess_peel_pipelined_loop( 2 )
        {
            process_data( );
            pin_accum( );
            #if __AIE_ARCH__ == 20
                chess_separator( );
            #endif
            fetch_data( );
        }

        if constexpr( il_peeled > 2 ) {
            #pragma unroll
            for ( unsigned i=0; i<il_peeled-2; i++ ) {
                process_data( );
                fetch_data( );
            }
        }

        pIn = add_3d_byte( pIn, dimsAO );
        pW = add_2d_byte( pW, params.incB2, params.numB, cntB, params.incB1 );
        process_data( );

        if constexpr( ol_lr <= ol_fold_threshold )
            capture_output( );

        store_data( );
        //store_psum.template operator()<1>( );

        if constexpr( ol_lr > ol_fold_threshold )
            capture_output( );

        //store_psum.template operator()<0>( );
    }
    if constexpr( ol_lr > ol_fold_threshold ) {
        store_data( );
        //store_psum.template operator()<1>( );
    }
}

#endif // __BIASED_CONV_INT8X8_TDM_SPLIT_LD2_TEMPLATE_H__
