/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.
    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.
    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.
    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __DIRECT_CONV_INT16X8_GENERIC_KERNEL_C__
#define __DIRECT_CONV_INT16X8_GENERIC_KERNEL_C__

#include "direct_conv_int16x8_generic_kernel.h"
#include "direct_conv_int16x8_generic_params.h"
#include "direct_conv_int16x8_generic_template.h"
#include "qdq/qdq_helpers.hpp"
#include "qdq/qdq.cc"

#ifndef INLINE

__aie_inline void sum_conv_2d_c1( int32 * in, int32 * restrict out, int Y, int X, int Ky, int Kx, int Sx, int stepY, int stepKy ) {
    v16int32 * pI = ( v16int32 * ) in;
    int32 * pO = out;
    constexpr unsigned N = 8;
    fifo_state_t f;
    f.pos = 0;
    addr_t cntI = 0;
    addr_t cntO = 0;
    int reset = -64;
    int inc0 = add_dimension( reset, Kx-1, 4 );
    int inc1 = stepKy + reset;
    reset = -Ky * stepKy;
    int inc2 = add_dimension( reset, X/N-1, 4*N*Sx );
    int inc3 = stepY + reset;
    for ( unsigned pix = 0; pix < Y * X / N; pix++ )
        chess_no_hw_loop
    {
        constexpr unsigned Nop = std::max( 16u, N );
        aie::vector<int32,Nop> sum = aie::zeros<int32,Nop>( );
        fifo_ld_fill( pI, f );
        aie::vector<int32,Nop> val = fifo_ld_pop_2d_byte( pI, f, inc1, Kx-1, cntI, inc0 );
        for ( unsigned k = 1; k < Ky * Kx; k++ )
            chess_no_hw_loop
        {
            sum = aie::add( sum, aie::vector( shuffle( val, Sx > 1 ? T32_16x2_lo : T512_1x2_lo )));
            fifo_ld_fill( pI, f );
            val = fifo_ld_pop_2d_byte( pI, f, inc1, Kx-1, cntI, inc0 );
        }
        sum = aie::add( sum, aie::vector( shuffle( val, Sx > 1 ? T32_16x2_lo : T512_1x2_lo )));
        aie::store_v( pO, sum.extract<N>( 0 ));
        pO += N;
        pI = add_2d_byte( pI, inc3, X/N-1, cntO, inc2 );
    }
}

struct Steps3D {
    int size1;
    int step1;
    int size2;
    int step2;
    int step3;
    Steps3D( ) : size1(1), step1(0), size2(1), step2(0), step3(aie::vector_decl_align) {}
    Steps3D( int step ) : size1(1), step1(0), size2(1), step2(0), step3(step) {}
    Steps3D( int size1, int step1, int step2  ) : size1(1), step1(0), size2(step1), step2(step1), step3(step2) {}
    Steps3D( int size1, int step1, int size2, int step2, int step3 ) : size1(size1), step1(step1), size2(size2), step2(step2), step3(step3) {}
    inline dims_3d_t to_dims() {
        return dims_3d_from_steps( size1, step1, size2, step2, step3 );
    }
};

//__attribute__((noinline))
void transpose_2k( void * in, Steps3D stepsIn, void * restrict out, Steps3D stepsOut, int blocks, int mode1, bool is_2k_mode=0, int mode2=T512_1x2_lo ) clobbers_not( p2, p3, p4, p5, p6, RS16 ) {
    v16int32 * pI = ( v16int32 * ) in;
    v16int32 * pO = ( v16int32 * ) out;
    dims_3d_t dimsIn = stepsIn.to_dims();
    dims_3d_t dimsOut = stepsOut.to_dims();
    mode2 = is_2k_mode ? mode2 : T512_1x2_lo;
    for ( unsigned pix = 0; pix < blocks; pix++ ) {
        v16int32 a0 = *pI;      pI = add_3d_byte( pI, dimsIn );
        v16int32 a1 = *pI;      pI = add_3d_byte( pI, dimsIn );
        v16int32 a2 = *pI;      pI = add_3d_byte( pI, dimsIn );
        v16int32 a3 = *pI;      pI = add_3d_byte( pI, dimsIn );
        v16int32 b0 = shuffle( a0, a1, mode1 );
        v16int32 b1 = shuffle( a0, a1, mode1 + 1 );
        v16int32 b2 = shuffle( a2, a3, mode1 );
        v16int32 b3 = shuffle( a2, a3, mode1 + 1 );
        *pO = shuffle( b0, b2, mode2 );         pO = add_3d_byte( pO, dimsOut );
        *pO = shuffle( b1, b3, mode2 );         pO = add_3d_byte( pO, dimsOut );
        *pO = shuffle( b0, b2, mode2 + 1 );     pO = add_3d_byte( pO, dimsOut );
        *pO = shuffle( b1, b3, mode2 + 1 );     pO = add_3d_byte( pO, dimsOut );
    }
}
void transpose_2k( void * in, void * restrict out, int blocks, int mode1, bool is_2k_mode=0, int mode2=T512_1x2_lo ) {
    transpose_2k( in, Steps3D(), out, Steps3D( 2, 64+64*is_2k_mode, 2, 128-64*is_2k_mode, 256 ), blocks, mode1, is_2k_mode, mode2 );
}


