/*  (c) Copyright 2019 - 2024 Xilinx, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __ACTIVATED_CONV_QDQ_INT16X8_TEMPLATE_H__
#define __ACTIVATED_CONV_QDQ_INT16X8_TEMPLATE_H__

#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"
#include "common.hh"
// #include "activated_conv_qdq_int16x8.hpp"
#include "access_helpers.hpp"
// #include "kernel_helpers.h"
#ifdef DEBUG_KERNEL
#include "stdio.h"
#endif


#ifndef ACCUS_SIZE
#define ACCUS_SIZE 64*4
#endif

template<bool has_actv_sum, unsigned vector_coeff, int hardened_loop, class Ba, class Bw, class Bc, class Bo>
__aie_inline void activated_conv_qdq_int16x8
(
        Ba & restrict bufA_,
        Bw & restrict bufW_,
        Bc & restrict bufC_,
        Bo & restrict bufO_,
        int8 * restrict spill_buf,
        int8 * restrict cf_cache,
        const ActivatedConvInt16x8Params &conv_params
) {
    static_assert(     ( has_actv_sum == 0 && ( vector_coeff == 0 || vector_coeff == 2 ) ) 
                    || ( has_actv_sum == 1 && ( vector_coeff == 0 || vector_coeff == 1 || vector_coeff == 2 ) ) ,
                "Unsupported operating mode for QDQ" );

    bool has_tlast=1;

    Ba &bufA = bufA_;
    Bw &bufW = bufW_;
    Bc &bufC = bufC_;
    Bo &bufO = bufO_;

    using Ta = int16; //buffer_element_t<Ba>;
    using Tw = int8; //buffer_element_t<Bw>;
    
    using Tc = float; //buffer_element_t<Bc>;
    using To = int16; //buffer_element_t<Bo>;

    constexpr unsigned V_qdq = ( std::is_same_v<Tc, float> ) ? 32 : 16;
    constexpr unsigned N_c0 = 64;
    constexpr unsigned N_c1 = vector_coeff >= 1 ? 64 : 1;
    constexpr unsigned N_c2 = vector_coeff >= 2 ? 64 : 1;
    constexpr unsigned V_c0 = std::min( V_qdq, N_c0 );
    constexpr unsigned V_c1 = std::min( V_qdq, N_c1 );
    constexpr unsigned V_c2 = std::min( V_qdq, N_c2 );

    constexpr unsigned unroll_factor = 2;

    constexpr unsigned granYX=32;
    constexpr unsigned granYX2 = granYX/2;

    ConvQdQInt16x8_RT_Params* conv_rt_params;

    int il_bound = conv_params.inner_loop;
    int ol_bound = conv_params.inner_time_iters;
    int iters = conv_params.outer_time_iters;

    #if defined( DEBUG_KERNEL ) || !defined( __chess__ )
    constexpr unsigned il_lr = 1; 
    constexpr unsigned Ky_x_Kx = 1; 
    #else
    constexpr unsigned il_lr = 3; /* kh*kw*Ci at least 3 iters */
    constexpr unsigned Ky_x_Kx = 4; /* kh*kw at least 2 iters */
    #endif

    #if defined( DEBUG_KERNEL ) || defined( DEBUG_TESTBENCH ) || !defined( __chess__ )
    constexpr unsigned ol_lr = 1;
    constexpr unsigned it_lr = 1;
    #else
    constexpr unsigned ol_lr = 1;
    constexpr unsigned it_lr = 1;
    #endif

    // check_bounds<il_lr, ol_lr, 4096, it_lr>( il_bound, ol_bound, 1, iters );

    Ta   * pA_in;
    v64int16 * pA;
    int8     * pAs;
    int8     *pWs;
    To * pO, * pO_l, * pO_h;

    pWs = (int8 *) bufW.data();
    auto in_desc_weight = aie::make_tensor_descriptor<int8, 128>( aie::tensor_dim( 16, 1 ));
    auto bufW_tbs = aie::make_tensor_buffer_stream( bufW, in_desc_weight );

    auto in_desc_coeff  = aie::make_tensor_descriptor<float, 32>( aie::tensor_dim( 8, 1 ));
    auto bufC_tbs = aie::make_tensor_buffer_stream( bufC, in_desc_coeff );

    /* dimensions */
    dims_2d_t dimsAl = conv_params.dims_A2.instantiate( );
    dims_3d_t dimsAi = conv_params.dims_A3.instantiate( );
    dims_2d_t dims_in_inner = conv_params.dims_conv2d_sum_inner.instantiate( );
    dims_2d_t dims_in_outer = conv_params.dims_conv2d_sum_outer.instantiate( );
    dims_3d_t dims_sum_actv = conv_params.dims_sum_actv.instantiate( );
    

    struct coeff_cache_type {
        alignas( 128 ) Tc c0[N_c0];
        Tc c1[N_c1];
        Tc c2[N_c2];
        Tc c2_shift[1];
    };
    
    coeff_cache_type *coeff_cache_ptr = (coeff_cache_type*) cf_cache;
    coeff_cache_type &coeff_cache = *coeff_cache_ptr;

    //collect qdq coefficients from weight stream
    auto store_coeff = [&]( int8* ptr, unsigned vec) __attribute__(( always_inline )) {
        #pragma unroll
        for ( unsigned l = 0; l < 2; l++ ) {
            aie::vector<float,32> load_vec;
            bufC_tbs >> load_vec;
            if ( vec >= 1 ) {
                *( (chess_protect_access v32float*) ptr+l) = load_vec;
            } else {
                *( (chess_protect_access float*) ptr) = extract_elem(extract_v16float(load_vec,0), 0);
            }
        }
    };
    
    constexpr auto coeff_step = []( unsigned N ) { return ( N > 1 ? 64 : 1 ); };

    //collecting coefficients from coefficient
    auto coeff_fetch = [&]( ) __attribute__(( always_inline )) {
        store_coeff((int8*)coeff_cache.c0, 1);
        if constexpr( has_actv_sum > 0) {
            store_coeff((int8*)coeff_cache.c1, vector_coeff >= 1);
        }
        store_coeff((int8*)coeff_cache.c2, vector_coeff >= 2);

        *( (Tc*) coeff_cache.c2_shift) = aie::mul( *((Tc*) coeff_cache.c2), 256.0f);
    };

    v32int32 * conv_tmp_l    = (v32int32*) (spill_buf );
    v32int32 * conv_tmp_h    = (v32int32*) (spill_buf + 1*conv_params.max_accus*ACCUS_SIZE);
    v32int32 * sum_tmp_ifm_l = (v32int32*) (spill_buf + 2*conv_params.max_accus*ACCUS_SIZE);
    v32int32 * sum_tmp_ifm_h = (v32int32*) (spill_buf + 3*conv_params.max_accus*ACCUS_SIZE);
    v32float * sum_ifm_qdq_in  = (v32float *) sum_tmp_ifm_l; /* 3x3 or 7x7 spill buffer is bigger than output - last sum pointing to spill_buff + output size. OT buffer shared in some cases. */
    v32float * sum_ifm_qdq_out = (v32float *) sum_tmp_ifm_h; 


    m32x64acc32 chess_storage(EM0) acc0 = chess_dont_care(m32x64acc32);
    m32x64acc32 chess_storage(EM1) acc1 = chess_dont_care(m32x64acc32);

    uint5_t chess_storage(i0) iso_0 = 0;
    uint5_t chess_storage(i1) iso_1 = 0;

    uint5_t chess_storage(i7) ib = 0;
    uint5_t chess_storage(i6) iw = 0;

    uint5_t chess_storage(i5) im0 = 0;
    uint5_t chess_storage(i4) im1 = 0;

    uint5_t chess_storage(i2) idx_r = 0;
    uint5_t chess_storage(i3) idx_s = 0;

    /* fifo actvs */
    fifo_state_t fA, fB, fI;
    
    fA.fifo = chess_dont_care(sparse_fifo_t);
    sparse_fifo_t chess_storage(lf0) fAf;
    fA.fifo = fAf;
    fA.pos = 0;

    fB.fifo = chess_dont_care(sparse_fifo_t);
    sparse_fifo_t chess_storage(lf1) fBf;
    fB.fifo = fBf;
    fB.pos = 0;

    fI.fifo = chess_dont_care(sparse_fifo_t);
    sparse_fifo_t chess_storage(lf1) fIf;
    fI.fifo = fBf;
    fI.pos = 0;

    /* outputs actvs */
    int zero_acc = 1;
    int zero_acc_actvs = 1;
    int zero_acc_actvs_intermediate = 1;

    constexpr bool pipelined_conv2d = false;

    auto fifo_init_A = [&]( ) __attribute__(( always_inline )) {
        fA.fifo = chess_dont_care(sparse_fifo_t);
        sparse_fifo_t chess_storage(lf0) temp = chess_copy(fA.fifo); 
        fA.fifo = temp;                  
        fA.pos = 0;
    };

    auto fifo_init_B = [&]( ) __attribute__(( always_inline )) {
        fB.fifo = chess_dont_care(sparse_fifo_t);
        sparse_fifo_t chess_storage(lf1) temp = chess_copy(fB.fifo); 
        fB.fifo = temp;
        fB.pos = 0;
    };
    
    auto fifo_init_I = [&]( ) __attribute__(( always_inline )) {
        fI.fifo = chess_dont_care(sparse_fifo_t);
        sparse_fifo_t chess_storage(lf1) temp = chess_copy(fI.fifo); 
        fI.fifo = temp;
        fI.pos = 0;
    };

    auto conv_init = [&]( ) __attribute__(( always_inline )) {
        acc0 = chess_dont_care( m32x64acc32 );
        acc1 = chess_dont_care( m32x64acc32 );
        zero_acc = 1;
        zero_acc_actvs = 1;
        zero_acc_actvs_intermediate = zero_acc_actvs;
        fifo_init_A();
        fifo_init_B();
        dimsAi.count1 = 0;
        dimsAi.count2 = 0;
    };

    auto weight_acquire = [&]( ) __attribute__(( always_inline )) {
            bufW_tbs.acquire( );
            conv_rt_params = (ConvQdQInt16x8_RT_Params*)byte_incr(pWs, conv_params.wgt_size + conv_params.coeff_size);
    };

    auto weight_release = [&]( ) __attribute__(( always_inline )) {
            bufW_tbs.release( );
    };
    
    auto coeff_acquire = [&]( ) __attribute__(( always_inline ) ){
            bufC_tbs.acquire( );
    };

    auto coeff_release = [&]( ) __attribute__(( always_inline )) {
            bufC_tbs.release( );
    };

    auto ifm_acquire = [&]( ) __attribute__(( always_inline )) {
            bufA.acquire( );
            pA_in = (Ta *) bufA.data();

            pA = (v64int16 *) pA_in;
            pAs = (int8 *) bufA.data();
    };

    auto ifm_release = [&]( ) __attribute__(( always_inline )) {
            bufA.release( );           
    };

    auto ofm_acquire = [&]( ) __attribute__(( always_inline )) {
            bufO.acquire( );
            pO = (To *)bufO.data( );
            pO_l = pO;
            pO_h = add_elem( pO, 32 );
    };

    auto ofm_release = [&]( ) __attribute__(( always_inline )) {
            bufO.release( );
    };    

    auto weights_body = [&](auto idx) __attribute__(( always_inline )) {
        aie::vector<int8, 128> w;
        bufW_tbs >> w;
        insert_staging( w, iw, 2 + conv_rt_params->sign_W );
        iw++;
    };

    void * dummy = (void*)0;

    auto mac_body_generic = [&]<unsigned unroll_idx = 0>(auto idx, bool is_last, std::integral_constant<unsigned,unroll_idx> unroll_idx_dummy = {}) __attribute__(( always_inline )) {
        v64int16 a;
        v32int16 a0, a1;

        a0 = aie::utils::locate_in_register<12 + 2*unroll_idx,aie::utils::AIE_RegFile::Vector>(v32int16(fifo_ld_popx(( v64int8 *& ) pA, fA, conv_params.step_align, 63)));
        if (!is_last) {
            a1 = aie::utils::locate_in_register<13 + 2*unroll_idx,aie::utils::AIE_RegFile::Vector>(v32int16(fifo_ld_pop_2d_byte(( v64int8 *& ) pA, fA, dimsAl )));
        }
        else {
            dimsAl.count1 = 0;
            a1 = aie::utils::locate_in_register<13 + 2*unroll_idx,aie::utils::AIE_RegFile::Vector>(v32int16(fifo_ld_pop_3d_byte(( v64int8 *& ) pA, fA, dimsAi )));
        }

        v64int8 b0, b1;

        b0 = aie::utils::locate_in_register<4 + 2*unroll_idx, aie::utils::AIE_RegFile::Vector>((( v64int8 ) shuffle( a0, a1, T8_64x2_lo )));
        b1 = aie::utils::locate_in_register<5 + 2*unroll_idx, aie::utils::AIE_RegFile::Vector>( ( v64int8 ) shuffle( a0, a1, T8_64x2_hi ) );

        if ( chess_manifest( zero_acc == 1 ) ) {
            acc0[im0] = mul( b0, false,       conv_rt_params->sign_W );
            acc1[im1] = mul( b1, conv_rt_params->sign_A, conv_rt_params->sign_W );
        } else {
            acc0[im0] = mac_conf( b0,       false, conv_rt_params->sign_W, acc0[im0], zero_acc );
            acc1[im1] = mac_conf( b1, conv_rt_params->sign_A, conv_rt_params->sign_W, acc1[im1], zero_acc );
        }
        im0++; 
        im1++;
    };
    
    auto mac_body = [&]<unsigned unroll_idx = 0>(auto idx, std::integral_constant<unsigned,unroll_idx> unroll_idx_dummy = {}) __attribute__(( always_inline )) {
        bool is_last = chess_manifest(idx == granYX-1);
        mac_body_generic.template operator()<unroll_idx>(idx, is_last);
    };

    /* unroll_idx allows a call with integral constant, objects/parameters initialized to be supported by both unroll_times and unroll_fn 
       unroll_fn is now supporting unroll templatized param and loop idx.
    */
    auto sum_fetch = [&]<unsigned unroll_idx = 0>( unsigned idx, std::integral_constant<unsigned,unroll_idx> unroll_idx_dummy = {} ) __attribute__(( always_inline )) {
        aie::vector<int8, 128> a;
        auto a0     = /*locate_in_register<0 + 8*(unroll_idx), AIE_RegFile_Vector>*/( (fifo_ld_popx(( v64int8 *& ) pAs, fB, conv_params.step_align_sum, 63)) );
        auto a1     = /*locate_in_register<2 + 8*(unroll_idx), AIE_RegFile_Vector>*/( (fifo_ld_pop_3d_byte(( v64int8 *& ) pAs, fB, dims_sum_actv )) );
        auto a1_tmp = /* locate_in_register<1 + 8*(unroll_idx), AIE_RegFile_Vector>*/( sel( /*locate_in_register<3, AIE_RegFile_Vector>*/(broadcast_zero_to_v64int8()), a1, conv_params.mask_Ci_high));     /* int32 mask param */
             a      = locate_in_register<0 + 4*(unroll_idx), AIE_RegFile_Vector>( concat(a0, a1_tmp) );
        insert_staging( a, ib, 2 + conv_rt_params->sign_A );
        ib++;
    };

    auto sum_fetch_32 = [&]<unsigned peel_front = 3, unsigned peel_back = 2>(  ) __attribute__(( always_inline )) {
        aie::pipelined_loop<granYX/unroll_factor, aie::LoopOptions{.peel_front = peel_front, .peel_back = peel_back}>(granYX/unroll_factor, unroll_fn<unroll_factor>(sum_fetch));
    };

    auto sum_fetch_32_unroll = [&](  ) __attribute__(( always_inline )) {
        aie::unroll_times<granYX>(sum_fetch);
    };

    auto spill_start = [&]( auto spill_idx ) __attribute__(( always_inline )) {
        // Accumulators are spilled to preserve the conv
        *(conv_tmp_l + 2*spill_idx)     = (v32int32) extract_v32acc32( acc0[spill_idx], 0);
        *(conv_tmp_l + 2*spill_idx + 1) = (v32int32) extract_v32acc32( acc0[spill_idx], 1);
        *(conv_tmp_h + 2*spill_idx)     = (v32int32) extract_v32acc32( acc1[spill_idx], 0);
        *(conv_tmp_h + 2*spill_idx + 1) = (v32int32) extract_v32acc32( acc1[spill_idx], 1);
    };

    auto restore_start = [&]( auto spill_idx ) __attribute__(( always_inline )) {
        // Accumulators are restored to create the ifm_sum
        acc0 = insert( acc0, spill_idx, 0, *((chess_protect_access v32acc32*) sum_tmp_ifm_l + 2*spill_idx) );
        acc0 = insert( acc0, spill_idx, 1, *((chess_protect_access v32acc32*) sum_tmp_ifm_l + 2*spill_idx + 1) );
        acc1 = insert( acc1, spill_idx, 0, *((chess_protect_access v32acc32*) sum_tmp_ifm_h + 2*spill_idx) );
        acc1 = insert( acc1, spill_idx, 1, *((chess_protect_access v32acc32*) sum_tmp_ifm_h + 2*spill_idx + 1) );
    };

    auto spill_end_int16 = [&]( auto spill_idx ) __attribute__(( always_inline )) {
        // Accumulators are spilled to create the ifm_sum
        *(sum_tmp_ifm_l + 2*spill_idx + 0) = (v32int32) extract_v32acc32(acc0[spill_idx], 0);
        *(sum_tmp_ifm_l + 2*spill_idx + 1) = (v32int32) extract_v32acc32(acc0[spill_idx], 1);
        *(sum_tmp_ifm_h + 2*spill_idx + 0) = (v32int32) extract_v32acc32(acc1[spill_idx], 0);
        *(sum_tmp_ifm_h + 2*spill_idx + 1) = (v32int32) extract_v32acc32(acc1[spill_idx], 1);
    };

    auto restore_end_int16 = [&]( auto spill_idx ) __attribute__(( always_inline )) {
        // Accumulators are restored for the conv
        acc0 = insert( acc0, spill_idx, 0, *((chess_protect_access v32acc32*) conv_tmp_l + 2*spill_idx) );
        acc0 = insert( acc0, spill_idx, 1, *((chess_protect_access v32acc32*) conv_tmp_l + 2*spill_idx + 1) );
        acc1 = insert( acc1, spill_idx, 0, *((chess_protect_access v32acc32*) conv_tmp_h + 2*spill_idx) );
        acc1 = insert( acc1, spill_idx, 1, *((chess_protect_access v32acc32*) conv_tmp_h + 2*spill_idx + 1) );
    };

    auto access_ifm = [&]( auto acc_idx, int accu_lh, int ifm_lh, bool is_accu = 0 ) __attribute__(( always_inline )) {
        if (!is_accu) {
            if (!ifm_lh)
                return sum_tmp_ifm_l[2*acc_idx + accu_lh];
            else
                return sum_tmp_ifm_h[2*acc_idx + accu_lh];
        } else {
            if (!ifm_lh)
                return v32int32(extract_v32acc32(acc0[acc_idx], accu_lh));
            else
                return v32int32(extract_v32acc32(acc1[acc_idx], accu_lh));
        }
    };
    
    auto sum_int16_post_cache_rev_unroll = [&]( auto acc_idx, float * output, bool is_accu = 0 ) __attribute__(( always_inline )) {
            aie::accum<accfloat,32> acc;
            aie::vector<int32,32> vec_high, vec_low;            
            vec_low = aie::interleave_unzip( aie::vector<int32,32>( access_ifm(acc_idx, 0, 0, is_accu) ), aie::vector<int32,32>( access_ifm(acc_idx, 1, 0, is_accu) ), 1).first;
            vec_low = aie::add( vec_low, aie::interleave_unzip( aie::vector<int32,32>( access_ifm(acc_idx, 0, 0, is_accu) ), aie::vector<int32,32>( access_ifm(acc_idx, 1, 0, is_accu) ), 1).second );
            vec_high = aie::interleave_unzip( aie::vector<int32,32>( access_ifm(acc_idx, 0, 1, is_accu) ), aie::vector<int32,32>( access_ifm(acc_idx, 1, 1, is_accu) ), 1).first;
            vec_high = aie::add(vec_high, aie::interleave_unzip( aie::vector<int32,32>( access_ifm(acc_idx, 0, 1, is_accu) ), aie::vector<int32,32>( access_ifm(acc_idx, 1, 1, is_accu) ), 1).second);
            aie::vector<float,32> vec_scale = (v32float) mul_elem_32( vec_low, 1.0f);
            vec_scale = (v32float) mac_elem_32(vec_high, 256.0f, (v32accfloat) vec_scale);
            // chess_report(vec_scale);
            *( ( (v32float*) output ) + acc_idx ) = vec_scale;
    };

    auto sum_block_int16 = [&](  auto spill_idx, bool zero_init = 0, bool compute_hilo_mix = false ) __attribute__(( always_inline )) {
        //Accumulators are spilled and restored to create the ifm_sum
        staging_to_matrix_m64x64int8( );

        auto low_mask  = chess_duplicate(aie::interleave_zip(aie::broadcast<int8,64>( 1 ),aie::broadcast<int8,64>( 0 ),1).first);
        auto high_mask = chess_duplicate(aie::interleave_zip(aie::broadcast<int8,64>( 0 ),aie::broadcast<int8,64>( 1 ),1).first);

        low_mask  = locate_in_register<12, AIE_RegFile_Vector>(sel(broadcast_zero_to_v64int8(), low_mask,  conv_params.mask_Ci_low));
        high_mask = locate_in_register<13, AIE_RegFile_Vector>(sel(broadcast_zero_to_v64int8(), high_mask, conv_params.mask_Ci_low));

        if ( chess_manifest( zero_init == 1 )) {
            acc0[spill_idx] = mul( low_mask, false, false );
            acc1[spill_idx] = mul( high_mask, false, conv_rt_params->sign_A );
        } else {
            restore_start(spill_idx); /* actv to accu */
            acc0[spill_idx] = mac_conf( low_mask, false, false, acc0[spill_idx], zero_init );
            acc1[spill_idx] = mac_conf( high_mask, false, conv_rt_params->sign_A, acc1[spill_idx], zero_init );
        }

        if ( compute_hilo_mix ) {
            /* last it for the actv sum */
            sum_int16_post_cache_rev_unroll( spill_idx, (float *) sum_ifm_qdq_in, compute_hilo_mix ); /* implicit spill - last it */
        } else {
            spill_end_int16(spill_idx); /* actv sum to mem */
        }

    };

    auto sum_conv_2d_c1 = [&]( float * buffer_out, float * buffer_in, int YX, int Ky, int Kx, int Sx ) __attribute__(( always_inline )) {
        v16int32 * pI = ( v16int32 * ) buffer_in;
        float * restrict pO_conv2d = buffer_out;
        constexpr unsigned N = 8; /* Therefore, minimum amount of supported X pixels is 8 */
        constexpr unsigned samplebytes = 4;
        
        fifo_init_I();

        // int reset = -64; /* 16*4 bytes (vector load) */
        // int stepKy = Sy * samplebytes*(Sx*X + Kx - Sx);
        // dims_2d_t dims_in_inner = dims_2d_from_steps_reset( reset, Kx, samplebytes, stepKy);
        // reset = -Ky * stepKy;
        // dims_2d_t dims_in_outer = dims_2d_from_steps_reset( reset, X/N, samplebytes*N*Sx, stepKy);

        for ( unsigned pix = 0; pix < YX / N; pix++ ) 
        chess_prepare_for_pipelining chess_loop_range( 4, )
        {
            v32accfloat chess_storage(y0) sum; 
            v32float chess_storage(y1) val; 
            int zero_acc_local = 1;
            for ( unsigned k = 0; k < Ky * Kx; k++ ) 
            chess_prepare_for_pipelining chess_loop_range( Ky_x_Kx, )
            {
                fifo_ld_fill( pI, fI );
                v16float data = v16float( shuffle( fifo_ld_pop_2d_byte( pI, fI, dims_in_inner ), Sx > 1 ? T32_16x2_lo : T512_1x2_lo ) );
                val = insert(val, 0, data);
                sum = mac_elem_32_conf( val, 1.0f, sum, zero_acc_local, 0, 0);
                zero_acc_local = 0;
            }
            *( (v8float *) pO_conv2d) = extract_v8float(v32float(sum), 0);
            pO_conv2d += N;
            pI = add_2d_byte( pI, dims_in_outer );
        }
        fifo_init_I();
    };

    auto write_output = [&]<unsigned hl_idx = 0>( auto l , auto i, v32float sum_ifm, To * pO_out) __attribute__(( always_inline )) {
        if constexpr ( std::is_same_v<Tc, float> ){

            v32int32 gemm_low = ( v32int32 ) extract_v32acc32( (acc0[l]), i);
            v32int32 gemm_high = ( v32int32 ) extract_v32acc32( (acc1[l]), i);

            aie::vector<float,32>  c0_bias_vector;
            decltype(access<V_c2>( coeff_cache.c2, i ))  qdq_coeff2, c2_shifted;

            if constexpr ( vector_coeff > 1 ) {
                if constexpr ( has_actv_sum ) {
                    decltype(access<V_c2>( coeff_cache.c2, i )) qdq_coeff2_pin = locate_in_register<0 + hl_idx,AIE_RegFile_Vector>(*((chess_protect_access v32float __aie_dm_resource_b *)coeff_cache.c2 + i));
                    qdq_coeff2 = locate_in_register<0 + hl_idx,AIE_RegFile_Vector>(qdq_coeff2_pin);
                } else {
                    decltype(access<V_c2>( coeff_cache.c2, i )) qdq_coeff2_pin = (*((chess_protect_access v32float __aie_dm_resource_b *)coeff_cache.c2 + i));
                    qdq_coeff2 = (qdq_coeff2_pin);                    
                }
            } else {
                qdq_coeff2 = access<V_c2>( coeff_cache.c2, i );
                c2_shifted =  access<V_c2>( coeff_cache.c2_shift, i );
            }

            if constexpr ( has_actv_sum ) {
                c0_bias_vector = locate_in_register<2 + hl_idx,AIE_RegFile_Vector>(*((chess_protect_access v32float __aie_dm_resource_b*)coeff_cache.c0 + i));
            } else {
                c0_bias_vector = locate_in_register<0 + hl_idx,AIE_RegFile_Vector>(*((chess_protect_access v32float __aie_dm_resource_b*)coeff_cache.c0 + i));
            }

            aie::vector<float,32> qdq_acc;
            aie::accum<accfloat,32> qdq_acc1;

            if constexpr ( vector_coeff >= 2 && has_actv_sum ) {
                qdq_acc =  (v32float) /* locate_in_register<2,AIE_RegFile_Vector> */( mul_elem_32(gemm_low, 1.0f ) );
                qdq_acc1 = /* locate_in_register<7,AIE_RegFile_Vector> */( mac_elem_32(gemm_high, 256.0f, (v32accfloat) qdq_acc ));
                qdq_acc1 = locate_in_register<6 + hl_idx,AIE_RegFile_Vector>( mac_elem_32((v32float) qdq_acc1, qdq_coeff2, (v32accfloat) c0_bias_vector));
            } else if constexpr ( vector_coeff == 2 ) {
                qdq_acc =  (v32float) locate_in_register<3,AIE_RegFile_Vector>( mul_elem_32(gemm_low, 1.0f ) );
                qdq_acc1 = locate_in_register<4,AIE_RegFile_Vector>(mac_elem_32(gemm_high, 256.0f, (v32accfloat) qdq_acc ));
                qdq_acc1 = locate_in_register<5,AIE_RegFile_Vector>(mac_elem_32((v32float) qdq_acc1, qdq_coeff2, (v32accfloat) c0_bias_vector));
            } else if constexpr ( vector_coeff == 1 ) {
                qdq_acc =  (v32float) /* locate_in_register<3,AIE_RegFile_Vector> */( ( mac_elem_32(gemm_low, qdq_coeff2, (v32accfloat) c0_bias_vector) ));
                qdq_acc1 = /* locate_in_register<4,AIE_RegFile_Vector> */ ( ( mac_elem_32(gemm_high, c2_shifted, (v32accfloat) qdq_acc ) ));
            } else if constexpr ( vector_coeff == 0 && has_actv_sum ) {
                qdq_acc =  (v32float) ( ( mac_elem_32(gemm_low, qdq_coeff2, (v32accfloat) c0_bias_vector) ));
                qdq_acc1 = locate_in_register<6 + hl_idx,AIE_RegFile_Vector> ( ( mac_elem_32(gemm_high, c2_shifted, (v32accfloat) qdq_acc ) ));
            } else if constexpr ( vector_coeff == 0 ) {
                qdq_acc =  (v32float) ( ( mac_elem_32(gemm_low, qdq_coeff2, (v32accfloat) c0_bias_vector) ));
                qdq_acc1 = ( ( mac_elem_32(gemm_high, c2_shifted, (v32accfloat) qdq_acc ) ));
            }

            if constexpr( has_actv_sum ) {
                if constexpr ( vector_coeff >= 2 ) {
                    v32float qdq_coeff1;
                    qdq_coeff1 = /* locate_in_register<4,AIE_RegFile_Vector> */(*((chess_protect_access __aie_dm_resource_b v32float*)coeff_cache.c1 + i));
                    qdq_acc1 = locate_in_register<6,AIE_RegFile_Vector>(mac_elem_32(sum_ifm, qdq_coeff1, (v32accfloat) qdq_acc1));
                } else if constexpr ( vector_coeff >= 1 ) {
                    v32float qdq_coeff1;
                    qdq_coeff1 = /* locate_in_register<3,AIE_RegFile_Vector> */(*((chess_protect_access __aie_dm_resource_b v32float*)coeff_cache.c1 + i));
                    qdq_acc1 = /*locate_in_register<5 + 2*hl_idx,AIE_RegFile_Vector>*/(mac_elem_32(sum_ifm, qdq_coeff1, (v32accfloat) qdq_acc1));                    
                } else {
                    float qdq_coeff1 = (access<V_c1>( coeff_cache.c1, i ));
                    qdq_acc1 = (mac_elem_32(sum_ifm, qdq_coeff1, (v32accfloat) qdq_acc1));
                }
            }

            auto out = qdq_acc1.template to_vector_sign<int16>( conv_rt_params->sign_O, aie::neg(conv_rt_params->shift_res));

            write_v( pO_out, out );
        }
    };

    auto output_body = [&](auto idx) __attribute__(( always_inline )) {
        float * sum = (float*) sum_ifm_qdq_out;
        
        v32float sum_ifm_l;
        v32float sum_ifm_h;

        if constexpr ( false ) {
            sum_ifm_l = (insert( undef_v32float( ), 0, broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 0 >> conv_params.Co_shft )] )));
            sum_ifm_l = (insert(         sum_ifm_l, 1, broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 1 >> conv_params.Co_shft )] )));
            sum_ifm_h = (insert(         sum_ifm_l, 0, broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 2 >> conv_params.Co_shft )] )));
            sum_ifm_h = (insert(         sum_ifm_h, 1, broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 3 >> conv_params.Co_shft )] )));
        } else {
            sum_ifm_l = (concat(broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 0 >> conv_params.Co_shft )] ), broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 1 >> conv_params.Co_shft )] )));
            sum_ifm_h = (concat(broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 2 >> conv_params.Co_shft )] ), broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 3 >> conv_params.Co_shft )] )));
        }
        
        chess_separator();

        if constexpr(has_actv_sum) {
            write_output.template operator()<0>(iso_0++, 0, sum_ifm_l, pO_l );
            pO_l = add_elem( pO_l, 64 );
            chess_separator();
            write_output.template operator()<1>(iso_1++, 1, sum_ifm_h, pO_h );
            pO_h = add_elem( pO_h, 64 );
        } else {
            write_output.template operator()<0>(iso_0++, 0, sum_ifm_l, pO_l );
            pO_l = add_elem( pO_l, 64 );
            write_output.template operator()<0>(iso_1++, 1, sum_ifm_h, pO_h );
            pO_h = add_elem( pO_h, 64 );
        }
    };

    auto output_body_l = [&](auto idx) __attribute__(( always_inline )) {
        float   * sum = (float*) sum_ifm_qdq_out;    
        v32float sum_ifm_l = locate_in_register<3,AIE_RegFile_Vector>(concat(broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 0 >> conv_params.Co_shft )] ), broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 1 >> conv_params.Co_shft )] )));
        write_output(iso_0++, 0, sum_ifm_l, pO_l );
        pO_l = add_elem( pO_l, 64 );
    };

    auto output_body_h = [&](auto idx) __attribute__(( always_inline )) {
        float * sum = (float*) sum_ifm_qdq_out;
        v32float sum_ifm_h = locate_in_register<3,AIE_RegFile_Vector>(concat(broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 2 >> conv_params.Co_shft )] ), broadcast_to_v16float( sum[conv_params.Co_blk*idx + ( 3 >> conv_params.Co_shft )] )));
        write_output(iso_1++, 1, sum_ifm_h, pO_h );
        pO_h = add_elem( pO_h, 64 );
    };

    auto prefetch_weights = [&]() __attribute__(( always_inline )) {
        aie::pipelined_loop<granYX, aie::LoopOptions{.peel_front = 0, .peel_back = 0}>(granYX, weights_body);
        staging_to_matrix_m64x64int8();
    };

    auto compute = [&]() __attribute__(( always_inline )) {
        aie::pipelined_loop<granYX/unroll_factor, aie::LoopOptions{.peel_front = 0, .peel_back = 1}>(granYX/unroll_factor, unroll_fn<unroll_factor>(mac_body));
        zero_acc = 0;
    };

    auto output = [&]() __attribute__(( always_inline )) {
        if constexpr ( vector_coeff < 2 ) {
            aie::pipelined_loop<granYX, aie::LoopOptions{.peel_front = 0, .peel_back = 0}>(granYX, output_body);
        } else {
            aie::pipelined_loop<granYX, aie::LoopOptions{.peel_front = 0, .peel_back = 0}>(granYX, output_body_l);
            aie::pipelined_loop<granYX, aie::LoopOptions{.peel_front = 0, .peel_back = 0}>(granYX, output_body_h);
        }
    };

    auto fetch_and_compute = [&]() __attribute__(( always_inline )) {

        constexpr auto opts_lambda = []() -> std::array<aie::LoopOptions,2> {
            if constexpr(has_actv_sum) {
                if constexpr( is_stream_type_v<Tw> ) {
                    return { aie::LoopOptions{.peel_front = 4, .peel_back = 1}, aie::LoopOptions{.peel_front = 3, .peel_back = 2} };
                } else {
                    return { aie::LoopOptions{.peel_front = 4, .peel_back = 1}, aie::LoopOptions{.peel_front = 2, .peel_back = 3} };
                }
            } else {
                if constexpr( is_stream_type_v<Tw> ) {
                    return { aie::LoopOptions{.peel_front = 5, .peel_back = 1}, aie::LoopOptions{.peel_front = 4, .peel_back = 2} };
                } else {
                    return { aie::LoopOptions{.peel_front = 5, .peel_back = 1}, aie::LoopOptions{.peel_front = 3, .peel_back = 3} };
                }
            }
        }();

        aie::pipelined_loops<granYX/unroll_factor, opts_lambda[0], opts_lambda[1]>(granYX/unroll_factor, unroll_fn<unroll_factor>(weights_body), unroll_fn<unroll_factor>(mac_body));
        staging_to_matrix_m64x64int8();
        zero_acc = 0;
    };

    auto compute_and_output = [&]() __attribute__(( always_inline )) {
        aie::pipelined_loops<granYX, aie::LoopOptions{.peel_front = 1, .peel_back = 2}, aie::LoopOptions{.peel_front = 2, .peel_back = 1}>(granYX, (output_body), (mac_body));
    };

    auto actv_sum_inner_first_blocks = [&]<bool compute_conv2d>(auto zero_init) __attribute__(( always_inline )) {
        auto z = zero_init;
        /* accum on Ci */
        for ( int r=0; r < conv_params.sum_bound; r++ ) 
        {
            sum_block_int16( idx_s++, z, compute_conv2d);
            if (chess_manifest(compute_conv2d == false))
            {
                sum_fetch_32.template operator()<3,2>();
            } else {
                sum_fetch_32_unroll();
            }
        }

        ib = 0;
        // ib = chess_copy( ib );
    };

    auto actv_sum_inner_last_block = [&]<bool compute_conv2d, bool is_last>(auto zero_init) __attribute__(( always_inline )) {
        auto z = zero_init;
        if (true) {
            sum_block_int16( idx_s, z, compute_conv2d);
        }

        if constexpr(is_last == false) {
            sum_fetch_32.template operator()<3,2>();
        }
    };

    auto actv_sum = [&](auto zero_init) __attribute__(( always_inline )) {

        bool do_actv_sum = has_actv_sum;
        auto z = zero_init;

        if ( do_actv_sum ) {

            idx_s = 0;
            for ( int i=0; i < conv_params.sum_bound + 1; i++ ) 
            // chess_prepare_for_pipelining chess_loop_range(1,) 
            {
                spill_start(idx_s++); /* conv to mem */
                // idx_s = chess_copy(idx_s);
            }
            
            for ( int i=0; i < conv_params.sum_outer; i++ )
            {
                ib = 0;
                // ib = chess_copy( ib );
                idx_s = 0;

                sum_fetch_32.template operator()<0,0>();
                actv_sum_inner_first_blocks.template operator()<false>(z);
                dims_sum_actv.count1 = 0;
                dims_sum_actv.count2 = 0;
                fifo_init_A();
                pAs = (int8 *) byte_incr(pA_in, (i+1)*conv_params.step_Ci); // add Ci step
                actv_sum_inner_last_block.template operator()<false,true>(z);
                z = 0;
            }

            idx_r = 0;
            for ( int i=0; i < conv_params.sum_bound + 1; i++ )
            // chess_prepare_for_pipelining chess_loop_range(1,) 
            {
                restore_end_int16(idx_r++); /* conv back to accus */
                // idx_r = chess_copy(idx_r);
            }

        }            
    };

    auto direct_conv_symmetric = [&]() __attribute__(( always_inline )) {
        for ( unsigned it=0; it<iters; it++ )
#if 0
        #if !defined( DEBUG_KERNEL ) && !defined( DEBUG_TESTBENCH )
            chess_prepare_for_pipelining chess_loop_range( it_lr, )
        #endif
        #ifdef __chess__
            chess_allocate( R:28 )
            chess_allocate( P:8 )
        #endif
#endif
        {
            conv_init();

            weight_acquire();
            prefetch_weights();

            for (int ol=0; ol<ol_bound; ol++)
            chess_prepare_for_pipelining chess_loop_range(ol_lr,)
            {
                ifm_acquire();
                for (int il=1; il<il_bound; il++) 
                chess_prepare_for_pipelining chess_loop_range(il_lr-1,)
                {
                    fetch_and_compute();
                }

                weight_release();

                if ( ol == ol_bound - 1 )
                    break;

                weight_acquire();

                fetch_and_compute();

                fifo_init_A();
                ifm_release();
            }
            
            chess_separator();

            coeff_acquire();
            coeff_fetch();
            coeff_release();
            
            chess_separator();

            ofm_acquire();
            compute_and_output();
            ofm_release();

            fifo_init_A();
            ifm_release();
        }
    };

    auto direct_conv_asymmetric = [&]() __attribute__(( always_inline )) {
        for ( unsigned it=0; it<iters; it++ )
#if 0
        #if !defined( DEBUG_KERNEL ) && !defined( DEBUG_TESTBENCH )
            chess_prepare_for_pipelining chess_loop_range( it_lr, )
        #endif
        #ifdef __chess__
            chess_allocate( R:28 )
            chess_allocate( P:8 )
        #endif
#endif
        {
            conv_init();

            weight_acquire();
            ifm_acquire();

            prefetch_weights(); /* collision with actv_sum in z reg */

            for (int ol=0; ol < ol_bound; ol++) 
            chess_prepare_for_pipelining chess_loop_range(ol_lr,)
            {
                for (int il = 1; il < il_bound; il++) 
                chess_prepare_for_pipelining chess_loop_range(il_lr-1,)
                {
                    fetch_and_compute(); 
                }
                compute();
                fifo_init_A(); /* End of life for direct conv FIFO */

                actv_sum(zero_acc_actvs);
                fifo_init_B(); /* End of life for sum of actv FIFO */
                
                coeff_acquire();
                coeff_fetch();
                coeff_release();
                chess_separator_scheduler();
                weight_release();
                ifm_release();

                zero_acc_actvs = 0; /* Accumulate over sum of actv accus */

                if ( ol == ol_bound - 1 )
                    break;
                
                weight_acquire();
                ifm_acquire();

                prefetch_weights(); /* collision with actv_sum in z reg */
            }

            if constexpr( has_actv_sum ) {
                
                for (int n = 0; n < conv_params.n_accus; n++) {
                    sum_int16_post_cache_rev_unroll( n, (float *) sum_ifm_qdq_in );
                }

                chess_separator();

                sum_conv_2d_c1( (float *) sum_ifm_qdq_out, (float *) sum_ifm_qdq_in, ((unsigned) conv_params.Co_blk)*granYX, conv_params.Ky_g, conv_params.Kx_g, conv_params.Sx_g);
            }
            
            chess_separator_scheduler();

            ofm_acquire();
            output();
            ofm_release();
        }
    };

    #ifdef FILE_IO
    constexpr bool is_fileio = 0;
    #else
    constexpr bool is_fileio = 0;
    #endif

    chess_separator_scheduler( );

    if constexpr(has_actv_sum) {
        direct_conv_asymmetric();
    } else {
        direct_conv_symmetric();
    }
    
    chess_separator_scheduler( );

    event1( );
}



#endif // __ACTIVATED_CONV_QDQ_INT16X8_TEMPLATE_H__