/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __ARCH_OX8_KERNEL_HELPERS_H__
#define __ARCH_OX8_KERNEL_HELPERS_H__

#include <aie_api/aie.hpp>
#include <aie_api/utils.hpp>
#include "include/kernel_helpers.h"
#include "qdq/qdq_kernel_helpers.h"
#include "common/ml_params.h"

#ifndef ALWAYS_INLINE
#ifdef  __clang__
#define ALWAYS_INLINE inline __attribute__((always_inline))
#else
#define ALWAYS_INLINE inline
#endif
#endif

#ifndef __KERNEL_HELPERS_H__

struct Accs4_i32 property(keep_in_registers) {
    v32acc32 a0;
    v32acc32 a1;
    v32acc32 a2;
    v32acc32 a3;
};

inline void undef(Accs4_i32 &accs) {
    accs.a0 = undef_v32acc32();
    accs.a1 = undef_v32acc32();
    accs.a2 = undef_v32acc32();
    accs.a3 = undef_v32acc32();
}
#endif // __KERNEL_HELPERS_H__

#ifndef __ML_PARAMS_H__
enum KernelConfig {
    KC_ZERO,
    KC_RESULT4,
    KC_RESULT8,
    KC_RESULT16,
    KC_RESULT32,
    KC_CASC,
    KC_TDM16,
    KC_TDM32,
    KC_TDM64,
    KC_TDM16_CASC,
    KC_TDM32_CASC,
    KC_TDM64_CASC,
    KC_RESULT32_CASC,
};
#endif // __ML_PARAMS_H__

#ifdef __clang__
#define RSTRCT __restrict
#else
#define RSTRCT
#endif

inline void load_full(v32acc32 &acc, v8int32 chess_storage(DM_bankC) *& ptr) {
    v8acc32 chess_storage(DM_bankC) * pIn = (v8acc32 chess_storage(DM_bankC) *) ptr;
    //acc = undef_v32acc32();
    acc = insert(acc, 0, *pIn++);
    acc = insert(acc, 1, *pIn++);
    acc = insert(acc, 2, *pIn++);
    acc = insert(acc, 3, *pIn);
    ptr = (v8int32 chess_storage(DM_bankC) *) pIn;
}
inline void load_full(v32acc32 &acc, v8int32 chess_storage(DM_bankD) *& ptr) {
    v8acc32 chess_storage(DM_bankD) * pIn = (v8acc32 chess_storage(DM_bankD) *) ptr;
    //acc = undef_v32acc32();
    acc = insert(acc, 0, *pIn++);
    acc = insert(acc, 1, *pIn++);
    acc = insert(acc, 2, *pIn++);
    acc = insert(acc, 3, *pIn);
    ptr = (v8int32 chess_storage(DM_bankD) *) pIn;
}

inline void store_full(v8int32 chess_storage(DM_bankC) * RSTRCT & ptr, v32acc32 acc) {
    v8acc32 chess_storage(DM_bankC) * pOut = (v8acc32 chess_storage(DM_bankC) *) ptr;
    *pOut++ = extract_v8acc32(acc, 0);
    *pOut++ = extract_v8acc32(acc, 1);
    *pOut++ = extract_v8acc32(acc, 2);
    *pOut   = extract_v8acc32(acc, 3);
    ptr = (v8int32 chess_storage(DM_bankC) *) pOut;
}
inline void store_full(v8int32 chess_storage(DM_bankD) * RSTRCT & ptr, v32acc32 acc) {
    v8acc32 chess_storage(DM_bankD) * pOut = (v8acc32 chess_storage(DM_bankD) *) ptr;
    *pOut++ = extract_v8acc32(acc, 0);
    *pOut++ = extract_v8acc32(acc, 1);
    *pOut++ = extract_v8acc32(acc, 2);
    *pOut   = extract_v8acc32(acc, 3);
    ptr = (v8int32 chess_storage(DM_bankD) *) pOut;
}



