/*  (c) Copyright 2019 - 2024 Xilinx, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __ACTIVATED_MMULT_QDQ_INT16X16_TRANSPOSE_TEMPLATE_H__
#define __ACTIVATED_MMULT_QDQ_INT16X16_TRANSPOSE_TEMPLATE_H__

#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"
#include "../common/common.hh"

#include "assert.h"
#include "access_helpers.hpp"
#ifdef DEBUG_KERNEL
#include "stdio.h"
#endif

namespace activated_mmult_qdq_int16x16_transpose_helpers {

/* Weights split */
template<bool stream=0, unsigned lr=12>
inline void split_i16( int16_t * buf, int count_v128, auto is_split ) {
    v128int8 chess_storage( DM_bankA ) * pI = ( v128int8 chess_storage( DM_bankA ) * ) buf;
    v128int8 chess_storage( DM_bankA ) * restrict pO = ( v128int8 chess_storage( DM_bankA ) * ) buf;
    [[using chess: prepare_for_pipelining, min_loop_count( lr )]]
    for ( unsigned i=0; i < count_v128; i++ ) {
        v128int8 a0, a1;
        if constexpr( stream ) {
            a0 = get_ss_v128int8_weight( 1 );
            a1 = get_ss_v128int8_weight( 1 );
        } else {
            a0 = *pI++;
            a1 = *pI++;
        }
        auto mode0 = is_split ? T512_1x2_lo : T8_64x2_lo;
        auto mode1 = is_split ? T512_1x2_lo : T8_64x2_hi;
        v128int8 b0 = set_v128int8( 0, shuffle( a0, mode0 ));
        v128int8 b1 = set_v128int8( 0, shuffle( a0, mode1 ));
        b0 = insert( b0, 1, shuffle( a1, mode0 ));
        b1 = insert( b1, 1, shuffle( a1, mode1 ));
        *pO++ = b0;
        *pO++ = b1;
    }
}

template<unsigned N, unsigned reg, typename T> inline __attribute__(( always_inline ))
auto access_loc( T * p, int i ) {
    decltype(  access<N>( p, i ) ) load_v;
    if(N > 1)
        load_v = locate_in_register<reg>(access<N>( p, i ));
    else
        load_v = access<N>(p, i );
    return load_v;
}

template<typename Ta, typename Ty> inline __attribute__(( always_inline ))
Ta aie_mac( Ta a, int32 x, Ty y ) {
    return aie::mac( a, aie::broadcast<int32, a.size( )>( x ), y );
}
template<> inline __attribute__(( always_inline ))
aie::accum<accfloat, 32> aie_mac<aie::accum<accfloat, 32>, float>( aie::accum<accfloat, 32> a, int32 x, float y ) {
    v32int32 vx = broadcast_to_v32int32( x );
    return aie::accum( mac_elem_32( vx, y, a ));
}

template<typename Ta, typename Ty, unsigned N> inline __attribute__(( always_inline ))
Ta aie_mac( Ta a, aie::vector<int32,N> x, Ty y ) {
    return aie::mac( a, x, y );
}
template<> inline __attribute__(( always_inline ))
aie::accum<accfloat, 32> aie_mac<aie::accum<accfloat, 32>, float, 32>( aie::accum<accfloat, 32> a, aie::vector<int32, 32> x, float y ) {
    return aie::accum( mac_elem_32( x, y, a ));
}


template<typename Ta, typename Tb>
inline auto aie_mul( Ta a, Tb b ) {
    return aie::mul( a, b );
}
template<>
inline auto aie_mul( float a, float b ) {
    return mul( a, b );
}


template<typename Ta, unsigned V, typename Ty> inline __attribute__(( always_inline ))
aie::accum<Ta, V> aie_mac_extract( aie::accum<Ta, V> a, aie::accum<acc32, 64> x, unsigned i, Ty y ) {
    return aie_mac( a, x.extract<V>( i ).template to_vector<int32>( ), y );
}
template<> inline __attribute__(( always_inline ))
aie::accum<accfloat, 32> aie_mac_extract<accfloat, 32, float>( aie::accum<accfloat, 32> a, aie::accum<acc32, 64> x, unsigned i, float y ) {
    return aie::accum( mac_elem_32(( v32int32 ) extract_v32acc32( x, i ), y, a ));
}

template<unsigned No, typename T>
inline void copy_broadcast( T * restrict dst, T * src, bool vector=1, bool scaling=0 ) {
    if constexpr( No == 1 ) {
        *dst = *src;
        if ( scaling ) {
            dst[1] = mul( *src, 256.0f );
            dst[2] = mul( *src, 65536.0f );
        }
    } else {
        constexpr unsigned size = No * sizeof( T );
        constexpr unsigned iters = ( size + 127 ) / 128;
        constexpr unsigned V = No / iters;
        int32_t s = as_int32( *src );
        for ( unsigned it = 0; it < iters; it++ ) {
            aie::vector c = ( v32float ) add_conf(( v32acc32 ) aie::broadcast<int32,V>( vector ? int32( 0 ) : s ), ( v32acc32 ) aie::load_v<V>( src + it * V ), !vector, 0, 0 );
            aie::store_v( dst + it * V, c );
            if ( scaling ) {
                aie::store_v( dst + No + it * V, aie::mul( c, 256.0f ).template to_vector<float>( ));
                aie::store_v( dst + 2 * No + it * V, aie::mul( c, 65536.0f ).template to_vector<float>( ));
            }
        }
    }
}

}

