/*  (c) Copyright 2019 - 2024 Xilinx, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __BIASED_CONV_INT8X8_IMPL_H__
#define __BIASED_CONV_INT8X8_IMPL_H__

#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"
#include "common.hh"
// #include "biased_conv_int8x8.hpp"
#include "access_helpers.hpp"
// #include "kernel_helpers.h"
#ifdef DEBUG_KERNEL
#include "stdio.h"
#endif



template<bool has_relu6, bool has_lrelu, bool has_bias, int hardened_loop, unsigned granYX, class Bi, class Bw, class Bb, class Bo>
//ALWAYS_INLINE void
//void __attribute__(( noinline ))
void
biased_conv_int8x8
(
        Bi & restrict bufA_,
        Bw & restrict bufW_,
        Bb & restrict bufB_,
        Bo & restrict bufO_,
        const BiasedConvInt8x8IdxParams &params
) {
    bool has_tlast = 1;
    Bi &bufA = bufA_;
    Bw &bufW = bufW_;
    Bb &bufB = bufB_;
    Bo &bufO = bufO_;

    using Ti = int8; //buffer_element_t<Bi>;
    using Tw = int8; //buffer_element_t<Bw>;
    using To = int8; //buffer_element_t<Bo>;
    using Tb = int16;//buffer_element_t<Bb>;
    using Tba = AccumulatorType_t<Tb>;

    constexpr bool stream_kernel = is_stream_type_v<Ti> && is_stream_type_v<Tw>;

    Quantization* quant;
    Quantization::Control* ctrl;

    int granYX2 = granYX/2;
    int il_bound = params.inner_loop;
    int ol_bound = params.inner_time_iters;
    if constexpr( stream_kernel ) {
        il_bound *= ol_bound;
        ol_bound = 1;
    }
    int iters = params.outer_time_iters;

    aie::accum<Tba,granYX> bias[2];
    Tb * pB;
    aie::vector<Tb,granYX> bias_raw;
    v32acc32 b[2];

    /* dimensions */
    dims_2d_t dims_HWi = params.dims_A2.instantiate( );
    dims_3d_t dims_KCi = params.dims_A3.instantiate( );
    /* fifo actvs */
    fifo_state_t fA;
    /* weights */
    Tw * pW;
    uint5_t iw=0;
    /* outputs actvs */
    int zero_acc;
    uint5_t ima=0;
    uint5_t imb=0;
    uint5_t isa=0;
    uint5_t isb=0;
    Ti * pA;
    To * pOut;

    m32x64acc32 chess_storage(EM0) accsA = chess_dont_care(m32x64acc32);
    m32x64acc32 chess_storage(EM1) accsB = chess_dont_care(m32x64acc32);

    auto weights_body = [&]( auto idx ) __aie_inline {
        insert_staging( read_v<128>( pW ), iw++, MMAC_INT8 );
        pW = add_byte( pW, 128 );
    };
    auto weights_body2 = [&]( auto idx ) __aie_inline {
        weights_body( 2 * idx );
        weights_body( 2 * idx + 1 );
    };
    
    void * dummy = (void*)0;

    auto mac_body = [&]( auto idx ) __aie_inline {
        bool is_last = chess_manifest( idx == granYX2 - 1 );
        v64int8 Xbuff0, Xbuff1;
        if constexpr( !is_stream_type_v<Ti> ) {
#ifdef __chess__
            Xbuff0 = fifo_ld_popx_2d_byte(( v64int8 *& ) pA, fA, params.step_align, 63, dims_HWi );
            if (!is_last) {
                Xbuff1 = fifo_ld_popx_2d_byte(( v64int8 *& ) pA, fA, params.step_align, 63, dims_HWi );
            }
            else {
                //dims_HWi.count1 = 0;
                dummy = add_2d_byte( dummy, dims_HWi );  // This is workaround for CRVO-11896. Can be kept as it saves a few bytes of PM
                Xbuff1 = fifo_ld_popx_3d_byte(( v64int8 *& ) pA, fA, params.step_align, 63, dims_KCi );
                dims_KCi.count1 = locate_in_register<2>( dims_KCi.count1 );
            }
#else
            Xbuff0 = read_v<64>( pA );
            pA = add_2d_byte( pA + 64, dimsAl );
            Xbuff1 = read_v<64>( pA );
            pA = add_2d_byte( pA + 64, dimsAl );
#endif
        }
        else {
            Xbuff0 = read_v<64>( pA );
            Xbuff1 = read_v<64>( pA );
        }
        accsA[ima] = mac_conf( Xbuff0, ctrl->sign_A, ctrl->sign_W, accsA[ima], zero_acc );
        accsB[imb] = mac_conf( Xbuff1, ctrl->sign_A, ctrl->sign_W, accsB[imb], zero_acc );
        ima++; imb++;
    };

    auto output_body = [&]( auto idx ) __aie_inline {
        bool is_last = chess_manifest( idx == granYX2 - 1 );
        auto convert = [&]( v64acc32 acc ) __aie_inline {
            aie::vector vec = to_v64int8( acc, quant->shift_out, ctrl->sign_O );
            if constexpr( has_relu6 ) {
                vec = min( vec, quant->max_value, ctrl->sign_O );
            }
            return vec;
        };
        write_v( pOut, convert( accsA[isb] ));
        pOut = add_byte( pOut, 64 );
        if (!is_last) {
            write_v( pOut, convert( accsB[isb] ));
        } else {
            write_v( pOut, convert( accsB[isb] ), has_tlast );
        }
        pOut = add_byte( pOut, 64 );
        isb++;
    };

    auto activate_output_body_64 = [&]( auto & acc, auto ai, bool last ) __aie_inline
    {
        //v64uint16 tp = aie::utils::locate_in_register<4>( to_v64uint16( acc[ai], quant->shift_lrelu_in - 1 ));
        //v64int16  tn = aie::utils::locate_in_register<5>( min( aie::utils::locate_in_register<6>( to_v64int16( acc[ai], quant->shift_lrelu_in ) ), broadcast_to_v64int16( 0 )));

        v64uint16 tp = ( to_v64uint16( acc[ai], quant->shift_lrelu_in - 1 ));
        if constexpr( has_relu6 ) {
            tp = min( tp, quant->max_value << ( quant->shift_out - 15 ));
        }
        v64int16  tn = ( min( ( to_v64int16( acc[ai], quant->shift_lrelu_in ) ), broadcast_to_v64int16( 0 ) ) );

        v32int8 out0 = to_v32int8( ( mac_elem_32( extract_v32int16( tn, 0 ), quant->lrelu_alpha, (to_v32acc32( extract_v32uint16( tp, 0 ), 15 )) )), quant->shift_out, ctrl->sign_O );
        v32int8 out1 = to_v32int8( ( mac_elem_32( extract_v32int16( tn, 1 ), quant->lrelu_alpha, (to_v32acc32( extract_v32uint16( tp, 1 ), 15 )) )), quant->shift_out, ctrl->sign_O );

        if constexpr( is_stream_type_v<To> ) {
            put_ms( out0 );
            put_ms( out1, last );
        } else if constexpr( !is_stream_type_v<To> ) {
            aie::store_v( pOut, aie::vector( out0 ));     pOut += 32;
            aie::store_v( pOut, aie::vector( out1 ));     pOut += 32;
        }
    };
    auto activate_output_body = [&]( auto idx ) __aie_inline
    {
        bool is_last = chess_manifest(idx == granYX2-1) && has_tlast;

        activate_output_body_64(accsA, isb, false);
        activate_output_body_64(accsB, isb, is_last);
        isb++;
    };

    auto mac_output_body = [&](auto idx) __aie_inline {
        mac_body(idx);
        output_body(idx);
    };

    auto bias_body = [&](auto idx) __aie_inline {

        //if constexpr( !has_lrelu ) {
        //    accsA[isa] = bias[0];
        //    accsB[isa] = bias[1];
        //} else {
            accsA[isa] = chess_duplicate( to_v64acc32( bias_raw, quant->shift_bias ));
            accsB[isa] = chess_duplicate( to_v64acc32( bias_raw, quant->shift_bias ));
            //accsB[isa] = chess_duplicate( aie::accum<Tba,64>( bias_raw, params.shift_bias ));
        //}
        accsA = chess_copy( accsA );
        accsB = chess_copy( accsB );

        isa++;
    };

    auto mac_body_with_bias = [&](auto idx) __aie_inline {
        bias_body( idx );
        mac_body( 2 * idx );
        mac_body( 2 * idx + 1 );
    };

    auto prefetch_weights = [&]() __aie_inline {
        if constexpr( hardened_loop == 1 ) {
            aie::pipelined_loops<8>(granYX2, weights_body, bias_body);
        } else {
            aie::pipelined_loops<8>(granYX2 / 2, weights_body2, bias_body);
        }
        staging_to_matrix_m64x64int8();
    };

    auto fetch_and_compute = [&]() __aie_inline {
        aie::pipelined_loops<8, aie::LoopOptions{.peel_front = 1, .peel_back = 2}, aie::LoopOptions{.peel_front = 2, .peel_back = 1}>(granYX2, weights_body, mac_body);
        staging_to_matrix_m64x64int8();
        zero_acc = 0;
    };

    auto fetch_and_compute_with_bias = [&]() __aie_inline {
        aie::pipelined_loops<8, aie::LoopOptions{.peel_front = 1, .peel_back = 1}, aie::LoopOptions{.peel_front = 1, .peel_back = 1}>(granYX2 / 2, weights_body2, mac_body_with_bias);
        staging_to_matrix_m64x64int8();
        zero_acc = 0;
    };

    auto compute_activate_and_output = [&]() __aie_inline {
        if constexpr( !has_lrelu ) {
            #ifdef DEBUG_KERNEL
                //printf( "Executed ReLU with bias %i, %i, %i, %i, ...\n", extract<int>( bias[0], 0 ), extract<int>( bias[0], 1 ), extract<int>( bias[0], 2 ), extract<int>( bias[0], 3 ));
            #endif
            aie::pipelined_loops<4, aie::LoopOptions{.peel_front = 0, .peel_back = 2}, aie::LoopOptions{.peel_front = 1, .peel_back = 1}>(granYX2, output_body, mac_body);
        } else {
            // aie::pipelined_loop<4, aie::LoopOptions{.peel_front = 0, .peel_back = 0}>(granYX2, mac_body);
            // aie::pipelined_loop<4, aie::LoopOptions{.peel_front = 0, .peel_back = 0}>(granYX2, activate_output_body);
            aie::pipelined_loops<4, aie::LoopOptions{.peel_front = 0, .peel_back = 2}, aie::LoopOptions{.peel_front = 1, .peel_back = 1}>(granYX2, activate_output_body, mac_body);
            #ifdef DEBUG_KERNEL
                //chess_report( accs[0] );
                //chess_report( to_v64int16( accs[0], quant->shift_lrelu_in ));
                // printf( "Executed LReLU with alpha=%i and bias %i, %i, %i, %i, ...\n", alpha, extract<int>( bias[0], 0 ), extract<int>( bias[0], 1 ), extract<int>( bias[0], 2 ), extract<int>( bias[0], 3 ));
            #endif
        }
    };
    
    auto fetch_and_compute_activate_with_bias_and_output = [&]( bool half ) __aie_inline {
        auto body = [&]( auto idx ) __aie_inline {
            if constexpr( hardened_loop == 1 ) {
                mac_body( idx );
                if constexpr( !has_lrelu ) {
                    output_body( idx );
                } else {
                    activate_output_body( idx );
                }
                bias_body( idx );
            } else if constexpr( !has_lrelu ) {
                if ( !half ) {
                    mac_body( idx );
                    output_body( idx );
                    bias_body( idx );
                } else {
                    mac_body( idx );
                }
            } else {
                if ( !half ) {
                    mac_body( 2 * idx );
                    mac_body( 2 * idx + 1 );
                }
                activate_output_body( 16 * half + idx );
                bias_body( 16 * half + idx );
                if ( half ) {
                    mac_body( 2 * idx );
                    mac_body( 2 * idx + 1 );
                }
            }
        };
        
        int bound = !has_lrelu || ( hardened_loop == 1 ) ? granYX2 : granYX2 / 2;
        auto wb = [&]( auto idx ) __aie_inline {
            if constexpr( !has_lrelu || ( hardened_loop == 1 ))
                weights_body( idx );
            else
                weights_body2( idx );
        };
            
        aie::pipelined_loops<8, aie::LoopOptions{.peel_front = 0, .peel_back = 2}, aie::LoopOptions{.peel_front = 1, .peel_back = 1}>( bound, wb, body );
        staging_to_matrix_m64x64int8();
        zero_acc = 0;
    };

    auto fetch_bias = [&]() __aie_inline {
        if ( params.norm_ch_g )
        {
            /* mesimulator */
            bufB.acquire( );
            pB = (Tb*)bufB.data( );
            bias_raw = read_v<64>( pB );
            pB = add_elem( pB, 64 );

            bias[0] = aie::accum<Tba, 64>( bias_raw, quant->shift_bias );
            //bias[0] = to_v64acc32( bias_raw, quant->shift_bias );
            bias[1] = bias[0];

            bufB.release( );
        }
        zero_acc = params.norm_ch_g == 0;
    };
    
    bufW.acquire( );
    pW = (Tw*)bufW.data( );
    quant = (Quantization*)byte_incr( pW, (params.wgt_size + params.bias_size ) );
    ctrl = &(quant->ctrl);

    fetch_bias();

    accsA = chess_dont_care( m32x64acc32 );
    accsB = chess_dont_care( m32x64acc32 );

    prefetch_weights();
    
    fifo_reset( fA );
    
    bufA.acquire( );
    pA = (Ti*)bufA.data( );
    
    if constexpr( hardened_loop != 1 ) {
        fetch_and_compute_with_bias( );
    }

    for ( unsigned chess_storage( r0 ) it=0; it<iters; it++ )
    #if !defined( DEBUG_KERNEL ) && !defined( DEBUG_TESTBENCH )
        //chess_prepare_for_pipelining
    #endif
        chess_loop_range( 1, )
        #ifdef __chess__
        chess_allocate( R:28 )
        chess_allocate( P:8 )
        chess_allocate( DC:8 )
        #endif
    {

        [[using chess: min_loop_count( 1 ), allocate( R:24 )]]
        for (int ol=0; ol<ol_bound; ol++)
        {

            if constexpr( hardened_loop < 0 ) {
                #ifndef DEBUG_KERNEL
                [[using chess: prepare_for_pipelining, min_loop_count( 4 - 2 ), allocate( R:15 )]]
                #endif
                for (int il=2; il<il_bound; il++)
                {
                    fetch_and_compute();
                }
            } else {
                #pragma unroll
                for ( unsigned j = 2; j < hardened_loop; j++ ) {
                    fetch_and_compute();
                }
            }

            bufW.release( );

            if ( ol == ol_bound - 1 )
                break;

            bufW.acquire( );
            pW = (Tw*)bufW.data( );

            fetch_and_compute();

            bufA.release( );
            bufA.acquire( );
            pA = (Ti*)bufA.data( );
            
            if constexpr( hardened_loop != 1 ) {
                fetch_and_compute( );
            }
        }

        if ( it == iters - 1 )
            break;

        bufW.acquire( );
        pW = (Tw*)bufW.data( );
        bufO.acquire( );
        pOut = (To*)bufO.data( );
        
        fetch_bias( );

        fetch_and_compute_activate_with_bias_and_output( 0 );

        bufA.release( );
        bufA.acquire( );
        pA = (Ti*)bufA.data( );
        
        if constexpr( hardened_loop != 1 ) {
            fetch_and_compute_activate_with_bias_and_output( 1 );
        }
        
        bufO.release( );
    }

    bufO.acquire( );
    pOut = (To*)bufO.data( );

    compute_activate_and_output();
    
    bufO.release( );
    bufA.release( );
    event1( );
}


#endif // __BIASED_CONV_INT8X8_TEMPLATE_H__