void __attribute__((noinline)) direct_conv_int16x8_generic
(
        int16_t * input,
        int8_t * weights,
        int8_t * weight_unpack,
        int32_t * tdm1,
        int32_t * tdm2,
        int32_t * tmp_buf,
        const KernelParams &params_s,
        bool zero_init=1,
        bool final_tdm_iter=0,
        int op_mode_s = OP_NONE
) 
chess_extra_options("mist2 +Opnll")
{
    Increments incrs;
    KernelParams params = params_s;
    params.op_mode = op_mode_s == OP_NONE ? params.op_mode : op_mode_s;
    #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ANY_CONV
    bool no_stride = params.S_g == 1;
    #else
    bool no_stride = 1;
    #endif
    params.N_g = 1; //This is currently the only mode supported

    setup_parameters_dc_int16x8_generic( params, incrs, ( long ) tdm2 - ( long ) tdm1 );

    KernelControl ctrl = params.ctrl;
    ctrl.zero_init = zero_init;
    int op_mode = params.op_mode;
    int8_t * pW = weights;
    int32_t * pTdm2 = no_stride ? tdm2 : tdm1;
    int Xi_g, Yi_g, Xi_block_words;
    int shift_tdm_in = params.shift_tdm;
    int shift_tdm_out = params.shift_tdm;
    #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_DWC
    if ( chess_copy( op_mode == OP_DWC || op_mode == OP_DWC_SUM )) {
        unsigned wghts = params.inner_g * params.Co_g;
        v8int8 * pI = (v8int8*) weights;
        v64int8 * restrict pO = (v64int8*) weight_unpack;
        pW = weight_unpack;
        for ( unsigned kc = 0; kc < wghts; kc++ )
            chess_prepare_for_pipelining
            chess_loop_range( 2, )
            chess_no_hw_loop
        {
            aie::vector<int8, 64> w = broadcast_to_v64int8( *pI++ );
            *pO++ = aie::select( int8( 0 ), w, aie::mask<64>::from_uint64( 0x8040201008040201ull ));
        }
        if ( op_mode == OP_DWC_SUM ) {
            aie::store_v( pW, aie::select( int8( 0 ), params.zp_wght, aie::mask<64>::from_uint64( 0x8040201008040201ull )));
            incrs.dimsW.num1 = 0;
            incrs.dimsW.num2 = 0;
            incrs.dimsW.inc3 = 0;
        }
    } else
    #endif
    #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ASYM
    if ( op_mode == OP_SUM || op_mode == OP_SUM_T ) {
        shift_tdm_in = 0;
        shift_tdm_out = 0;
        pW = weight_unpack;
        bool is_transpose = op_mode == OP_SUM_T;
    pW = weight_unpack;
        if ( !is_transpose ) {
            incrs.shfl_0 = T512_1x2_lo;
            incrs.shfl_1 = T512_1x2_hi;

            #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ANY_CONV
                //Size requirement for ifm sum buffer:
                //size = ceil( Xi_g * Yi_g, 64 ) * 4
                //with
                Xi_g = params.X_g + ( params.Kx_g > 1 );
                Yi_g = ( params.Y_g - 1 ) * params.S_g + params.Ky_g;
                //Xi_g here is simplified under the assumption, we don't process more than an Kx=8 filter.
                //Xo = 8 * X_g / Sx
                //placing this into the generic formula for Xi:
                //Xi = ( Xo - 1 ) * Sx + Kx = ( 8 * X_g / Sx - 1 ) * Sx + Kx = 8 * X_g - Sx + Kx
                //Calculation the granularity again:
                //Xi_g = ceil( Xi / 8 ) = ceil(( 8 * X_g - Sx + Kx ) / 8 ) = X_g + ceil(( Kx - Sx ) / 8 )

                int reset = -incrs.incA_0;
                int incCi = add_dimension( reset, params.Ci_g * Yi_g - 1, params.step_Ci );
                int incXi = add_dimension( reset, Xi_g-1, incrs.incA_0 );

                incrs.dimsAI = dims_3d_t( params.Ci_g * Yi_g - 1, incCi, Xi_g-1, incXi, reset );
            #else
                Xi_g = params.X_g;
                Yi_g = 1;//params.Y_g;

                int reset = 64 - incrs.incA_0;
                int incCi = add_dimension( reset, params.Ci_g-1, params.step_Ci );
                int incXi = add_dimension( reset, Xi_g-1, incrs.incA_0 );

                incrs.dimsAI = dims_3d_t( params.Ci_g-1, incCi, Xi_g-1, incXi, reset );
            #endif

            incrs.dimsAO.num1 = 0;
            incrs.dimsAO.num2 = 0;
            incrs.dimsAO.inc3 = 0;
            incrs.dimsO.num1 = 0;
            incrs.dimsO.num2 = 0;
            incrs.dimsO.inc3 = 128;
            incrs.dimsW = dims_3d_t( params.Ci_g-1, 0, 7, 64, -7*64 );
            incrs.incT_0 = 256;
            incrs.incT_1 = 256;
            pTdm2 = byte_incr( tdm1, 128 );

            // Performance optimization: Reduced inner dimension for small pixel problems.
            // yields: 0, 1, 2, 3 for blocks of 1, 2, 4, 8 words (for 1..2, 3..4, 5..8, 9..16 pixels)
            Xi_block_words = std::min( 3u, ( 31 - clb( Xi_g * Yi_g * params.N_g - 1 )));
            int inner_g = params.Ci_g << Xi_block_words;
            if ( inner_g < 8 ) {
                Xi_block_words = clb( params.Ci_g ) - 28;
                inner_g = params.Ci_g << Xi_block_words;
            }
            params.inner_g = inner_g;
            params.outer_g = std::max( 2, ( Xi_g * Yi_g * params.N_g + 7 ) / 8 );
        } else {
            incrs.shfl_0 = T16_8x8_lo;
            incrs.shfl_1 = T512_1x2_lo;

            Xi_g = params.X_g;
            Yi_g = 1;//params.Y_g;

            #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ANY_CONV
             int reset  = -128;
            #else
             int reset  = -64;
            #endif
            incrs.step_align = 0;
            int incCi1 = add_dimension( reset, 1, 8 );
            int incCi2 = add_dimension( reset, params.Ci_g * Xi_g - 1, 128 );

            incrs.dimsAI = dims_3d_t( 1, incCi1, params.Ci_g * Xi_g - 1, incCi2, reset );
            incrs.dimsAO.num1 = 0;
            incrs.dimsAO.num2 = 0;
            incrs.dimsAO.inc3 = 0;
            incrs.dimsO.num1 = 0;
            incrs.dimsO.num2 = 0;
            incrs.dimsO.inc3 = 128;
            incrs.dimsW = dims_3d_from_steps( 2, 64, Xi_g, 0, 128 );
            incrs.incT_0 = 128;
            incrs.incT_1 = 128;
            pTdm2 = tdm1;

            //TODO fix me - this can be performance optimized
            Xi_block_words = 3;
            params.inner_g = Xi_g * 8;
            params.outer_g = std::max( 2, ( params.Ci_g + 3 ) / 4 );
        }

        unsigned wghts = is_transpose ? 8 * params.outer_g : 8;
        uint64_t mask = 0x0101010101010101ull;
        aie::vector<int8,64> mask_v = aie::select( int8( 0 ), int8( 1 ), aie::mask<64>::from_uint64( mask ));
        for ( unsigned i=0; i<wghts; i++ ) {
            aie::store_v( pW + 64*i, mask_v );
            mask_v = aie::shuffle_up_rotate( mask_v, 1 );
        }
    #else
    {
    #endif
    }

    direct_conv_int16x8_generic_template<KC_TDM32, KC_TDM32> (
        (int16*)input, pW, tdm1, pTdm2, (int16*)0,
        params, ctrl, incrs, shift_tdm_in, shift_tdm_out
    );

    #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ASYM
    bool sum_pp_cond = chess_copy( final_tdm_iter && ( op_mode == OP_SUM || op_mode == OP_SUM_T ));
    if ( sum_pp_cond ) {
        v16int32 * pI = ( v16int32 * ) tdm1;
        #if defined( DIRECT_CONV_INT16X8_GENERIC_HAS_ANY_CONV )
        // fix large remainer address issue. This introduces gaps which makes the buffer size a bit larger if Y exists
        // pI requirement (bytes): max( 16, ceiling( Xi_g * Yi_g, 8 )) * 32
        // pO requirement (bytes): ceiling( max( 16, ceiling( Xi_g * Yi_g, 8 )), Yi_g ) * 32
        int Xi_all = Xi_g;
        while( chess_copy( Xi_all * Yi_g < 8 * params.outer_g )) chess_no_hw_loop {
            Xi_all++;
        }

        // alignas( aie::vector_decl_align ) int32 tmp[480];
        // assert( Xi_all * Yi_g * 32 <= sizeof( tmp ));
        v8int32 * pO = ( v8int32 * )( Yi_g > 1 ? tmp_buf : tdm1 );
        int step_Ky = Xi_all * 32;
        dims_3d_t dimsO = dims_3d_from_steps( Yi_g, step_Ky, Xi_all, 32, Yi_g * step_Ky );
        #endif
        bool is_2k_mode = op_mode != OP_SUM_T;
        int mode1 = is_2k_mode ? T128_4x2_lo : T512_1x2_lo;
        int mode2 = is_2k_mode ? T32_8x4_lo : T32_4x8_lo;
        for ( unsigned pix = 0; pix < params.outer_g; pix++ )
            chess_no_hw_loop
        {
            v16int32 a0 = pI[4*pix];
            v16int32 a1 = pI[4*pix+2-is_2k_mode];
            v16int32 a2 = pI[4*pix+1+is_2k_mode];
            v16int32 a3 = pI[4*pix+3];
            v16int32 b0 = shuffle( a0, a1, mode1 );
            v16int32 b1 = shuffle( a2, a3, mode1 );
            v16int32 b2 = shuffle( a0, a1, mode1 + 1 );
            v16int32 b3 = shuffle( a2, a3, mode1 + 1 );
            v16int32 c0 = shuffle( b0, b1, mode2 );
            v16int32 c1 = shuffle( b0, b1, mode2 + 1 );
            v16int32 c2 = shuffle( b2, b3, mode2 );
            v16int32 c3 = shuffle( b2, b3, mode2 + 1 );
            #if defined( DIRECT_CONV_INT16X8_GENERIC_HAS_ANY_CONV )
            *pO = extract_v8int32( c0, 0 );     pO = add_3d_byte( pO, dimsO );
            *pO = extract_v8int32( c0, 1 );     pO = add_3d_byte( pO, dimsO );
            *pO = extract_v8int32( c1, 0 );     pO = add_3d_byte( pO, dimsO );
            *pO = extract_v8int32( c1, 1 );     pO = add_3d_byte( pO, dimsO );
            *pO = extract_v8int32( c2, 0 );     pO = add_3d_byte( pO, dimsO );
            *pO = extract_v8int32( c2, 1 );     pO = add_3d_byte( pO, dimsO );
            *pO = extract_v8int32( c3, 0 );     pO = add_3d_byte( pO, dimsO );
            *pO = extract_v8int32( c3, 1 );     pO = add_3d_byte( pO, dimsO );
            #else
            pI[4*pix]   = c0;
            pI[4*pix+1] = c1;
            pI[4*pix+2] = c2;
            pI[4*pix+3] = c3;
            #endif
        }
        if ( Xi_block_words < 3 ) {
            v8int32 * p = ( v8int32 * )( Yi_g > 1 ? tmp_buf : tdm1 ) + ( 1 << Xi_block_words );
            for ( int i = 0; i < ( 1 << Xi_block_words ); i++ ) chess_no_hw_loop {
                p[i] = p[i+8];
            }
        }
        #if defined( DIRECT_CONV_INT16X8_GENERIC_HAS_ANY_CONV ) && !defined( DIRECT_CONV_INT16X8_GENERIC_HAS_SLOW_CONV_SUM )
        if ( Yi_g > 1) {
            //int Sy = params.step_Yi / params.step_Ky;
            int Sy_shft = clb( params.step_Ky ) - clb( params.step_Yi );
            sum_conv_2d_c1( tmp_buf, tdm1, params.Y_g, params.X_g * 4 << no_stride, params.Ky_g, params.Kx_g, params.S_g, step_Ky << Sy_shft, step_Ky );
        }
        #endif
    }
    #endif
}



void __attribute__((noinline)) direct_conv_int16x8_generic
(
        int16_t * input,
        int8_t * weights,
        int8_t * weight_unpack,
        int32_t * tdm1,
        int32_t * tdm2,
        int32_t * ifm_sum,
        int32_t * tmp_buf,
        int64_t * qdq_c0,
        int32_t qdq_c1,
        int32_t qdq_c2,
        int16_t * restrict output,
        const KernelParams &params,
        int op_mode_s = OP_NONE,
        bool zero_init=1,
        bool final_tdm_iter=1,
        bool qdq_linear=0,
        int Vec_coeffs=1
) {
    op_mode_s = op_mode_s == OP_NONE ? params.op_mode : op_mode_s;
    OPMode op_mode = *(OPMode*)&op_mode_s;
    int ofm_size = params.X_g * params.Y_g * params.Co_g * params.N_g * 64;
    bool sum_op = chess_copy( op_mode == OP_SUM || op_mode == OP_SUM_T || op_mode == OP_ASYM );

    if ( !chess_copy( sum_op || op_mode == OP_QDQ )) {
        #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_DWC
        OPMode op_conv = ( op_mode >= OP_DWC_SYM ? OP_DWC : OP_CONV );
        #else
        OPMode op_conv = OP_CONV;
        #endif
        direct_conv_int16x8_generic( input, weights, weight_unpack, tdm1, tdm2, tmp_buf, params, zero_init, final_tdm_iter, op_mode >= OP_CONV_SYM ? op_conv : op_mode );
    }

    #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ASYM
    if ( chess_copy( sum_op || op_mode == OP_CONV_ASYM )) {
        direct_conv_int16x8_generic( input, weights, weight_unpack, ifm_sum, ifm_sum, tmp_buf, params, zero_init, final_tdm_iter, op_mode >= OP_CONV_ASYM ? OP_SUM : op_mode );

    #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_DWC
    } else if ( op_mode == OP_DWC_ASYM ) {
        direct_conv_int16x8_generic( input, weights, weight_unpack, tdm1, tdm2, tmp_buf, params, 0, final_tdm_iter, OP_DWC_SUM );
    #endif
    }
    #endif

    #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_QDQ
    if ( chess_copy( final_tdm_iter && op_mode >= OP_CONV_SYM )) {
        int sum_size = params.X_g * params.Y_g * params.N_g * 8;
        adf::output_buffer<int16> ofm(( int16* )output, ofm_size, 0, ofm_size );
        adf::input_buffer<int32> ifm_sm( ifm_sum, sum_size, 0, sum_size );

        QDQParams qdq_param;
        qdq_param.shift_Qout = params.shift_res;
        qdq_param.Y_g = params.Y_g;
        qdq_param.N_g = params.Co_g;
        qdq_param.sign_out = params.ctrl.sign_O;
        if ( qdq_linear ) {
            qdq_param.M_g = 2 * params.X_g;
            qdq_param.wrap0 = 1;
            qdq_param.wrap1 = 1;
            qdq_param.wrap2 = 1;
            qdq_param.wrap3 = 1;
            qdq_param.step0 = (( long )tdm2 - ( long )tdm1 ) / 4;
            qdq_param.step1 = 32 * qdq_param.N_g;
            qdq_param.step2 = 0;
            qdq_param.step3 = 0;
            qdq_param.step4 = 32;
        #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ANY_CONV
        } else if ( params.S_g != 1 ) {
            //'Tdm_S2': 'C2>1YXNC<1(XN)4C8',    tile: (XN)4C8, D0: YXN, D1: C2>1, D2: C<1
            qdq_param.M_g = params.X_g;
            qdq_param.wrap0 = params.X_g * params.Y_g;
            qdq_param.wrap1 = 2;
            qdq_param.wrap2 = 1;
            qdq_param.wrap3 = 1;
            qdq_param.step0 = 16 * qdq_param.N_g;
            qdq_param.step1 = (( long )tdm2 - ( long )tdm1 ) / 4;
            qdq_param.step2 = 0;
            qdq_param.step3 = 0;
            qdq_param.step4 = 32;
        #endif
        } else {
            //'Tdm':  '(XN)2>1YX<1N<1C(XN)4C8',   tile: (XN)4C8, D0: (XN)2>1, D1: YX<1N<1, D2: C
            qdq_param.M_g = 2 * params.X_g;
            qdq_param.wrap0 = 2;
            qdq_param.wrap1 = params.X_g * params.Y_g;
            qdq_param.wrap2 = 1;
            qdq_param.wrap3 = 1;
            qdq_param.step0 = (( long )tdm2 - ( long )tdm1 ) / 4;
            qdq_param.step1 = 32 * qdq_param.N_g;
            qdq_param.step2 = 0;
            qdq_param.step3 = 0;
            qdq_param.step4 = 32;
        }
        #ifdef DIRECT_CONV_INT16X8_GENERIC_HAS_ASYM
        if(Vec_coeffs > 1) {
            adf::input_buffer<int32> c0( (int32_t*)qdq_c0, params.Co_g * 8, 0, params.Co_g * 8 );
            qdq<int32, int32, int64, int16, 4,8,4,8,3>( tdm1, ifm_sm, c0, ofm, qdq_param );
        } else {
            adf::input_buffer<int64> c0( qdq_c0, params.Co_g * 8, 0, params.Co_g * 8 );
            qdq<4,8,4,8,3,2>( tdm1, ifm_sm, c0, qdq_c1, qdq_c2, ofm, qdq_param );
        }
        #else
        adf::input_buffer<int64> c0( qdq_c0, params.Co_g * 8, 0, params.Co_g * 8 );
        qdq<4,8,4,8,2>( tdm1, ifm_sm, c0, qdq_c1, qdq_c2, ofm, qdq_param );
        #endif
    }
    #endif
}


// This function is a wrapper to interface the generic testbench used in this example.
void __attribute__((noinline)) direct_conv_int16x8_generic
(
        int * input,
        int * weights,
        int * tdm1,
        int * tdm2,
        int * restrict output,
        const KernelParams &params
) {
    int op_mode = params.op_mode;
    int wgt_size = params.Kx_g * params.Ky_g * params.Co_g * 8 * ( op_mode == OP_DWC ? 10 : params.Ci_g * 8 );
    int ofm_size = params.X_g * params.Y_g * params.Co_g * params.N_g * 64;

    int64_t * c0 = (int64_t*)( weights + wgt_size );
    int32_t c1 = ((int32_t*)( weights + wgt_size + params.Co_g * 64 ))[0];
    int32_t c2 = ((int32_t*)( weights + wgt_size + params.Co_g * 64 ))[1];
    int8_t * weight_unpack = ( int8_t* )weights + (( params.Ky_g * params.Kx_g * params.Co_g * 8 + 63 ) & ~63 );
    int32_t * ifm_sum = ( op_mode == OP_SUM || op_mode == OP_SUM_T ? tdm1 : tdm1+ofm_size/2 );

    direct_conv_int16x8_generic(( int16_t* ) input, ( int8_t* )weights, weight_unpack, tdm1, tdm2, ifm_sum, nullptr, c0, c1, c2, ( int16_t* )output, params );

    if ( op_mode == OP_SUM || op_mode == OP_SUM_T ) {
        if ( OUT_MODE == KC_TDM32 ) {
            //this is not needed in the real design
            v16int32 * pI = ( v16int32 * )( tdm1 + ( params.Y_g * params.X_g * 8 / 2 >> ( params.S_g > 1 )));
            v16int32 * pO = ( v16int32 * ) tdm2;
            fifo_state_t f;
            fifo_ld_reset( pI, f );
            for ( unsigned i = 0; i < (( params.Y_g * params.X_g * 8 >> ( params.S_g > 1 )) + 31 ) / 32; i++ ) {
                *pO++ = fifo_ld_pop( pI, f );
            }
        }
    }
}

#endif
#endif