/* Temporal buffer */
static v128int8 ifm_tmp_buf[16];

#ifdef __chess_clang__
template<unsigned has_actv_sum, int has_vector_coeffs, typename T_cf, class Bi, class Bb, class Bc, class Bo>
#else
template<unsigned has_actv_sum, int has_vector_coeffs, typename T_cf, class Bi, class Bb, class Bc, class Bo>
#endif
ALWAYS_INLINE void activated_mmult_qdq_int16x16_transpose
(
        Bi & restrict bufA,
        Bb & restrict bufB,
        Bc & restrict bufC,
        Bo & restrict bufO,
        int8 * tdm_buf,
        int8 * wght_transpose_sb,
        int8 * restrict cfqdq_buf,
        T_cf * scalar_coeffs,
        int vector_coeffs,
        const ActivatedMMultKernelParamsTranspose &params
) {
    bool has_tlast=1;
    using Ti =  int16; //buffer_element_t<Bi>;
    using Tb =  int16; //buffer_element_t<Bb>;
    using To =  int16; //buffer_element_t<Bo>;
    using Tc =  float; //buffer_element_t<Bc>;

    GemmQdqint16x16_RT_Params* mmult_rt_params = reinterpret_cast<GemmQdqint16x16_RT_Params*>(params.qdq_buf);
    using namespace activated_mmult_qdq_int16x16_transpose_helpers;

    static_assert( !is_stream_type_v<Ti> || !is_stream_type_v<Tb> );
    static_assert( std::is_same_v<T_cf, Tc> );

    using T_c0 = get_next_type_t<T_cf>;
    constexpr unsigned V_qdq = ( std::is_same_v<T_cf, float> ) ? 32 : 16;
    constexpr unsigned N_c0 = has_vector_coeffs > 0 ? 64 : 1;
    constexpr unsigned N_c1 = ( has_vector_coeffs > 1 ? 64 : 1 ) * ( has_actv_sum > 0 );
    constexpr unsigned N_c2 = has_vector_coeffs > 1 ? 64 : 1;
    constexpr unsigned N_c3 = has_vector_coeffs > 1 ? 64 : 1;
    constexpr unsigned V_c0 = std::min( V_qdq, N_c0 );
    constexpr unsigned V_c1 = std::min( V_qdq, N_c1 );
    constexpr unsigned V_c2 = std::min( V_qdq, N_c2 );
    constexpr unsigned V_c3 = std::min( V_qdq, N_c3 );
    unsigned offset_c1 = ( vector_coeffs == 1 ? 0 : ( vector_coeffs == 0 ? 1 : 64 )) * sizeof( T_c0 ) / sizeof( T_cf );
    unsigned offset_c2 = offset_c1 + ( vector_coeffs <= 1 ? 1 : 64 ) * ( has_actv_sum > 0 );
    unsigned offset_c3 = offset_c2 + ( vector_coeffs <= 1 ? 1 : 64 );

    constexpr unsigned granM  = 16;
    constexpr unsigned granK  = 64;
    constexpr unsigned granN  = 64;
    constexpr unsigned il_block = 32;
    constexpr bool stream_kernel = is_stream_type_v<Ti> && is_stream_type_v<Tb>;
    constexpr bool stream_any = is_stream_type_v<Ti> || is_stream_type_v<Tb> || is_stream_type_v<To>;

    #ifdef DEBUG_KERNEL
    constexpr unsigned il_lr = 1;
    constexpr unsigned il_peel = 0;
    constexpr unsigned ol_lr = 1;
    constexpr unsigned it_lr = 1;
    constexpr bool pm_opt = 0;
    #else
    constexpr unsigned il_lr = 2;
    constexpr unsigned il_peel = 0;
    constexpr unsigned ol_lr = 1;
    constexpr unsigned it_lr = 1;
    constexpr bool pm_opt = 0;
    #endif


    int tsl_bound = params.tsl_bound;
    int il_bound = params.inner_g;
    int ol_bound = 1;
    int tl_bound = params.inner_time_iters;
    if constexpr( stream_kernel ) {
        il_bound *= ol_bound;
        ol_bound = 1;
    }

    unsigned keep_sum_iters = 1;
    unsigned keep_sum_cnt = 0;

    int zero_acc = 1;
    int zero_sum = 1;
    bool is_tdm = params.inner_time_iters > 0;

    int8 * pA;
    int8 * pAs;
    int8 * pATdm;
    Tb * pB;
    Tb * pBTr;
    int8 * restrict pBTw;
    Tc * pC;
    To * pO;

    dims_3d_t dimsA = params.dimsA.instantiate();
    dims_3d_t dimsB = params.dimsB.instantiate();
    dims_2d_t dimsQ = params.dimsQ.instantiate_step();
    dims_2d_t dimsAs = params.dimsAs.instantiate();
    dims_3d_t dimsW = params.dimsW.instantiate();

    /* Accumulators */
    m32x64acc32 chess_storage( em0 ) acc0 = chess_dont_care( m32x64acc32 );
    m32x64acc32 chess_storage( em1 ) acc1 = chess_dont_care( m32x64acc32 );

    v128int8 * pitr = ifm_tmp_buf;
    v128int8 * restrict pitw = ifm_tmp_buf;
    dims_2d_t ditr = dims_2d_from_steps( 16, 1, 0 );
    dims_2d_t ditw = dims_2d_from_steps( 16, 1, 0 );

    aie::vector<int8, 64> a0, a1;
    aie::vector mask = aie::zeros<int8, 64>( );
    mask[0] = 1;

    /* Indices */
    uint5_t chess_storage( i0 ) im0 = chess_copy( 0 );
    uint5_t chess_storage( i1 ) im1 = chess_copy( 0 );
    uint5_t chess_storage( i2 ) ib  = chess_copy( 0 );
    uint5_t chess_storage( i3 ) ia  = chess_copy( 0 );


    struct coeff_cache_type {
        alignas( 128 ) T_c0 c0[64];
        T_cf c1[N_c1];
        T_cf c2[N_c2*3];
        T_cf c3[N_c3];
    };

    coeff_cache_type * restrict coeff_cache_ptr = (coeff_cache_type*) cfqdq_buf;
    coeff_cache_type &coeff_cache = *coeff_cache_ptr;

    union sum_tdm_type{
        v64acc32 a;
        int32 s[64];
        aie::vector<int32,16> v[4];
        float f[64];
    };
    v32int32 * wght_transpose_tdm = ( v32int32 * ) wght_transpose_sb;
    sum_tdm_type * restrict sum_tdm = (sum_tdm_type*) tdm_buf;
    
    //collect qdq coefficients from weight stream
    auto store_stream_coeff = [&]( auto * pCo, auto * pCs, unsigned N, unsigned vec ) __attribute__(( always_inline )) {
        #pragma unroll
        for ( unsigned l = 0; l < 2; l++ ) {       
            if ( N == 1 ) {
                *pCo = *pCs;
            } else if ( vec == 0 ) {
                aie::store_v( pCo + 32 * l, aie::broadcast<float, 32>( pCs ));
            } else {
                aie::vector<float,32> load_stream =  (v32float) aie::utils::locate_in_register<4,aie::utils::AIE_RegFile::Vector>( get_ss_v128int8_weight( 1 )); 
                aie::store_v( pCo + 32 * l, load_stream );
            }
        }
    };

    auto coeff_fetch = [&]( ) __attribute__(( always_inline )) {
        if constexpr (!is_stream_type_v<Tb>)
        {
            T_c0 * c0 = ( T_c0* )( vector_coeffs > 0 ? pC : scalar_coeffs );
            T_cf * c1 = ( vector_coeffs > 1 ? pC : scalar_coeffs ) + offset_c1;
            T_cf * c2 = ( vector_coeffs > 1 ? pC : scalar_coeffs ) + offset_c2;
            T_cf * c3 = ( vector_coeffs > 1 ? pC : scalar_coeffs ) + offset_c3;
            copy_broadcast<N_c0>( coeff_cache.c0, c0, vector_coeffs > 0 );
            if constexpr(has_actv_sum > 0) {
                copy_broadcast<N_c1>( coeff_cache.c1, c1, vector_coeffs > 1 );
            }
            copy_broadcast<N_c2>( coeff_cache.c2, c2, vector_coeffs > 1, 1 );
            if constexpr(has_actv_sum >= 2) {
                copy_broadcast<N_c3>( coeff_cache.c3, c3, vector_coeffs > 1 );
            };
        } else {
            T_c0 * c0 = scalar_coeffs;
            T_cf * c1 = scalar_coeffs + offset_c1;
            T_cf * c2 = scalar_coeffs + offset_c2;
            T_cf * c3 = scalar_coeffs + offset_c3;
            store_stream_coeff( coeff_cache.c0, c0, N_c0, vector_coeffs > 0 );
            if constexpr(has_actv_sum > 0) {
                store_stream_coeff( coeff_cache.c1, c1, N_c1, vector_coeffs > 1);
            }
            store_stream_coeff( coeff_cache.c2, c2, N_c2, vector_coeffs > 1);
            if constexpr(has_actv_sum >= 2) {
                store_stream_coeff( coeff_cache.c3, c3, N_c3, vector_coeffs > 1 );
            };
        }
    };

    auto spill = [&]( auto * buffer, auto & accs ) __aie_inline {
        const unsigned idx = 0;
        *(( chess_protect_access v32int32 * ) buffer     ) = ( v32int32 ) extract_v32acc32( accs[idx], 0 );
        *(( chess_protect_access v32int32 * ) buffer + 1 ) = ( v32int32 ) extract_v32acc32( accs[idx], 1 );
    };

    auto restore = [&]( auto * buffer, auto & accs ) __aie_inline {
        const unsigned idx = 0;
        accs = insert( accs, idx, 0, *(( chess_protect_access v32acc32 * ) buffer     ));
        accs = insert( accs, idx, 1, *(( chess_protect_access v32acc32 * ) buffer + 1 ));
    };

    auto fetch_body = [&]( auto idx ) __aie_inline {
        insert_staging( read_v<128>(( int8 * ) pBTr ), ia );
        pBTr = add_byte( pBTr, 128 );
        ia++;
    };

    auto transpose_half = [&]( auto idx ) __aie_inline {
        constexpr unsigned idx2 = 0;
        acc0[idx2] = mul( mask );
        aie::store_v( pBTw, aie::vector( to_v64int8( acc0[idx2], 0 )));
        mask = aie::shuffle_up_rotate( mask, 1 ); 
        pBTw = add_3d_byte( pBTw, dimsB );
    };

    auto transpose_body = [&]( auto fetch ) __aie_inline {
        auto transpose_body2 = [&]( auto idx ) __aie_inline {
            transpose_half( idx );
            transpose_half( idx );
        };
        if ( fetch ) {
            aie::pipelined_loops<4, aie::LoopOptions{.peel_front = 1, .peel_back = 2}, aie::LoopOptions{.peel_front = 2, .peel_back = 1}>( il_block, transpose_body2, fetch_body );
            staging_to_matrix_m64x64int8( );
        } else {
            aie::pipelined_loop<4, aie::LoopOptions{.peel_front = 2, .peel_back = 1}>( il_block, transpose_body2 );
        }
    };

    auto weight_acquire = [&]( ) __attribute__(( always_inline )) {
        bufB.acquire( );
        pB = (int16*) bufB.data();
        if ( params.transpose ) {
            if constexpr( !is_stream_type_v<Tb> ) {
                split_i16( pB, granN / 2 * params.inner_g * params.X_g, params.is_split );
                bufC.acquire();
                pC = (float*) bufC.data();
            } else {
                //TODO STREAM W
                split_i16<1>( pB, granN / 2 * params.inner_g * params.X_g, params.is_split );
            }
        } else {
            pBTr = chess_copy( pB );
            pBTw = ( int8 * ) chess_copy( pB );
            spill( wght_transpose_tdm, acc0 );
            aie::pipelined_loop<4, aie::LoopOptions{.peel_front = 0, .peel_back = 0}>( il_block, fetch_body );
            staging_to_matrix_m64x64int8( );
            for ( unsigned i = 1; i < tsl_bound; i++ )
                chess_prepare_for_pipelining
                chess_loop_range( 2, )
            {
                transpose_body( 1 );
            }
            transpose_body( 0 );
            restore( wght_transpose_tdm, acc0 );
            bufC.acquire();
            pC = (float*) bufC.data();
        }
    };

    auto weight_release = [&]( ) __attribute__(( always_inline )) {
        bufB.release( );
        bufC.release( );
    };

    auto ifm_acquire = [&]( ) __attribute__(( always_inline )) {
        bufA.acquire( );
        pA = (int8 *)bufA.data();
        if constexpr( !is_stream_type_v<Ti> ) {
            pAs = chess_copy( pA );
        }
    };

    auto ifm_release = [&]( ) __attribute__(( always_inline )) {
        bufA.release( );
    };

    auto ofm_acquire = [&]( ) __attribute__(( always_inline )) {
        bufO.acquire( );
        pO = (int16*) bufO.data( );
    };

    auto ofm_release = [&]( ) __attribute__(( always_inline )) {
        bufO.release( );
    };

    auto weight_fetch = [&]( ) __attribute__(( always_inline )) {
        v128int8 w = read_v<64>( pB ).template cast_to<int8>( );
        pB = add_3d_byte( pB, dimsW );
        insert_staging(w , ib++, 2 + mmult_rt_params->sign_W);
        ib = chess_copy( ib );
    };

    auto compute_prepare = [&]( auto half ) __attribute__(( always_inline )) {
        staging_to_matrix_m64x64int8( );
        if ( params.transpose ) {
            if ( half ) {
                pB = add_3d_byte( add_elem( pB, -128 ), dimsB );
            } else {
                pB = add_byte( pB, -( granN - 1 ) * 128);
            }
        } else {
            if ( half ) {
                dimsW.count1 = 0;
                dimsW.count2 = 0;
                pB = add_byte( pB, ( granK * granN ) - 256 );
            }
        }
    };

    auto report_acc = [&]( unsigned l=0 ) __attribute__(( always_inline )) {
        v16acc64 acc = mul_elem_16(( v16int32 ) extract_v16acc32( acc0, l, 0 ), 1 );
        acc = mac_elem_16(( v16int32 ) extract_v16acc32( acc1, l, 0 ), 256, acc );
        acc = mac_elem_16(( v16int32 ) extract_v16acc32( acc0, l+16, 0 ), 65536, acc );
        chess_report( acc );
    };

    auto compute_execute = [&]<int pin_regs=-1>( auto half, bool zero_acc = 0, bool pin_regs_comb=0 ) __attribute__(( always_inline )) {
        if ( !half ) {
            aie::vector<int8,128> a = read_v<128>( pA );
            *pitw = a;
            std::tie( a0, a1 ) = a.template split<64>( );
            std::tie( a0, a1 ) = aie::interleave_unzip( a0, a1, 1 );
            a0 = locate_in_register<6>( a0 );
            a1 = locate_in_register<7>( a1 );
            pitw = add_2d_ptr( pitw, ditw );
            pA = add_byte( pA, 128 );


            if ( chess_manifest( zero_acc == 1 )) {
                acc0 = insert( acc0, im0, mul( a0, false,  false ));
                acc1 = insert( acc1, im1, mul( a1, mmult_rt_params->sign_A, false ));
            } else {
                acc0 = insert( acc0, im0, mac_conf( a0, false,  false, extract_v64acc32( acc0, im0 ), zero_acc ));
                acc1 = insert( acc1, im1, mac_conf( a1, mmult_rt_params->sign_A, false, extract_v64acc32( acc1, im1 ), zero_acc ));
            }
        } else {
            auto a = aie::vector<int8, 128>( *pitr );
            std::tie( a0, a1 ) = a.split<64>( );
            std::tie( a0, a1 ) = aie::interleave_unzip( a0, a1, 1 );
            a0 = locate_in_register<6>( a0 );
            a1 = locate_in_register<7>( a1 );
            pitr = add_2d_ptr( pitr, ditr );
            if ( chess_manifest( zero_acc == 1 )) {
                acc1 = insert( acc1, im1, mac( a0, false,  mmult_rt_params->sign_W, extract_v64acc32( acc1, im1 )));
                acc0 = insert( acc0, im0, mul( a1, mmult_rt_params->sign_A, mmult_rt_params->sign_W ));
            } else {
                acc1 = insert( acc1, im1, mac_conf( a0, false,  mmult_rt_params->sign_W, extract_v64acc32( acc1, im1 ), 0 ));
                acc0 = insert( acc0, im0, mac_conf( a1, mmult_rt_params->sign_A, mmult_rt_params->sign_W, extract_v64acc32( acc0, im0 ), zero_acc ));
            }
        }
    };

    auto compute_incr = [&]( ) __attribute__(( always_inline )) {
        im0++; im1++;
    };

    auto compute_finalize = [&]( auto half, unsigned j ) __attribute__(( always_inline )) {
        if constexpr( has_actv_sum >= 2 ) {
            uint5_t idx = 24 + 4 * half + j;
            if ( chess_manifest( zero_acc == 1 )) {
                acc1 = insert( acc1, idx, mul( aie::broadcast<int8,64>( 1 ), true, half ? mmult_rt_params->sign_W : false ));
            } else {
                acc1 = insert( acc1, idx, mac_conf( aie::broadcast<int8,64>( 1 ), true, half ? mmult_rt_params->sign_W : false, extract_v64acc32( acc1, idx ), zero_acc ));
            }
        }
        im1 = 0;
        if ( half ) {
            zero_acc = 0;
        } else {
            pA = add_3d_byte( pA, dimsA );
        }
    };

    auto sum_fetch = [&]( ) __attribute__(( always_inline )) {
        insert_staging( read_v<128>( pAs ), ib++ );     pAs = add_byte(pAs,128);
        ib = chess_copy( ib );
    };

    auto sum_block = [&]( unsigned j, bool zero_acc = 0 ) __attribute__(( always_inline )) {
        uint5_t idx = 16 + j;
        aie::vector<int8,64> lo = aie::select( int8( 0 ), int8( 1 ), aie::mask<64>::from_uint64( 0x5555555555555555ll ));
        aie::vector<int8,64> hi = aie::select( int8( 1 ), int8( 0 ), aie::mask<64>::from_uint64( 0x5555555555555555ll ));
        if ( chess_manifest( zero_acc == 1 )) {
            acc1 = insert( acc1, idx,     mul( lo, true, false ));
            acc1 = insert( acc1, idx + 4, mul( hi, true, mmult_rt_params->sign_A ));
        } else {
            acc1 = insert( acc1, idx,     mac_conf( lo, true, false,       extract_v64acc32( acc1, idx     ), zero_acc ));
            acc1 = insert( acc1, idx + 4, mac_conf( hi, true, mmult_rt_params->sign_A, extract_v64acc32( acc1, idx + 4 ), zero_acc ));
        }

        zero_acc = 0;
        pAs = add_2d_byte( pAs, dimsAs );
        return zero_acc;
    };

    auto sum_to_qdq = [&]( ) __attribute__(( always_inline )) {
        v16acc64 lo_a0 = ( v16acc64 ) extract_v32acc32( acc1, 16, 0 );
        v16acc64 lo_a1 = ( v16acc64 ) extract_v32acc32( acc1, 16, 1 );
        v16acc64 hi_a0 = ( v16acc64 ) extract_v32acc32( acc1, 20, 0 );
        v16acc64 hi_a1 = ( v16acc64 ) extract_v32acc32( acc1, 20, 1 );
        
        aie::vector<int32,32> lo0, lo1, hi0, hi1;
        lo0.insert( 0, aie::vector( to_v16int32_conf( lo_a0, 0, 1, 0, rnd_floor )));
        lo0.insert( 1, aie::vector( to_v16int32_conf( lo_a1, 0, 1, 0, rnd_floor )));
        lo1.insert( 0, aie::vector( to_v16int32_conf( lo_a0, 32, 1, 0, rnd_floor )));
        lo1.insert( 1, aie::vector( to_v16int32_conf( lo_a1, 32, 1, 0, rnd_floor )));
        hi0.insert( 0, aie::vector( to_v16int32_conf( hi_a0, 0, 1, 0, rnd_floor )));
        hi0.insert( 1, aie::vector( to_v16int32_conf( hi_a1, 0, 1, 0, rnd_floor )));
        hi1.insert( 0, aie::vector( to_v16int32_conf( hi_a0, 32, 1, 0, rnd_floor )));
        hi1.insert( 1, aie::vector( to_v16int32_conf( hi_a1, 32, 1, 0, rnd_floor )));
        
        aie::accum acc = mul_elem_32( lo0 + lo1, 1.0f );
        acc = mac_elem_32( hi0 + hi1, 256.0f, acc );
        aie::store_v( sum_tdm[0].f, acc.to_vector<float>( ));

        if constexpr( has_actv_sum >= 2 ) {
            #pragma unroll
            for ( unsigned p = 0; p < 2; p++ ) {
                using Tc0 = decltype( access<V_c0>( coeff_cache.c0, p ));
                using Tc3 = decltype( access<V_c3>( coeff_cache.c3, p ));
                Tc0 c0v = access<V_c0>( coeff_cache.c0, p );
                Tc3 c3v = access<V_c3>( coeff_cache.c3, p ), c3v2;
                if constexpr( V_c3 == 1 )
                    c3v2 = mul( c3v, 256.0f );
                else
                    c3v2 = aie::mul( c3v, 256.0f ).template to_vector<float>( );
                aie::vector lo = ( v32int32 ) extract_v32acc32( acc1, 24, p );
                aie::vector hi = ( v32int32 ) extract_v32acc32( acc1, 28, p );
                aie::accum<accfloat, V_qdq> acc;
                if constexpr( V_c0 == 1 ) {
                    acc = mac_elem_32( lo, c3v, c0v );
                } else {
                    acc.from_vector( c0v );
                    acc = aie::mac( acc, lo, c3v );
                }
                aie::store_v( sum_tdm[1].f + V_qdq * p, aie::accum( mac_elem_32( hi, c3v2, acc )).template to_vector<float>( ));
            }
        }
        
    };
    
    uint5_t idx_o0;
    uint5_t idx_o1;

    auto write_output = [&]( auto l, unsigned j ) __attribute__(( always_inline )) {
            constexpr unsigned count = 64 / V_qdq;
            using Tc0 = decltype( access<V_c0>( coeff_cache.c0, 0 ));
            using Tc2 = decltype( access<V_c2>( coeff_cache.c2, 0 ));
            Tc0 c0_vec[count];
            Tc2 c2_vec[count], s1c2[count], s2c2[count];
            if constexpr( has_actv_sum < 2 )
                unroll_times<count>( [&]( auto p ) __attribute__(( always_inline )) { c0_vec[p] = access<V_c0>( coeff_cache.c0, p ); });
            unroll_times<count>( [&]( auto p ) __attribute__(( always_inline )) { c2_vec[p] = access<V_c2>( coeff_cache.c2, p ); });
            unroll_times<count>( [&]( auto p ) __attribute__(( always_inline )) { s1c2[p] = access<V_c2>( coeff_cache.c2 + N_c2, p ); });
            unroll_times<count>( [&]( auto p ) __attribute__(( always_inline )) { s2c2[p] = access<V_c2>( coeff_cache.c2 + 2 * N_c2, p ); });
                
            unroll_times<count>( [&]( auto p ) __attribute__(( always_inline ))
            {
                aie::accum<acc32, 64> gemm_out0 = aie::accum<acc32,64>( extract_v64acc32( acc0, idx_o0 ));
                aie::accum<acc32, 64> gemm_out1 = aie::accum<acc32,64>( extract_v64acc32( acc1, idx_o0 ));
                aie::accum<acc32, 64> gemm_out2 = aie::accum<acc32,64>( extract_v64acc32( acc0, idx_o1 ));
                aie::accum<AccumulatorType_t<T_cf>,V_qdq> acc;

                if constexpr( has_actv_sum >= 2 ) {
                    acc.from_vector( access_loc<V_qdq,4>( sum_tdm[1].f, p ));
                } else {
                    if constexpr( V_c0 == 1 ) {
                        acc.from_vector( aie::broadcast<float,V_qdq>( c0_vec[p] ));
                    } else {
                        acc.from_vector( c0_vec[p] );
                    }
                }
                
                acc = aie_mac_extract( acc, gemm_out0, p, c2_vec[p] );
                acc = aie_mac_extract( acc, gemm_out1, p, s1c2[p] );
                acc = aie_mac_extract( acc, gemm_out2, p, s2c2[p] );
                if constexpr( has_actv_sum ) {
                    auto c1v = access<V_c1>( coeff_cache.c1, p );
                    if constexpr( V_c1 == 1 ) {
                        //aie::vector vx = locate_in_register<3>( aie::broadcast<float,32>( sum_tdm->f[l+granM*j] ));
                        aie::vector vx = aie::broadcast<float,32>( sum_tdm->f[l+granM*j] );
                        acc = aie::mac( acc, c1v, vx );
                    } else {
                        float vx = sum_tdm->f[l+granM*j];
                        acc = aie::mac( acc, c1v, vx );
                    }
                }
                
                /* Convert output to int16 */
                aie::vector<int16, V_qdq> out;
                if constexpr( std::is_same_v<T_cf, int32> )
                    out = aie::vector( to_v16int16( acc, -mmult_rt_params->shift_res, mmult_rt_params->sign_O ));
                else
                    out = acc.template to_vector_sign<int16>( mmult_rt_params->sign_O, -mmult_rt_params->shift_res );
                write_v(pO, out); pO = add_byte( pO, 2*V_qdq );
            });
            idx_o0++; idx_o1++;
    };

    #ifdef FILE_IO
    constexpr bool is_fileio = 0;
    #else
    constexpr bool is_fileio = 0;
    #endif
    
    /* Loop structure */
    /* DEBUG_KERNEL - TDM in registers */
    if constexpr( il_lr <= 1 || is_fileio ) {
        for ( int it=0; it < params.outer_time_iters; it++ )
        {
            zero_acc = 1;

            /* TDM loop. Time iterations in accumulation depth */
            for ( unsigned t = 0; t < tl_bound; t++ ) 
            {
                /* Buffers acquired in each time iteration */
                weight_acquire( );
                ifm_acquire( );

                bool do_actv_sum = has_actv_sum && ( keep_sum_cnt == 0 );

                /* Outer loop */
                for ( int j = 0; j < ol_bound; j++ )
                {
                    if ( do_actv_sum ) {
                        auto z = zero_acc;
                        for ( int i = 0; i < il_bound; i++ )
                        {
                            for ( unsigned l = 0; l < granM; l++ ) {
                                sum_fetch( );
                            }
                            staging_to_matrix_m64x64int8( );
                            ib = chess_copy( 0 );
                            z = sum_block( j, z );
                        }
                    }                
                    
                    /* Inner loop */
                    for ( int i = 0; i < il_bound; i++ )
                    {
                        for ( unsigned l = 0; l < 2 * granM; l++ ) {
                            weight_fetch( );
                        }
                        
                        compute_prepare( 0 );

                        /* First loop */   
                        for ( unsigned l = 0; l < granM; l++ ) {
                            compute_execute( 0, zero_acc );
                            compute_incr( );
                            weight_fetch( );
                            weight_fetch( );
                        }
                        compute_finalize( 0, j );
                        compute_prepare( 1 );

                        /* Second loop */
                        for ( unsigned l = 0; l < granM; l++ ) {
                            compute_execute( 1, zero_acc );
                            compute_incr( );
                        }
                        
                        compute_finalize( 1, j );
                    }
                }

                ifm_release( );
                if ( t < tl_bound - 1 )
                    weight_release( );
            }
            coeff_fetch( );
            weight_release( );

            ofm_acquire( );

            sum_to_qdq( );
            keep_sum_cnt = keep_sum_cnt == 0 ? keep_sum_iters - 1 : 0;
            idx_o0 = 0;
            idx_o1 = 16;
            
            for ( unsigned l = 0; l < granM; l++ ) {
                write_output( l, 0 );
            }

            ofm_release( );
        }
    } else if constexpr( has_actv_sum ) {
        /* ASYM QDQ pipelined version TDM in registers */

        /* Iteration loop ( Outer TD dimensions ) */
        for ( unsigned it = 0; it < params.outer_time_iters; it++ )
        {
            zero_acc = 1;
            zero_sum = 1;

            /* TDM loop ( Inner TD dimension ) */
            for ( unsigned t = 0; t < tl_bound; t++ )
            {                
                weight_acquire( );
                ifm_acquire( );

                /* Peeled inner loop */
                /* First Inner loop iteration */
                /* loop_range=8, peel_front=6-peel_back, peel_back=2 */
                pipelined_loop<8, 6, 2>( 2 * granM, [&]( auto l ) __attribute__((always_inline))
                {
                    weight_fetch( );
                });
                
                compute_prepare( 0 );

                compute_execute( 0, zero_acc );
                compute_incr( );
                pipelined_loop<4, 3>( granM-1, [&]( auto l ) __attribute__((always_inline))
                {
                    compute_execute( 0, zero_acc );
                    compute_incr( );
                    weight_fetch( );
                    weight_fetch( );
                });
                weight_fetch( );
                weight_fetch( );
                
                compute_finalize( 0, 0 );
                compute_prepare( 1 );
                
                compute_execute( 1, zero_acc );
                compute_incr( );
                pipelined_loop<4, 2>( granM-1, [&]( auto l ) __attribute__((always_inline))
                {
                    compute_execute( 1, zero_acc );
                    compute_incr( );
                    sum_fetch( );
                });
                sum_fetch( );
                compute_finalize( 1, 0 );
                staging_to_matrix_m64x64int8( );
                ib = chess_copy( 0 );

                /* Second inner loop iteration */
                pipelined_loop<8, 7, 2>( 2 * granM, [&]( auto l ) __attribute__((always_inline))
                {
                    weight_fetch( );
                });
                /* Do SUM before staging the next set of weights */
                zero_sum = sum_block( 0, zero_sum );
                
                compute_prepare( 0 );

                compute_execute( 0, zero_acc );
                compute_incr( );
                pipelined_loop<4, 1>( granM-1, [&]( auto l ) __attribute__((always_inline))
                {
                    compute_execute( 0, zero_acc );
                    compute_incr( );
                    weight_fetch( );
                    weight_fetch( );
                });
                weight_fetch( );
                weight_fetch( );

                compute_finalize( 0, 0 );
                compute_prepare( 1 );

                compute_execute( 1, zero_acc );
                compute_incr( );
                for ( unsigned l = 0; l < granM-1; l++ )
                    chess_pipeline_adjust_preamble( -1 )
                {
                    compute_execute( 1, zero_acc );
                    compute_incr( );
                    sum_fetch( );
                }
                sum_fetch( );
                compute_finalize( 1, 0 );
                staging_to_matrix_m64x64int8( );
                ib = chess_copy( 0 );
               
                sum_block( 0, zero_sum );

                ifm_release( );
                if ( t < tl_bound - 1 )
                    weight_release( );
            }
            coeff_fetch( );
            weight_release( );
            
            ofm_acquire( );
            sum_to_qdq( );
            idx_o0 = 0;
            idx_o1 = 16;

            /* Write output */
            for ( unsigned l = 0; l < granM; l++ )
            chess_allocate(Y : 8)
            chess_allocate(R : 31)
            {
                write_output( l, 0 );
            }

            ofm_release( );

        }
    } else {
        /* SYM QDQ pipelined version TDM in registers */

        /* Iteration loop ( Outer TD dimensions ) */
        for ( unsigned it = 0; it < params.outer_time_iters; it++ )
        {
            zero_acc = 1;

            /* TDM loop ( Inner TD dimension ) */
            for ( unsigned t = 0; t < tl_bound; t++ )
            {

                weight_acquire( );
                ifm_acquire( );

                /* First inner loop iteration */
                for( unsigned l = 0; l < 2 * granM-2; l++ )
                    chess_peel_pipelined_loop( 4 )
                {
                    weight_fetch( );
                }
                weight_fetch( );
                weight_fetch( );

                compute_prepare( 0 );

                compute_execute( 0, zero_acc );
                compute_incr( );
                for( unsigned l = 0; l < granM-1; l++ )
                    chess_peel_pipelined_loop( 1 )
                {
                    compute_execute( 0, zero_acc );
                    compute_incr( );
                    weight_fetch( );
                    weight_fetch( );
                }
                weight_fetch( );
                weight_fetch( );

                compute_finalize( 0, 0 );
                compute_prepare( 1 );

                compute_execute( 1, zero_acc );
                compute_incr( );
                for( unsigned l = 0; l < granM-1; l++ )
                {
                    compute_execute( 1, zero_acc );
                    compute_incr( );
                    weight_fetch( );
                    weight_fetch( );
                }
                weight_fetch( );
                weight_fetch( );
                compute_finalize( 1, 0 );
                
                /* Second inner loop iteration */
                compute_prepare( 0 );

                compute_execute( 0, zero_acc );
                compute_incr( );
                for( unsigned l = 0; l < granM-1; l++ )
                {
                    compute_execute( 0, zero_acc );
                    compute_incr( );
                    weight_fetch( );
                    weight_fetch( );
                }
                weight_fetch( );
                weight_fetch( );

                compute_finalize( 0, 0 );
                compute_prepare( 1 );

                compute_execute( 1, zero_acc );
                compute_incr( );
                for( unsigned l = 0; l < granM-2; l++ )
                {
                    compute_execute( 1, zero_acc );
                    compute_incr( );
                }
                compute_execute( 1, zero_acc );
                compute_incr( );

                compute_finalize( 1, 0 );

                ifm_release( );
                if ( t < tl_bound - 1 )
                    weight_release( );
            }
            coeff_fetch( );
            weight_release( );

            ofm_acquire( );
            idx_o0 = 0;
            idx_o1 = 16;
            
            /* Write Output */
            for ( unsigned l = 0; l < granM; l++ )
            {
                write_output( l, 0 );
            }

            ofm_release( );

        }
    }

    event1( );
}


#endif // __ACTIVATED_MMULT_QDQ_INT16X16_TRANSPOSE_TEMPLATE_H__