inline void load_full(aie::accum<acc32,32> &acc, v8int32 chess_storage(DM_bankC) *& ptr) {
    using acc_t = aie::accum<acc32,8>;
    aie::vector<int32,8> v;
    acc.insert( 0, acc_t( v = *ptr++ ));
    acc.insert( 1, acc_t( v = *ptr++ ));
    acc.insert( 2, acc_t( v = *ptr++ ));
    acc.insert( 3, acc_t( v = *ptr ));
}
inline void load_full(aie::accum<acc32,32> &acc, v8int32 chess_storage(DM_bankD) *& ptr) {
    using acc_t = aie::accum<acc32,8>;
    aie::vector<int32,8> v;
    acc.insert( 0, acc_t( v = *ptr++ ));
    acc.insert( 1, acc_t( v = *ptr++ ));
    acc.insert( 2, acc_t( v = *ptr++ ));
    acc.insert( 3, acc_t( v = *ptr ));
}

inline void load_half(aie::accum<acc32,32> &acc, v16int16 chess_storage(DM_bankC) *& ptr, uint6_t shift, bool sign) {
    //acc = undef_v32acc32();
    ptr = chess_copy( ptr );
    v16int16 chess_storage(DM_bankC)  * pIn = ( v16int16 chess_storage(DM_bankC)* ) ptr;
    acc = insert(acc, 0, sups(*ptr++, shift, sign));
    pIn = chess_copy(ptr);
    acc = insert(acc, 1, sups(*ptr,   shift, sign));
}
inline void load_half(aie::accum<acc32,32> &acc, v16int16 chess_storage(DM_bankD) *& ptr, uint6_t shift, bool sign) {
    //acc = undef_v32acc32();
    ptr = chess_copy( ptr );
    v16int16 chess_storage(DM_bankD)  * pIn = ( v16int16 chess_storage(DM_bankD)* ) ptr;
    acc = insert(acc, 0, sups(*ptr++, shift, sign));
    pIn = chess_copy(ptr);
    acc = insert(acc, 1, sups(*ptr,   shift, sign));
}

// Used in strix to work around 512bit load alignment restrictions
template <typename Tw>
[[gnu::always_inline]]
inline aie::vector<int8, 64> load64(Tw *pW) {
    return aie::concat( aie::load_v<32>(pW),
                        aie::load_v<32>(pW + 32) );
}

inline void store_half(v16int16 chess_storage(DM_bankC) * RSTRCT & ptr, aie::accum<acc32,32> acc, uint6_t shift, bool sign) {
    v16int16 chess_storage(DM_bankC) * RSTRCT pOut = ( v16int16 chess_storage(DM_bankC) * ) ptr;
    *ptr++ = lsrs(extract_v16acc32(acc, 0), shift, sign);
    *ptr   = lsrs(extract_v16acc32(acc, 1), shift, sign);
}
inline void store_half(v16int16 chess_storage(DM_bankD) * RSTRCT & ptr, aie::accum<acc32,32> acc, uint6_t shift, bool sign) {
    v16acc32 chess_storage(DM_bankD) * RSTRCT pOut = ( v16acc32 chess_storage(DM_bankD)* ) ptr;
    *ptr++ = lsrs(extract_v16acc32(acc, 0), shift, sign);
    *ptr   = lsrs(extract_v16acc32(acc, 1), shift, sign);
}

inline void store_full(v8int32 chess_storage(DM_bankC) * RSTRCT & ptr, aie::accum<acc32,32> acc) {
    v8acc32 chess_storage(DM_bankC) * pOut = (v8acc32 chess_storage(DM_bankC) *) ptr;
    *pOut++ = acc.extract<8>( 0 );
    *pOut++ = acc.extract<8>( 1 );
    *pOut++ = acc.extract<8>( 2 );
    *pOut   = acc.extract<8>( 3 );
    ptr = (v8int32 chess_storage(DM_bankC) *) pOut;
}
inline void store_full(v8int32 chess_storage(DM_bankD) * RSTRCT & ptr, aie::accum<acc32,32> acc) {
    v8acc32 chess_storage(DM_bankD) * pOut = (v8acc32 chess_storage(DM_bankD) *) ptr;
    *pOut++ = acc.extract<8>( 0 );
    *pOut++ = acc.extract<8>( 1 );
    *pOut++ = acc.extract<8>( 2 );
    *pOut   = acc.extract<8>( 3 );
    ptr = (v8int32 chess_storage(DM_bankD) *) pOut;
}

