/*  (c) Copyright 2019 - 2021 Xilinx, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __DIRECT_CONV_INT8X8_GENERIC_IMPL_HPP__
#define __DIRECT_CONV_INT8X8_GENERIC_IMPL_HPP__

#include <stdint.h>
#include "direct_conv_int8x8_generic.hpp"
#include "aie_api/aie.hpp"
#include "access_helpers.hpp"
#include "kernel_helpers.h"
#include "../direct_conv_int16x8_generic/direct_conv_int16x8_generic_impl.hpp"


template<bool has_in1_T_wa=1>
__aie_inline void direct_conv_int8x8_generic_base
(
        int8 * input,
        int8  * weights,
        int32 * tdm1,
        int32 * tdm2,
        bool zero_init,
        const DirectConvInt8x8Generic::LowParams &params,
        dims_3d_param_s16 dimsAO,
        dims_3d_param_s16 dimsW,
        int incT_0,
        int incT_1,
        bool is_dwc,
        bool is_sub
) {
    //test();
    using mm_t = aie::mmul<8,8,8,int8,int8>;
    using acc_t = mm_t::accum_type;
    constexpr unsigned Va = 64;
    constexpr unsigned Na = 2;

    v64int8 * pA = ( v64int8 * ) input;
    int8  * pW = weights;

    int32 * pTdm1 = tdm1;
    int32 * pTdm2 = tdm2;
    int32 * restrict pOut1 = chess_copy( tdm1 );
    int32 * restrict pOut2 = chess_copy( tdm2 );

    fifo_state_t fA;
    fA.pos = 0;

    dims_3d_t dimsAI = params.dimsA.instantiate( 8 );
    dims_3d_t dimsAO_i = dimsAO.instantiate( 8 );
    dims_3d_t dimsW_i  = dimsW.instantiate( 8 );
    DirectConvInt8x8Generic::BaseParams::Control ctrl = params.ctrl;

    int incTi = incT_0;
    int incT_flag = 0;

    [[ using chess: prepare_for_pipelining, min_loop_count( 2 )]]
    for (unsigned j=0; j<params.outer_g; j++)
    {
        int zero_acc = zero_init;
        mm_t m[Na];

        #pragma unroll
        for ( int l=0; l<Na; l++ ) {
            int incT = l == Na - 1 ? incT_1 : incT_0;
            acc_t tdm;
            using acc_p_t = decltype( tdm.extract<Va/4>( 0 ));
            tdm.insert( 3, acc_p_t( aie::load_v<Va/4>( pTdm2 + Va / 4 )));
            tdm.insert( 2, acc_p_t( aie::load_v<Va/4>( pTdm2 )));
            tdm.insert( 1, acc_p_t( aie::load_v<Va/4>( pTdm1 + Va / 4 )));
            tdm.insert( 0, acc_p_t( aie::load_v<Va/4>( pTdm1 )));
            m[l] = mm_t( aie::op_zero( tdm, zero_acc ));
            pTdm1 = byte_incr( pTdm1, incT );
            pTdm2 = byte_incr( pTdm2, incT );
        }

        int bound = params.inner_g;

        aie::vector<int8,64> s0, s1, x0, y0, y1;

            s0 = fifo_ld_popx( pA, fA, params.step_align, 63 );
            s1 = fifo_ld_pop_3d_byte( pA, fA, dimsAI );

            y0 = aie::load_v<64>( pW );       pW = byte_incr( pW, params.incW );
            y1 = aie::load_v<64>( pW );
            pW = add_3d_byte( pW, dimsW_i );

           #if __AIE_MODEL_VERSION__ > 11300
            x0 = shuffle(s0, s1, params.shfl_0);
            m[0] = acc_t( mac_8x8_8x8( x0, ctrl.sign_A, y0, ctrl.sign_W, m[0].to_accum( ), zero_acc, is_sub, ctrl.is_in1_T ));
            m[1] = acc_t( mac_8x8_8x8( x0, ctrl.sign_A, y1, ctrl.sign_W, m[1].to_accum( ), zero_acc, is_sub, ctrl.is_in1_T ));

           #else
            if constexpr( has_in1_T_wa ) {
                y0 = shuffle( y0, ctrl.is_in1_T ? T8_8x8 : T512_1x2_lo );
                y1 = shuffle( y1, ctrl.is_in1_T ? T8_8x8 : T512_1x2_lo );
            }

            x0 = shuffle(s0, s1, params.shfl_0);
            m[0] = acc_t( mac_8x8_8x8_conf( x0, ctrl.sign_A, y0, ctrl.sign_W, m[0].to_accum( ), zero_acc, 0, is_sub, 0 ));
            m[1] = acc_t( mac_8x8_8x8_conf( x0, ctrl.sign_A, y1, ctrl.sign_W, m[1].to_accum( ), zero_acc, 0, is_sub, 0 ));
           #endif

        [[ using chess: prepare_for_pipelining, min_loop_count( 7 )]]
        for (int i=1; i<bound; i++)
        {
            s0 = fifo_ld_popx( pA, fA, params.step_align, 63 );
            s1 = fifo_ld_pop_3d_byte( pA, fA, dimsAI );

            y0 = aie::load_v<64>( pW );       pW = byte_incr( pW, params.incW );
            y1 = aie::load_v<64>( pW );
            pW = add_3d_byte( pW, dimsW_i );

           #if __AIE_MODEL_VERSION__ > 11300
            x0 = shuffle(s0, s1, params.shfl_0);
            m[0] = acc_t( mac_8x8_8x8( x0, ctrl.sign_A, y0, ctrl.sign_W, m[0].to_accum( ), 0, is_sub, ctrl.is_in1_T ));
            m[1] = acc_t( mac_8x8_8x8( x0, ctrl.sign_A, y1, ctrl.sign_W, m[1].to_accum( ), 0, is_sub, ctrl.is_in1_T ));

           #else
            if constexpr( has_in1_T_wa ) {
                y0 = shuffle( y0, ctrl.is_in1_T ? T8_8x8 : T512_1x2_lo );
                y1 = shuffle( y1, ctrl.is_in1_T ? T8_8x8 : T512_1x2_lo );
            }

            x0 = shuffle(s0, s1, params.shfl_0);
            m[0] = acc_t( mac_8x8_8x8_conf( x0, ctrl.sign_A, y0, ctrl.sign_W, m[0].to_accum( ), 0, 0, is_sub, 0 ));
            m[1] = acc_t( mac_8x8_8x8_conf( x0, ctrl.sign_A, y1, ctrl.sign_W, m[1].to_accum( ), 0, 0, is_sub, 0 ));
           #endif
        }

        pA = add_3d_byte( pA, dimsAO_i );
        #pragma unroll
        for ( int l=0; l<Na; l++ ) {
            int incT = l == Na - 1 ? incT_1 : incT_0;
            aie::accum<acc32,64> a64 = m[l].to_accum( );
            aie::accum<acc32,16> a16_3 = a64.extract<Va/4>( 3 );
            aie::accum<acc32,16> a16_2 = a64.extract<Va/4>( 3 );
            aie::accum<acc32,16> a16_1 = a64.extract<Va/4>( 3 );
            aie::accum<acc32,16> a16_0 = a64.extract<Va/4>( 3 );
            aie::store_v( pOut2 + Va / 4, a64.extract<Va/4>( 3 ).to_vector<int32>( ));
            aie::store_v( pOut2,          a64.extract<Va/4>( 2 ).to_vector<int32>( ));
            aie::store_v( pOut1 + Va / 4, a64.extract<Va/4>( 1 ).to_vector<int32>( ));
            aie::store_v( pOut1,          a64.extract<Va/4>( 0 ).to_vector<int32>( ));
            pOut1 = byte_incr( pOut1, incT );
            pOut2 = byte_incr( pOut2, incT );
        }
    }
}


template<bool has_dwc, bool has_conv, bool has_sum>
void direct_conv_int8x8_generic
(
        int8_t * input,
        int8_t * weights,
        int8_t * weight_unpack,
        int32_t * restrict tdm1,
        int32_t * restrict tdm2,
        bool zero_init,
        bool final_tdm_iter,
        int op_mode,
        const DirectConvInt8x8Generic::LowParams &params
)
chess_extra_options("mist2 +Opnll")
{
    constexpr bool has_any_conv = has_conv || has_dwc;

    DirectConvInt8x8Generic::BaseParams::Control ctrl = params.ctrl;
    int8_t * pA = input;
    int8_t * pW = weights;
    int32_t * pTdm2 = tdm2;
    int Xi_g, Yi_g, Xi_block_words;
    dims_3d_param_s16 dimsAO = { 0, 0, 0, 0, 0 };
    dims_3d_param_s16 dimsW = params.dimsW;
    int delta_tdm = (long)tdm2 - (long)tdm1;
    int incT_0 = has_dwc && ctrl.tdm_overwrite ? 0 : 128;
    int incT_1 = has_dwc && ctrl.tdm_overwrite ? 128 : 128;
    bool is_sub = 0;

    if ( has_any_conv && ctrl.is_conv && !ctrl.is_sum ) {
        dimsAO = params.conv.dimsAO;
    }
    if ( has_dwc && (( op_mode & DirectConvInt8x8::OP_DWC ) != 0 )) {
        unsigned wghts = params.dwc.weight_size;
        v8int8 * pI = (v8int8*) weights;
        v64int8 * restrict pO = (v64int8*) weight_unpack;
        pW = weight_unpack;
        #define AIE_ATTRIBUTES chess
        [[ using AIE_ATTRIBUTES: prepare_for_pipelining, min_loop_count( 2 ), no_hw_loop ]]
        for ( unsigned kc = 0; kc < wghts; kc++ )
        {
            aie::vector<int8, 64> w = broadcast_to_v64int8( *pI++ );
            *pO++ = aie::select( int8( 0 ), w, aie::mask<64>::from_uint64( 0x8040201008040201ull ));
        }
        if (( op_mode & DirectConvInt8x8::OP_SUM ) != 0 ) {
            aie::store_v( pW, aie::select( int8( 0 ), params.dwc.zp_wght, aie::mask<64>::from_uint64( 0x8040201008040201ull )));
            dimsW.num0 = 0;
            dimsW.num1 = 0;
            dimsW.inc2 = 0;
            is_sub = 1;
        }
    } else if ( has_sum && (( op_mode & DirectConvInt8x8::OP_SUM ) != 0 )) {
        pA = weight_unpack;
        pW = input;
        unsigned wghts = 8;
        uint64_t mask = 0xFFull;
        aie::vector<int8,64> mask_v = aie::select( int8( 0 ), int8( 1 ), aie::mask<64>::from_uint64( mask ));
        for ( unsigned i=0; i < wghts ; i++ ) {
            aie::store_v( weight_unpack + 64*i, mask_v );
            mask_v = aie::shuffle_up_rotate( mask_v, 8 );
        }
        incT_0 = 0;
        incT_1 = 256;
        pTdm2 = tdm1 + 32;
    }

#ifdef PROC_OPT
    if (ctrl.is_in1_T) {
        direct_conv_int8x8_generic_base<1>(
                pA, pW, tdm1, pTdm2,
                zero_init, params, dimsAO, dimsW, incT_0, incT_1, 0, is_sub
                );
    } else {
        direct_conv_int8x8_generic_base<0>(
                pA, pW, tdm1, pTdm2,
                zero_init, params, dimsAO, dimsW, incT_0, incT_1, 0, is_sub
                );
    }
#else
    direct_conv_int8x8_generic_base(
        pA, pW, tdm1, pTdm2,
        zero_init, params, dimsAO, dimsW, incT_0, incT_1, 0, is_sub
    );
#endif

    bool sum_pp_cond = chess_copy( final_tdm_iter && (( op_mode & ( DirectConvInt8x8::OP_SUM | DirectConvInt8x8::OP_DWC )) == DirectConvInt8x8::OP_SUM ));
    if ( has_sum && sum_pp_cond ) {
        int32_t * pI = tdm1;
        int32_t * restrict pO = tdm1;
        dims_3d_t dims = params.sum.dimsSum.instantiate( );
        if ( has_conv && ctrl.is_conv ) {
            pO = ( int32_t * )weight_unpack;
        }
        int32_t * pS = pO;
        unsigned bound = params.sum.loop;
        [[ using chess: no_hw_loop, min_loop_count( 1 ) ]]
        for ( unsigned pix = 0; pix < bound; pix++ )
        {
            aie::store_v( pO, aie::load_v<8>( pI ));
            pO = add_3d_byte( pO, dims );
            pI += 8;
        }
        if ( has_conv && ctrl.is_conv ) {
            sum_conv_2d_c1( pS, tdm1, params.conv_sum );
        }
    }
}



template<bool has_vector_coeffs>
void sum_to_c0( int64_t * coeff, int32_t * ifm_sum_1, int vector_coeffs, int64_t c0, int32_t c3, const DirectConvInt8x8Generic::SumToC0Params &param )
{
    auto pS = aie::begin_vector<8>( ifm_sum_1 );
    int32_t __aie_dm_resource_b * restrict pO = (int32_t __aie_dm_resource_b *) coeff;

    alignas( 64 ) int64_t cf_spill[12];
    *( v16int32* )cf_spill = broadcast_to_v16int32(( mask64 ) c0 );
    aie::store_v(( int32_t * )( cf_spill + 8 ), aie::broadcast<int32,8>( c3 ));
    int64_t __aie_dm_resource_b * pC0 = (int64_t __aie_dm_resource_b *) ( vector_coeffs < 0 ? cf_spill : coeff );
    int32_t __aie_dm_resource_b * pC3 = (int32_t __aie_dm_resource_b *) ( vector_coeffs <= 0 ? cf_spill + 8 : byte_incr( coeff, 128 ));

    [[ using chess: no_hw_loop, min_loop_count( 1 ) ]]
    for ( unsigned o=0; o<param.N_g; o++ )
    {
        
        v8acc64 tmp = *( v8acc64 __aie_dm_resource_b *)pC0;
        aie::accum<acc64, 8> acc = aie::accum<acc64, 8>( tmp );
        aie::vector<int32, 8> sm = *pS++;
        if constexpr( has_vector_coeffs ) {
            acc = aie::mac( acc, sm, aie::load_v<8, aie_dm_resource::b>( pC3 ));
        } else {
            acc = aie::mac( acc, sm, c3 );
        }
        aie::store_v( pO, acc.template cast_to<acc32>( ).template to_vector<int32>( ));
        
        pC0 = byte_incr( pC0, param.coeff_step * ( vector_coeffs >= 0 ));
        pC3 = byte_incr( pC3, param.coeff_step * ( vector_coeffs > 0 ));
        pO  = byte_incr( pO,  param.coeff_step );
    }
}




template<bool has_dwc, bool has_conv, bool has_sum, bool has_vector_coeffs>
void direct_conv_int8x8_generic
(
        int8_t * input,
        int8_t * weights,
        int8_t * weight_unpack,
        int32_t * tdm1,
        int32_t * tdm2,
        int32_t * ifm_sum,
        int64_t * coeffs,
        int8_t * restrict output,
        bool zero_init,
        bool final_tdm_iter,
        int op_mode,
        int64_t qdq_c0,
        int32_t qdq_c1,
        int32_t qdq_c2,
        int32_t qdq_c3,
        int8_t shift_res,
        DirectConvInt8x8GenericKernelParams &params
) {
    void * prm_ptr = &params;
    DirectConvInt8x8Generic::LowParams &prm_conv = *( DirectConvInt8x8Generic::LowParams* ) prm_ptr;

    if (( op_mode & ( DirectConvInt8x8::OP_CONV | DirectConvInt8x8::OP_DWC )) != 0 ) {
        int conv_param_size = prm_conv.ctrl.is_conv ? sizeof( DirectConvInt8x8Generic::ConvParams ) : sizeof( DirectConvInt8x8Generic::BaseParams );
        prm_ptr = byte_incr( prm_ptr, ( op_mode & DirectConvInt8x8::OP_DWC ) != 0 ? sizeof( DirectConvInt8x8Generic::DwcParams ) : conv_param_size );
        direct_conv_int8x8_generic<has_dwc, has_conv, has_sum>( input, weights, weight_unpack, tdm1, tdm2, zero_init, final_tdm_iter, op_mode & ( DirectConvInt8x8::OP_CONV | DirectConvInt8x8::OP_DWC ), prm_conv );
    }

    if ( has_sum && (( op_mode & ( DirectConvInt8x8::OP_SUM | DirectConvInt8x8::OP_DWC )) == DirectConvInt8x8::OP_SUM )) {
        DirectConvInt8x8Generic::LowParams &prm_sum  = *( DirectConvInt8x8Generic::LowParams* ) prm_ptr;
        prm_ptr = byte_incr( prm_ptr, prm_sum.ctrl.is_conv ? sizeof( DirectConvInt8x8Generic::ConvSumParams ) : sizeof( DirectConvInt8x8Generic::SumParams ));
        direct_conv_int8x8_generic<has_dwc, has_conv, has_sum>( input, nullptr, weight_unpack, ifm_sum, nullptr, zero_init, final_tdm_iter,  DirectConvInt8x8::OP_SUM, prm_sum );

    } else if ( has_dwc && (( op_mode & ( DirectConvInt8x8::OP_SUM | DirectConvInt8x8::OP_DWC )) == ( DirectConvInt8x8::OP_SUM | DirectConvInt8x8::OP_DWC ))) {
        direct_conv_int8x8_generic<has_dwc, has_conv, has_sum>( input, weights, weight_unpack, tdm1, tdm2, 0, final_tdm_iter, DirectConvInt8x8::OP_SUM | DirectConvInt8x8::OP_DWC, prm_conv );
    }
    
    if ( has_sum && (( op_mode & DirectConvInt8x8::OP_SUM_2 ) != 0 )) {
        DirectConvInt8x8Generic::LowParams &prm_sum_1  = *( DirectConvInt8x8Generic::LowParams* ) prm_ptr;
        prm_ptr = byte_incr( prm_ptr, sizeof( DirectConvInt8x8Generic::SumParams ));
        DirectConvInt8x8Generic::SumToC0Params &prm_sum_to_c0  = *( DirectConvInt8x8Generic::SumToC0Params* ) prm_ptr;
        prm_ptr = byte_incr( prm_ptr, sizeof( DirectConvInt8x8Generic::SumToC0Params ));
        int32_t * ifm_sum_1 = byte_incr( ifm_sum, prm_sum_to_c0.offset );
        direct_conv_int8x8_generic<has_dwc, has_conv, has_sum>( weights, nullptr, weight_unpack, ifm_sum_1, nullptr, zero_init, final_tdm_iter, DirectConvInt8x8::OP_SUM, prm_sum_1 );

        if ( final_tdm_iter ) {
            DirectConvInt8x8Generic::QDQParams &prm_qdq  = *( DirectConvInt8x8Generic::QDQParams* ) prm_ptr;
            sum_to_c0<has_vector_coeffs>( coeffs, ifm_sum_1, prm_qdq.vector_coeffs, qdq_c0, qdq_c3, prm_sum_to_c0 );
        }   
    }

    if ( final_tdm_iter && (( op_mode & DirectConvInt8x8::OP_QDQ ) != 0 )) {
        DirectConvInt8x8Generic::QDQParams &prm_qdq  = *( DirectConvInt8x8Generic::QDQParams* ) prm_ptr;
        qdq<has_sum, has_vector_coeffs>( tdm1, tdm2, ifm_sum, coeffs, qdq_c1, qdq_c2, shift_res, output, prm_qdq );
    }
}
#endif // __DIRECT_CONV_INT8X8_GENERIC_TEMPLATE_H__
