#ifndef __SLICE_16B_INNER_IMPL_H__
#define __SLICE_16B_INNER_IMPL_H__

#include "aie_api/utils.hpp"
#include "aie_api/aie.hpp"
#include "common/api_loop_pipe_helper.hpp"
#include "common/ml_params.h"

struct SliceParams {
    uint32_t loop_s1;
    uint32_t loop_s2;
    uint32_t s2_offset;
    uint16_t numS1;
    uint16_t incS1;
    uint16_t numS2; 
    uint16_t incS2;
    uint16_t incO1;
    uint16_t incO2;
    uint16_t size1;
    uint16_t size2;
    uint32_t mask1;
    uint32_t mask2;
};

template<typename T>
struct int32_or_larger { using type = int32_t; };
template<> struct int32_or_larger<int64_t> { using type = int64_t; };
template<typename T>
using int32_or_larger_t = typename int32_or_larger<T>::type;

/*
Kernel to slice one tensor at the inner dimensions
Slicing is controlled through addressing parameters in SliceParams
Reference how these are calculated are in the metafile
innerC/slice_inner defines the inner dimension size of the input tensor to slice at
example (1,1,64,64) -> (1,1,64,63) (1,1,64,1)
innerC/slice_inner = 63
*/

template< unsigned loop_range=5, typename T>
__attribute__ ((always_inline))
void slice_16b_inner (
        T * in_ptr,
        T * __restrict out1_ptr,
        T * __restrict out2_ptr,
        const SliceParams &params
) {
    using To = int32_or_larger_t<T>;
    constexpr unsigned gran = sizeof( To ) / sizeof( T );
    T * pIs1 = in_ptr;
    To * pOs1 = ( To* ) out1_ptr;
    int size = params.size1;
    int mask = params.mask1;

    dims_2d_t dimsS( params.numS1 - 1, sizeof( T ), params.incS1 );
    dims_2d_t dimsO( params.numS1 / gran - 1, 4, params.incO1 );

    auto fetch = [&]( ) __aie_inline -> To {
        To val = *pIs1;
        pIs1 = add_2d_byte( pIs1, dimsS );
        return val;
    };
    unsigned loop_count = params.loop_s1;
    [[using chess: no_unroll, min_loop_count( 1 )]]
    for(int i =0; i<2; i++) {
        [[using chess: min_loop_count( 1 )]]
        for ( int l=0; l<size; l++ ) aie::store_v((( T* ) pOs1 ) + 32 * l, aie::zeros<T,32>( ));

        aie::pipelined_loop<loop_range>( loop_count, [&]( auto l ) __aie_inline {
            To val = fetch( );
            if constexpr( gran > 1 ) {
                int32_t mask = (( 1 << ( 8 * sizeof( T ))) - 1 );
                val &= mask;
                #pragma unroll
                for ( unsigned w = 1; w < gran; w++ ) {
                    To v = ( fetch( ) & mask ) << ( 8 * sizeof( T ) * w );
                    val += chess_copy( v );
                }
            }
            To m = delay3((int) dimsS.count1 ) == 0 ? mask : -1;
            *pOs1 = val & m;
            pOs1 = add_2d_byte( pOs1, dimsO );
        });
        pIs1 = byte_incr( in_ptr, params.s2_offset );
        pOs1 = ( To* ) out2_ptr;
        loop_count = params.loop_s2;
        size = params.size2;
        mask = params.mask2;
        dimsS = dims_2d_t( params.numS2 - 1, sizeof( T ), params.incS2 );
        dimsO = dims_2d_t( params.numS2 / gran - 1, 4, params.incO2 );
    }

}
#endif //__SLICE_16B_INNER_IMPL_H__
