#ifndef __CONCAT_16B_INNER_IMPL_H__
#define __CONCAT_16B_INNER_IMPL_H__

#include "aie_api/utils.hpp"
#include "aie_api/aie.hpp"
#include "common/api_loop_pipe_helper.hpp"
#include "common/ml_params.h"

struct ConcatParams {
    uint32_t loop_s1;
    uint32_t loop_s2;
    uint32_t s2_offset;
    uint32_t innerC;
    uint16_t incO1;
    uint16_t numO2;
    uint16_t incO2;
    uint16_t incI1;
    uint16_t incI2;
    uint16_t sizeO;
    int16_t pad_value;
};

/*
Kernel to concat two tensors at the inner dimensions
Concating is controlled through addressing parameters in ConcatParams
Reference how these are calculated are in the metafile
innerC/concat_inner defines the inner dimension size of the first input that is concat
example (1,1,64,63) (1,1,64,1) -> (1,1,64,64)
innerC/concat_inner = 63
*/
template< unsigned loop_range=5, typename T>
__attribute__ ((always_inline))
void concat_16b_inner (
        T * in1_ptr,
        T * in2_ptr,
        T * __restrict out_ptr,
        const ConcatParams &params
) {
    T * pIs1 = in1_ptr;
    T * restrict pOs1 = ( T* restrict) out_ptr;

    constexpr unsigned N = 64 / sizeof( T );
    [[ using chess: min_loop_count( 1 )]]
    for ( int i = 0; i < params.sizeO; i++ ) {
        aie::store_v( pOs1 + i * N, aie::broadcast<T, N>( params.pad_value ));
    }

    dims_2d_t dimsS(params.innerC-1, sizeof( T ), params.incI1);
    dims_2d_t dimsO(params.innerC-1, sizeof( T ), params.incO1);

    auto fetch = [&]( ) __aie_inline -> T {
        T val = *pIs1;
        pIs1 = add_2d_byte( pIs1, dimsS );
        return val;
    };
    unsigned loop_count = params.loop_s1;
    [[ using chess: no_unroll, min_loop_count( 1 )]]
    for(int i =0; i<2; i++) {

        aie::pipelined_loop<8>( loop_count, [&]( auto i ) __aie_inline {
            T val = fetch( );
            *pOs1 = val; pOs1 = add_2d_byte( pOs1, dimsO );
        });
        pIs1 = in2_ptr;
        pOs1 = ( T* restrict) byte_incr(out_ptr,params.s2_offset);
        loop_count = params.loop_s2;
        dimsS = dims_2d_t(params.numO2, sizeof( T ), params.incI2);
        dimsO = dims_2d_t(params.numO2, sizeof( T ), params.incO2);
    }
}
#endif //__CONCAT_16B_INNER_IMPL_H__