#if __AIE_ARCH__ >= 21
inline aie::accum<acc32, 32> mac_4x8_8x8_C(aie::vector<int8, 64> x,
                                           aie::vector<int8, 64> y,
                                           aie::accum<acc32, 32> acc,
                                           int zero_acc, int casc) {
    aie::mmul<4, 8, 8, int8, int8, acc32> mm{aie::op_zero(acc, zero_acc)};
    mm.mac(x.extract<32>(0), y);
    return mm;
}
#else
inline v32acc32 mac_4x8_8x8_C(v64int8 x, v64int8 y, v32acc32 acc, int zero_acc, int casc) {
    acc = mac_4x8_8x8_conf(x, y, acc, zero_acc, 0, 0, 0);
    return acc;
}
#endif


#if __AIE_ARCH__ == 20
#define DM CM

struct fifo_state_t {
    v64int8 state;
    int pos;
};

template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_popx( T *& ptr, fifo_state_t &fifo, int step, int mask ) {
    using Tm = aie_dm_resource_set_t<v32int8,aie_dm_resource_get_v<T>>;
    char * p = ( char * )ptr + 31;
    int frac = (( long )p & 31 ) + 33;
    v64int8 val = set_v64int8( 0, *floor(( Tm * ) p ));     p += 32;
    val = insert( val, 1, *floor(( Tm * ) p ));
    val = shiftx( fifo.state, val, step + 2, frac );
    fifo.state = val;
    fifo.pos = frac;
    ptr = ( T * )( byte_incr( p, 32 ) - 31 );
    return val;
}
template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_popx( T * restrict & ptr, fifo_state_t &fifo, int step, int mask ) {
    using Tm = aie_dm_resource_set_t<v32int8,aie_dm_resource_get_v<T>>;
    char * p = ( char * )ptr + 31;
    int frac = (( long )p & 31 ) + 33;
    v64int8 val = set_v64int8( 0, *floor(( Tm * ) p ));     p += 32;
    val = insert( val, 1, *floor(( Tm * ) p ));
    val = shiftx( fifo.state, val, step + 2, frac );
    fifo.state = val;
    fifo.pos = frac;
    ptr = ( T * )( byte_incr( p, 32 ) - 31 );
    return val;
}
template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_popx_3d_byte( T *& ptr, fifo_state_t &fifo, int step, int mask, int inc2, int num0, addr_t &cnt0, int inc0, int num1, addr_t &cnt1, int inc1 ) {
    using Tm = aie_dm_resource_set_t<v32int8,aie_dm_resource_get_v<T>>;
    char * p = ( char * )ptr + 31;
    int frac = (( long )p & 31 ) + 33;
    v64int8 val = set_v64int8( 0, *floor(( Tm * ) p ));     p += 32;
    val = insert( val, 1, *floor(( Tm * ) p ));
    val = shiftx( fifo.state, val, step + 2, frac );
    fifo.state = val;
    fifo.pos = frac;
    ptr = ( T * )( add_3d_byte( p, inc2+32, num0, cnt0, inc0+32, num1, cnt1, inc1+32 ) - 31 );
    return val;
}
template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_popx_3d_byte( T * restrict & ptr, fifo_state_t &fifo, int step, int mask, int inc2, int num0, addr_t &cnt0, int inc0, int num1, addr_t &cnt1, int inc1 ) {
    using Tm = aie_dm_resource_set_t<v32int8,aie_dm_resource_get_v<T>>;
    char * p = ( char * )ptr + 31;
    int frac = (( long )p & 31 ) + 33;
    v64int8 val = set_v64int8( 0, *floor(( Tm * ) p ));     p += 32;
    val = insert( val, 1, *floor(( Tm * ) p ));
    val = shiftx( fifo.state, val, step + 2, frac );
    fifo.state = val;
    fifo.pos = frac;
    ptr = ( T * )( add_3d_byte( p, inc2+32, num0, cnt0, inc0+32, num1, cnt1, inc1+32 ) - 31 );
    return val;
}
template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_pop_3d_byte( T *& ptr, fifo_state_t &fifo, int inc2, int num0, addr_t &cnt0, int inc0, int num1, addr_t &cnt1, int inc1 ) {
    using Tm = aie_dm_resource_set_t<v32int8,aie_dm_resource_get_v<T>>;
    char * p = ( char * )ptr + 31;
    v64int8 val = set_v64int8( 0, *floor(( Tm * ) p ));     p += 32;
    val = insert( val, 1, *floor(( Tm * ) p ));
    val = shift_bytes( fifo.state, val, fifo.pos );
    ptr = ( T * )( add_3d_byte( p, inc2+32, num0, cnt0, inc0+32, num1, cnt1, inc1+32 ) - 31 );
    return val;
}
template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_pop_3d_byte( T * restrict & ptr, fifo_state_t &fifo, int inc2, int num0, addr_t &cnt0, int inc0, int num1, addr_t &cnt1, int inc1 ) {
    using Tm = aie_dm_resource_set_t<v32int8,aie_dm_resource_get_v<T>>;
    char * p = ( char * )ptr + 31;
    v64int8 val = set_v64int8( 0, *floor(( Tm * ) p ));     p += 32;
    val = insert( val, 1, *floor(( Tm * ) p ));
    val = shift_bytes( fifo.state, val, fifo.pos );
    ptr = ( T * )( add_3d_byte( p, inc2+32, num0, cnt0, inc0+32, num1, cnt1, inc1+32 ) - 31 );
    return val;
}

