/*  (c) Copyright 2019 - 2025 Xilinx, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __ACTIVATED_DWC_QDQ_INT16X8_S1_IMPL_HPP__
#define __ACTIVATED_DWC_QDQ_INT16X8_S1_IMPL_HPP__

#include "aie_api/aie.hpp"
#include "common.hh"
// #include "activated_dwc_qdq_int16x8_s1.hpp"

template<unsigned has_actv_sum>     // NOTE: The template param is unused
ALWAYS_INLINE void activated_dwc_qdq_int16x8_s1
(
        int * input,
        int * weights,
        float * coeff_buf,
        int * restrict output,
        int8 zp_w_sc,
        const ActivatedDwcQdqInt16x8Params &params,
        int32_t shift_res
)
{
    v64int8  chess_storage( DM_bankA ) * pIn   = ( v64int8  chess_storage( DM_bankA )* ) input;
    chess_protect_access v128int8 chess_storage( DM_bankB ) * pW    = ( v128int8 chess_storage( DM_bankB )* ) weights;
    chess_protect_access v32float chess_storage( DM_bankB ) * restrict pC = ( v32float chess_storage( DM_bankB )* ) coeff_buf;
    v32int16  chess_storage( DM_bankC ) * pO = ( v32int16  chess_storage( DM_bankC )* ) output;

    v64acc32 chess_storage( dma0 ) acc0l;
    v64acc32 chess_storage( dma1 ) acc1l;
    v64acc32 chess_storage( dma2 ) acc2l;
    v64acc32 chess_storage( dma3 ) acc3l;
    v64acc32 chess_storage( dma4 ) acc0h;
    v64acc32 chess_storage( dma5 ) acc1h;
    v64acc32 chess_storage( dma6 ) acc2h;
    v64acc32 chess_storage( dma7 ) acc3h;
    v64acc32 chess_storage( dmb0 ) sum_actv_acc0l;
    v64acc32 chess_storage( dmb1 ) sum_actv_acc1l;
    v64acc32 chess_storage( dmb2 ) sum_actv_acc2l;
    v64acc32 chess_storage( dmb3 ) sum_actv_acc3l;
    v64acc32 chess_storage( dmb4 ) sum_actv_acc0h;
    v64acc32 chess_storage( dmb5 ) sum_actv_acc1h;
    v64acc32 chess_storage( dmb6 ) sum_actv_acc2h;
    v64acc32 chess_storage( dmb7 ) sum_actv_acc3h;

    dims_3d_t dims_A = params.dims_A3.instantiate( );
    dims_2d_t dims_A2 = params.dims_A2.instantiate( );
    dims_3d_t dims_O = params.dims_O3.instantiate( );
    dims_2d_t dims_O2 = params.dims_O2.instantiate( );
    dims_2d_t dims_W = params.dims_W2.instantiate( );
    dims_2d_t dims_C = params.dims_C2.instantiate( );
    const ActivatedDwcQdqInt16x8Params::Control ctrl = params.ctrl;

    v64int8 chess_storage( x0 ) ifm0;
    v64int8 chess_storage( x1 ) ifm1;
    v64int8 chess_storage( x2 ) ifm2;
    v64int8 chess_storage( x3 ) ifm3;

    aie::accum<accfloat, 32> chess_storage( y6 ) res0;
    aie::accum<accfloat, 32> chess_storage( y7 ) res1;
    aie::accum<accfloat, 32> chess_storage( y5 ) res;

    v64int8 zero = broadcast_zero_s8( );
    v128int8 chess_storage( y4 ) zp_w = broadcast_to_v128int8( zp_w_sc );
    zp_w = insert( zp_w, 3, extract_v32int8( zero, 0 ));

    int ol_bound = params.outer_loop;

    set_convo_mode( );
    event0( );

    for ( unsigned j = 0; j < ol_bound; j++ )
        chess_prepare_for_pipelining
        chess_modulo_scheduling_budget_ratio( 10000 )
        chess_pipeline_initiation_interval( 52 )
        chess_loop_range( 2, )
    {
        v128int8 chess_storage( y2 ) xa;
        v128int8 chess_storage( y2 ) xb;
        v128int8 chess_storage( y3 ) w; 

        /* Row 1 */
        /* Load X4C32 */
        ifm0 = pIn[0];
        ifm1 = pIn[2];
        ifm2 = pIn[4];
        ifm3 = pIn[6];             pIn = byte_incr( pIn+6, params.incA_0 );

        /* Split high and low bits */
        xa = concat( shuffle( ifm0, ifm1, T8_64x2_lo ), shuffle( ifm2, ifm3, T8_64x2_lo ));
        xb = concat( shuffle( ifm0, ifm1, T8_64x2_hi ), shuffle( ifm2, ifm3, T8_64x2_hi ));

        set_staging_convo( xa );
        staging_to_matrix_m64x64int8( );
        /* A: X4C32 lo * W: X4C32 */
        acc0l = mul_convo( w=pW[0], ctrl.sign_W, false );
        sum_actv_acc0l = mul_convo( zp_w, ctrl.sign_W, false );
        set_staging_convo( xb );
        staging_to_matrix_m64x64int8( );
        /* A: X4C32 hi * W: X4C32 */
        acc0h = mul_convo( w=pW[0], ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc0h = mul_convo( zp_w, ctrl.sign_W, ctrl.sign_A );
        
        chess_separator( );

        /* Row 2 */
        ifm0 = pIn[0];
        ifm1 = pIn[2];
        ifm2 = pIn[4];
        ifm3 = pIn[6];             pIn = byte_incr( pIn+6, params.incA_0 );
        
        xa = concat( shuffle( ifm0, ifm1, T8_64x2_lo ), shuffle( ifm2, ifm3, T8_64x2_lo ));
        xb = concat( shuffle( ifm0, ifm1, T8_64x2_hi ), shuffle( ifm2, ifm3, T8_64x2_hi ));

        set_staging_convo( xa );
        staging_to_matrix_m64x64int8( );
        acc0l = mac_convo( w=pW[1], acc0l, ctrl.sign_W, false );
        acc1l = mul_convo( w=pW[0],        ctrl.sign_W, false );
        sum_actv_acc0l = mac_convo( zp_w, sum_actv_acc0l, ctrl.sign_W, false );
        sum_actv_acc1l = mul_convo( zp_w, ctrl.sign_W, false );
        set_staging_convo( xb );
        staging_to_matrix_m64x64int8( );
        acc0h = mac_convo( w=pW[1], acc0h, ctrl.sign_W, ctrl.sign_A );
        acc1h = mul_convo( w=pW[0],        ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc0h = mac_convo( zp_w, sum_actv_acc0h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc1h = mul_convo( zp_w, ctrl.sign_W, ctrl.sign_A );
        
        chess_separator( );

        /* Row 3 */
        ifm0 = pIn[0];
        ifm1 = pIn[2];
        ifm2 = pIn[4];
        ifm3 = pIn[6];             pIn = byte_incr( pIn+6, params.incA_0 );

        xa = concat( shuffle( ifm0, ifm1, T8_64x2_lo ), shuffle( ifm2, ifm3, T8_64x2_lo ));
        xb = concat( shuffle( ifm0, ifm1, T8_64x2_hi ), shuffle( ifm2, ifm3, T8_64x2_hi ));

        set_staging_convo( xa );
        staging_to_matrix_m64x64int8( );
        acc0l = mac_convo( w=pW[2], acc0l, ctrl.sign_W, false );
        acc1l = mac_convo( w=pW[1], acc1l, ctrl.sign_W, false );
        acc2l = mul_convo( w=pW[0],        ctrl.sign_W, false );
        sum_actv_acc0l = mac_convo( zp_w, sum_actv_acc0l, ctrl.sign_W, false );
        sum_actv_acc1l = mac_convo( zp_w, sum_actv_acc1l, ctrl.sign_W, false );
        sum_actv_acc2l = mul_convo( zp_w, ctrl.sign_W, false );
        set_staging_convo( xb );
        staging_to_matrix_m64x64int8( );
        acc0h = mac_convo( w=pW[2], acc0h, ctrl.sign_W, ctrl.sign_A );
        acc1h = mac_convo( w=pW[1], acc1h, ctrl.sign_W, ctrl.sign_A );
        acc2h = mul_convo( w=pW[0],        ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc0h = mac_convo( zp_w, sum_actv_acc0h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc1h = mac_convo( zp_w, sum_actv_acc1h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc2h = mul_convo( zp_w, ctrl.sign_W, ctrl.sign_A );
        
        chess_separator( );

        /* Row 4 */
        ifm0 = pIn[0];
        ifm1 = pIn[2];
        ifm2 = pIn[4];
        ifm3 = pIn[6];             pIn = byte_incr( pIn+6, params.incA_0 );
        
        xa = concat( shuffle( ifm0, ifm1, T8_64x2_lo ), shuffle( ifm2, ifm3, T8_64x2_lo ));
        xb = concat( shuffle( ifm0, ifm1, T8_64x2_hi ), shuffle( ifm2, ifm3, T8_64x2_hi ));

        set_staging_convo( xa );
        staging_to_matrix_m64x64int8( );
        acc1l = mac_convo( w=pW[2], acc1l, ctrl.sign_W, false );
        acc2l = mac_convo( w=pW[1], acc2l, ctrl.sign_W, false );
        acc3l = mul_convo( w=pW[0],        ctrl.sign_W, false );
        sum_actv_acc1l = mac_convo( zp_w, sum_actv_acc1l, ctrl.sign_W, false );
        sum_actv_acc2l = mac_convo( zp_w, sum_actv_acc2l, ctrl.sign_W, false );
        sum_actv_acc3l = mul_convo( zp_w, ctrl.sign_W, false );
        set_staging_convo( xb );
        staging_to_matrix_m64x64int8( );
        acc1h = mac_convo( w=pW[2], acc1h, ctrl.sign_W, ctrl.sign_A );
        acc2h = mac_convo( w=pW[1], acc2h, ctrl.sign_W, ctrl.sign_A );
        acc3h = mul_convo( w=pW[0],        ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc1h = mac_convo( zp_w, sum_actv_acc1h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc2h = mac_convo( zp_w, sum_actv_acc2h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc3h = mul_convo( zp_w, ctrl.sign_W, ctrl.sign_A );
        
        chess_separator( );

        /* Row 5 */
        ifm0 = pIn[0];
        ifm1 = pIn[2];
        ifm2 = pIn[4];
        ifm3 = pIn[6];             pIn = byte_incr( pIn+6, params.incA_0 );
        
        xa = concat( shuffle( ifm0, ifm1, T8_64x2_lo ), shuffle( ifm2, ifm3, T8_64x2_lo ));
        xb = concat( shuffle( ifm0, ifm1, T8_64x2_hi ), shuffle( ifm2, ifm3, T8_64x2_hi ));

        set_staging_convo( xa );
        staging_to_matrix_m64x64int8( );
        acc2l = mac_convo( w=pW[2], acc2l, ctrl.sign_W, false );
        acc3l = mac_convo( w=pW[1], acc3l, ctrl.sign_W, false );
        sum_actv_acc2l = mac_convo( zp_w, sum_actv_acc2l, ctrl.sign_W, false );
        sum_actv_acc3l = mac_convo( zp_w, sum_actv_acc3l, ctrl.sign_W, false );
        set_staging_convo( xb );
        staging_to_matrix_m64x64int8( );
        acc2h = mac_convo( w=pW[2], acc2h, ctrl.sign_W, ctrl.sign_A );
        acc3h = mac_convo( w=pW[1], acc3h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc2h = mac_convo( zp_w, sum_actv_acc2h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc3h = mac_convo( zp_w, sum_actv_acc3h, ctrl.sign_W, ctrl.sign_A );
        
        chess_separator( );

        /* Row 6 */
        ifm0 = pIn[0];
        ifm1 = pIn[2];
        ifm2 = pIn[4];
        ifm3 = pIn[6];
        pIn = add_3d_byte( pIn, dims_A );
        pIn = add_2d_byte( pIn, dims_A2 );
        
        xa = concat( shuffle( ifm0, ifm1, T8_64x2_lo ), shuffle( ifm2, ifm3, T8_64x2_lo ));
        xb = concat( shuffle( ifm0, ifm1, T8_64x2_hi ), shuffle( ifm2, ifm3, T8_64x2_hi ));

        set_staging_convo( xa );
        staging_to_matrix_m64x64int8( );
        acc3l = mac_convo( w=pW[2], acc3l, ctrl.sign_W, false );
        sum_actv_acc3l = mac_convo( zp_w, sum_actv_acc3l, ctrl.sign_W, false );
        set_staging_convo( xb );
        staging_to_matrix_m64x64int8( );
        acc3h = mac_convo( w=pW[2], acc3h, ctrl.sign_W, ctrl.sign_A );
        sum_actv_acc3h = mac_convo( zp_w, sum_actv_acc3h, ctrl.sign_W, ctrl.sign_A );
        pW = add_2d_byte( pW, dims_W );
        
        chess_separator( );

        /* Write output */
        /* Y=0 */
        /* Converto int16 and Apply QDQ */
        v32accfloat chess_storage( cmal8 ) c2a = ( v32accfloat ) pC[1];
        v32accfloat chess_storage( cmbl8 ) c2b = ( v32accfloat ) pC[1];
        
        v32int32 conv_hi = ( v32int32 ) extract_v32acc32( acc0h, 1 );
        v32int32 conv_lo = ( v32int32 ) extract_v32acc32( acc0l, 1 );
        v32int32 sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc0h, 1 );
        v32int32 sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc0l, 1 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2a, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        v32int16 out = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[2] = out;
        
        chess_separator( );
        
        conv_hi = ( v32int32 ) extract_v32acc32( acc0h, 0 );
        conv_lo = ( v32int32 ) extract_v32acc32( acc0l, 0 );
        sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc0h, 0 );
        sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc0l, 0 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2b, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        out   = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[0] = out;    pO = byte_incr( pO, params.incS_0 );

        chess_separator( );

        /* Y=1 */
        conv_hi = ( v32int32 ) extract_v32acc32( acc1h, 1 );
        conv_lo = ( v32int32 ) extract_v32acc32( acc1l, 1 );
        sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc1h, 1 );
        sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc1l, 1 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2a, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        out   = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[2] = out;

        chess_separator( );
        
        conv_hi = ( v32int32 ) extract_v32acc32( acc1h, 0 );
        conv_lo = ( v32int32 ) extract_v32acc32( acc1l, 0 );
        sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc1h, 0 );
        sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc1l, 0 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2b, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        out   = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[0] = out;    pO = byte_incr( pO, params.incS_0 );

        chess_separator( );

        /* Y=2 */
        conv_hi = ( v32int32 ) extract_v32acc32( acc2h, 1 );
        conv_lo = ( v32int32 ) extract_v32acc32( acc2l, 1 );
        sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc2h, 1 );
        sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc2l, 1 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2a, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        out   = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[2] = out;

        chess_separator( );
        
        conv_hi = ( v32int32 ) extract_v32acc32( acc2h, 0 );
        conv_lo = ( v32int32 ) extract_v32acc32( acc2l, 0 );
        sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc2h, 0 );
        sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc2l, 0 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2b, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        out   = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[0] = out;    pO = byte_incr( pO, params.incS_0 );

        chess_separator( );

        /* Y=3 */
        conv_hi = ( v32int32 ) extract_v32acc32( acc3h, 1 );
        conv_lo = ( v32int32 ) extract_v32acc32( acc3l, 1 );
        sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc3h, 1 );
        sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc3l, 1 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2a, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        out   = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[2] = out;

        chess_separator( );
        
        conv_hi = ( v32int32 ) extract_v32acc32( acc3h, 0 );
        conv_lo = ( v32int32 ) extract_v32acc32( acc3l, 0 );
        sum_hi  = ( v32int32 ) extract_v32acc32( sum_actv_acc3h, 0 );
        sum_lo  = ( v32int32 ) extract_v32acc32( sum_actv_acc3l, 0 );
        res0 = mac_elem_32( conv_hi, 256.0f, locate_in_register<6>( mul_elem_32( conv_lo, 1.0f )));
        res1 = msc_elem_32( sum_hi, 256.0f, res0 );
        res1 = msc_elem_32( sum_lo, 1.0f, res1 );
        res  = locate_in_register<5>( mac_elem_32(( v32float ) c2b, ( v32float ) res1, ( v32accfloat ) locate_in_register<5>( pC[0] )));
        out   = res.template to_vector_sign<int16>( ctrl.sign_O, aie::neg(shift_res) );
        pO[0] = out;

        pO = add_3d_byte( pO, dims_O );
        pO = add_2d_byte( pO, dims_O2 );
        pC = add_2d_byte( pC, dims_C );
    }

    event1( );
}

#endif // __ACTIVATED_DWC_QDQ_INT16X8_S1_IMPL_HPP__