#else

template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_popx_3d_byte( T *& ptr, fifo_state_t &fifo, int step, int mask, int inc2, int num0, addr_t &cnt0, int inc0, int num1, addr_t &cnt1, int inc1 ) {
    v64int8 val = fifo_ld_popx( ptr, fifo, step, mask );
    ptr = add_3d_byte( ptr, inc2, num0, cnt0, inc0, num1, cnt1, inc1 );
    return val;
}
template<typename T>
inline aie_dm_resource_remove_t<T> fifo_ld_popx_3d_byte( T * restrict & ptr, fifo_state_t &fifo, int step, int mask, int inc2, int num0, addr_t &cnt0, int inc0, int num1, addr_t &cnt1, int inc1 ) {
    v64int8 val = fifo_ld_popx( ptr, fifo, step, mask );
    ptr = add_3d_byte( ptr, inc2, num0, cnt0, inc0, num1, cnt1, inc1 );
    return val;
}
#endif




inline int ceil( int val, int significance ) {
    int mask = significance - 1;
    return ( val + mask ) & ~mask;
}

inline int div_ceil( int val, int div ) {
    int mask = div - 1;
    return ( val + mask ) / div;
}

inline int div_ceil_p2( int val, int div ) {
    int mask = div - 1;
    return ( val + mask ) >> ( 31 - clb( div ));
}


#ifndef __KERNEL_HELPERS_H__
inline int add_dimension( int &reset, int num, int incr ) {
    int mod = reset + incr;
    reset -= num * incr;
    return mod;
}
#endif // __KERNEL_HELPERS_H__


#define FIX_CM_0(accs) {                                                        \
    v32acc32 chess_storage(cm0) tmp0 = accs.a0; accs.a0 = chess_copy( tmp0 );   \
    v32acc32 chess_storage(cm1) tmp1 = accs.a1; accs.a1 = chess_copy( tmp1 );   \
    v32acc32 chess_storage(cm2) tmp2 = accs.a2; accs.a2 = chess_copy( tmp2 );   \
    v32acc32 chess_storage(cm3) tmp3 = accs.a3; accs.a3 = chess_copy( tmp3 );   \
}
#define FIX_CM_1(accs) {                                                        \
    v32acc32 chess_storage(cm4) tmp0 = accs.a0; accs.a0 = chess_copy( tmp0 );   \
    v32acc32 chess_storage(cm5) tmp1 = accs.a1; accs.a1 = chess_copy( tmp1 );   \
    v32acc32 chess_storage(cm6) tmp2 = accs.a2; accs.a2 = chess_copy( tmp2 );   \
    v32acc32 chess_storage(cm7) tmp3 = accs.a3; accs.a3 = chess_copy( tmp3 );   \
}
#define FIX_CM(accs, fix_dm)            \
    if (fix_dm == 1) FIX_CM_0(accs);    \
    if (fix_dm == 2) FIX_CM_1(accs);


inline __attribute__((always_inline)) void wipe_registers() {
    char chess_storage(r4) r4 = get_rnd();
    char chess_storage(r5) r5 = chess_copy( r4 );
    char chess_storage(r6) r6 = chess_copy( r5 );
    char chess_storage(r7) r7 = chess_copy( r6 );
    char chess_storage(r8) r8 = chess_copy( r7 );
    char chess_storage(r9) r9 = chess_copy( r8 );
    char chess_storage(r10) r10 = chess_copy( r9 );
    char chess_storage(r11) r11 = chess_copy( r10 );
    addr_t chess_storage(dc0) dc0 = chess_copy( r11 );
    addr_t chess_storage(dc1) dc1 = chess_copy( dc0 );
    addr_t chess_storage(dc2) dc2 = chess_copy( dc1 );
    addr_t chess_storage(dc3) dc3 = chess_copy( dc2 );
    addr_t chess_storage(dc4) dc4 = chess_copy( dc3 );
    addr_t chess_storage(dc5) dc5 = chess_copy( dc4 );
    addr_t chess_storage(dc6) dc6 = chess_copy( dc5 );
    addr_t chess_storage(dc7) dc7 = chess_copy( dc6 );
    set_rnd( r11 );
}

// may add
template<unsigned reg>
ALWAYS_INLINE int locate_in_register( int val ) {
        if constexpr( reg ==  0 ) { auto __aie_register( r0  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  1 ) { auto __aie_register( r1  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  2 ) { auto __aie_register( r2  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  3 ) { auto __aie_register( r3  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  4 ) { auto __aie_register( r4  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  5 ) { auto __aie_register( r5  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  6 ) { auto __aie_register( r6  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  7 ) { auto __aie_register( r7  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  8 ) { auto __aie_register( r8  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  9 ) { auto __aie_register( r9  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg == 10 ) { auto __aie_register( r10 ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg == 11 ) { auto __aie_register( r11 ) tmp = val; val = __aie_copy( tmp ); }
        return val;
}

template<unsigned reg>
ALWAYS_INLINE v64int8 locate_in_register( v64int8 val ) {
    //if constexpr( std::is_same_v<T, v64int8> ) {
        if constexpr( reg ==  0 ) { auto __aie_register( x0  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  1 ) { auto __aie_register( x1  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  2 ) { auto __aie_register( x2  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  3 ) { auto __aie_register( x3  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  4 ) { auto __aie_register( x4  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  5 ) { auto __aie_register( x5  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  6 ) { auto __aie_register( x6  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  7 ) { auto __aie_register( x7  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  8 ) { auto __aie_register( x8  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  9 ) { auto __aie_register( x9  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg == 10 ) { auto __aie_register( x10 ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg == 11 ) { auto __aie_register( x11 ) tmp = val; val = __aie_copy( tmp ); }
    //}
    return val;
}
template<unsigned reg>
ALWAYS_INLINE v128int8 locate_in_register( v128int8 val ) {
        if constexpr( reg ==  0 ) { auto __aie_register( y0  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  1 ) { auto __aie_register( y1  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  2 ) { auto __aie_register( y2  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  3 ) { auto __aie_register( y3  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  4 ) { auto __aie_register( y4  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  5 ) { auto __aie_register( y5  ) tmp = val; val = __aie_copy( tmp ); }
    return val;
}
template<unsigned reg, aie::Vector Tv>
ALWAYS_INLINE Tv locate_in_register_t( Tv val ) {
    if constexpr( val.bits() == 512 ) {
        if constexpr( reg ==  0 ) { auto __aie_register( x0  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  1 ) { auto __aie_register( x1  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  2 ) { auto __aie_register( x2  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  3 ) { auto __aie_register( x3  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  4 ) { auto __aie_register( x4  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  5 ) { auto __aie_register( x5  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  6 ) { auto __aie_register( x6  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  7 ) { auto __aie_register( x7  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  8 ) { auto __aie_register( x8  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg ==  9 ) { auto __aie_register( x9  ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg == 10 ) { auto __aie_register( x10 ) tmp = val; val = __aie_copy( tmp ); }
        if constexpr( reg == 11 ) { auto __aie_register( x11 ) tmp = val; val = __aie_copy( tmp ); }
    }
    return val;
}

#ifndef __QDQ_KERNEL_HELPERS_H__
template<unsigned reg, aie::Accum Ta>
inline Ta locate_in_register( Ta acc ) {
    if constexpr( acc.bits( ) == 1024 ) {
      #if __AIE_ARCH__ == 20
        if constexpr( reg ==  0 ) { auto __aie_register( cm0  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  1 ) { auto __aie_register( cm1  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  2 ) { auto __aie_register( cm2  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  3 ) { auto __aie_register( cm3  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  4 ) { auto __aie_register( cm4  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  5 ) { auto __aie_register( cm5  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  6 ) { auto __aie_register( cm6  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  7 ) { auto __aie_register( cm7  ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  8 ) { auto __aie_register( cm8  ) tmp = acc; acc = __aie_copy( tmp ); }
      #elif __AIE_ARCH__ >= 21
        if constexpr( reg ==  0 ) { auto __aie_register( cml0 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  1 ) { auto __aie_register( cmh0 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  2 ) { auto __aie_register( cml1 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  3 ) { auto __aie_register( cmh1 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  4 ) { auto __aie_register( cml2 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  5 ) { auto __aie_register( cmh2 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  6 ) { auto __aie_register( cml3 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  7 ) { auto __aie_register( cmh3 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  8 ) { auto __aie_register( cml4 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg ==  9 ) { auto __aie_register( cmh4 ) tmp = acc; acc = __aie_copy( tmp ); }
       //#if __AIE_ARCH__ >= 22
       // if constexpr( reg == 10 ) { auto __aie_register( cml5 ) tmp = acc; acc = __aie_copy( tmp ); }
       // if constexpr( reg == 11 ) { auto __aie_register( cmh5 ) tmp = acc; acc = __aie_copy( tmp ); }
       // if constexpr( reg == 12 ) { auto __aie_register( cml6 ) tmp = acc; acc = __aie_copy( tmp ); }
       // if constexpr( reg == 13 ) { auto __aie_register( cmh6 ) tmp = acc; acc = __aie_copy( tmp ); }
       // if constexpr( reg == 14 ) { auto __aie_register( cml7 ) tmp = acc; acc = __aie_copy( tmp ); }
       // if constexpr( reg == 15 ) { auto __aie_register( cmh7 ) tmp = acc; acc = __aie_copy( tmp ); }
       //#endif
      #endif

    } else if constexpr( acc.bits( ) == 2048 ) {
      #if __AIE_ARCH__ == 20
        if constexpr( reg == 0 ) {
            auto __aie_register( cm0 ) tmp0 = acc.template extract<32>( 0 );
            auto __aie_register( cm1 ) tmp1 = acc.template extract<32>( 1 );
            acc = aie::concat( __aie_copy( tmp0 ), __aie_copy( tmp1 )); }
        if constexpr( reg == 1 ) {
            auto __aie_register( cm2 ) tmp0 = acc.template extract<32>( 0 );
            auto __aie_register( cm3 ) tmp1 = acc.template extract<32>( 1 );
            acc = aie::concat( __aie_copy( tmp0 ), __aie_copy( tmp1 )); }
        if constexpr( reg == 2 ) {
            auto __aie_register( cm4 ) tmp0 = acc.template extract<32>( 0 );
            auto __aie_register( cm5 ) tmp1 = acc.template extract<32>( 1 );
            acc = aie::concat( __aie_copy( tmp0 ), __aie_copy( tmp1 )); }
        if constexpr( reg == 3 ) {
            auto __aie_register( cm6 ) tmp0 = acc.template extract<32>( 0 );
            auto __aie_register( cm7 ) tmp1 = acc.template extract<32>( 1 );
            acc = aie::concat( __aie_copy( tmp0 ), __aie_copy( tmp1 )); }
      #elif __AIE_ARCH__ >= 21
        if constexpr( reg == 0 ) { auto __aie_register( dm0 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg == 1 ) { auto __aie_register( dm1 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg == 2 ) { auto __aie_register( dm2 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg == 3 ) { auto __aie_register( dm3 ) tmp = acc; acc = __aie_copy( tmp ); }
        if constexpr( reg == 4 ) { auto __aie_register( dm4 ) tmp = acc; acc = __aie_copy( tmp ); }
       //#if __AIE_ARCH__ >= 22
       // if constexpr( reg == 5 ) { auto __aie_register( dm5 ) tmp = acc; acc = __aie_copy( tmp ); }
       // if constexpr( reg == 6 ) { auto __aie_register( dm6 ) tmp = acc; acc = __aie_copy( tmp ); }
       // if constexpr( reg == 7 ) { auto __aie_register( dm7 ) tmp = acc; acc = __aie_copy( tmp ); }
       //#endif
      #endif
    } else {
        chess_error( "locate_in_register not yet implemented for this type" );
    }
    return acc;
}


template<unsigned reg>
inline aie::mmul<4, 8, 8, int8, int8> locate_in_register( aie::mmul<4, 8, 8, int8, int8> mm ) {
    #if __AIE_ARCH__ == 20
    if constexpr( reg == 0 ) { auto __aie_register( cm0 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 1 ) { auto __aie_register( cm1 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 2 ) { auto __aie_register( cm2 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 3 ) { auto __aie_register( cm3 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 4 ) { auto __aie_register( cm4 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 5 ) { auto __aie_register( cm5 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 6 ) { auto __aie_register( cm6 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 7 ) { auto __aie_register( cm7 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    if constexpr( reg == 8 ) { auto __aie_register( cm8 ) tmp = mm.data; mm.data = __aie_copy( tmp ); }
    #elif __AIE_ARCH__ >= 21
    chess_error( "locate_in_register not yet implemented for this type" );
    #endif
    return mm;
}
template<unsigned reg>
inline aie::mmul<8, 8, 8, int8, int8> locate_in_register( aie::mmul<8, 8, 8, int8, int8> mm ) {
    #if __AIE_ARCH__ == 20
    if constexpr( reg == 0 ) { aie::accum<acc32,32> __aie_register( cm0 ) tmp0 = mm.data[0]; mm.data[0] = __aie_copy( tmp0 );
                               aie::accum<acc32,32> __aie_register( cm1 ) tmp1 = mm.data[1]; mm.data[1] = __aie_copy( tmp1 ); }
    if constexpr( reg == 1 ) { aie::accum<acc32,32> __aie_register( cm2 ) tmp0 = mm.data[0]; mm.data[0] = __aie_copy( tmp0 );
                               aie::accum<acc32,32> __aie_register( cm3 ) tmp1 = mm.data[1]; mm.data[1] = __aie_copy( tmp1 ); }
    if constexpr( reg == 2 ) { aie::accum<acc32,32> __aie_register( cm4 ) tmp0 = mm.data[0]; mm.data[0] = __aie_copy( tmp0 );
                               aie::accum<acc32,32> __aie_register( cm5 ) tmp1 = mm.data[1]; mm.data[1] = __aie_copy( tmp1 ); }
    if constexpr( reg == 3 ) { aie::accum<acc32,32> __aie_register( cm6 ) tmp0 = mm.data[0]; mm.data[0] = __aie_copy( tmp0 );
                               aie::accum<acc32,32> __aie_register( cm7 ) tmp1 = mm.data[1]; mm.data[1] = __aie_copy( tmp1 ); }
    #elif __AIE_ARCH__ >= 21
    if constexpr( reg == 0 ) { aie::accum<acc32,64> __aie_register( dm0 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    if constexpr( reg == 1 ) { aie::accum<acc32,64> __aie_register( dm1 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    if constexpr( reg == 2 ) { aie::accum<acc32,64> __aie_register( dm2 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    if constexpr( reg == 3 ) { aie::accum<acc32,64> __aie_register( dm3 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    if constexpr( reg == 4 ) { aie::accum<acc32,64> __aie_register( dm4 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    //#if __AIE_ARCH__ >= 22
    //if constexpr( reg == 5 ) { aie::accum<acc32,64> __aie_register( dm5 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    //if constexpr( reg == 6 ) { aie::accum<acc32,64> __aie_register( dm6 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    //if constexpr( reg == 7 ) { aie::accum<acc32,64> __aie_register( dm7 ) tmp = mm.to_accum( ); mm = aie::mmul<8, 8, 8, int8, int8>( __aie_copy( tmp )); }
    //#endif
    #endif
    return mm;
}

template<typename T, unsigned size, unsigned reg_start, unsigned ... Is>
inline void locate_in_register_helper( T (&arr)[size], std::integer_sequence<unsigned, Is...> const & ) {
   (( arr[Is] = locate_in_register<reg_start+Is>( arr[Is] )), ... );
}

template<unsigned reg_start=0, unsigned size, typename T>
inline void locate_in_register( T (&arr)[size] ) {
    locate_in_register_helper<T,size,reg_start>( arr, std::make_integer_sequence<unsigned, size>{} );
}
#endif // __QDQ_KERNEL_HELPERS_H__



#endif

