#ifndef __KERNEL_GENERIC_TRANSPOSE_IMPL_HPP__
#define __KERNEL_GENERIC_TRANSPOSE_IMPL_HPP__

#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"
#include "common/api_loop_pipe_helper.hpp"
#include "common/ml_params.h"

struct KernelGenericTransposeParam {
    uint16_t inner_loop;
    uint16_t outer_loop;
    int8_t shift;
    int8_t shift_fin;
    uint16_t size_out;
    uint32_t offset_in;
    uint32_t offset_out;
    dims_5d_param dims_in;
    int32_t inc_out;
    dims_3d_param dims_out;
};

template<typename T>
struct uint32_or_larger { using type = uint32_t; };
template<> struct uint32_or_larger<uint64_t> { using type = uint64_t; };

template<typename T>
using uint32_or_larger_t = typename uint32_or_larger<T>::type;

template<typename T>
inline void generic_transpose( T * ifm, T * __restrict ofm, const KernelGenericTransposeParam &param )
{
    using Ti = std::make_unsigned_t<T>;
    using To = uint32_or_larger_t<T>;
    Ti * pI = ( Ti* ) byte_incr( ifm, param.offset_in );
    To * pO = ( To* ) byte_incr( ofm, param.offset_out );
    
    constexpr unsigned Nz = 64 / sizeof( T );
    [[using chess: min_loop_count( 1 )]]
    for ( int i=0; i<param.size_out; i++ ) {
        aie::store_v( ofm + Nz * i, aie::zeros<T,Nz>( ));
    }

    dims_3d_t dimsA, dimsB;
    std::tie( dimsA, dimsB ) = param.dims_in.instantiate( );
    dims_3d_t dimsO = param.dims_out.instantiate( );

    [[using chess: min_loop_count( 1 )]]
    for ( unsigned o = 0; o < param.outer_loop; o++ ) {
        To aggr = 0;
        aie::pipelined_loop<4>( param.inner_loop, [&]( auto i ) __aie_inline {
            To val = *pI << ( 32 - param.shift );
            aggr = ( aggr >> param.shift ) + val;
            *pO = aggr;
            pI = add_3d_byte( pI, dimsA );
            pI = add_3d_byte( pI, dimsB );
            pO = byte_incr( pO, param.inc_out );
        });
        To val = *pI << ( 32 - param.shift_fin );
        aggr = (( aggr >> param.shift_fin ) & -( param.shift_fin < 32 )) + val;
        *pO = aggr;
        pI = add_3d_byte( pI, dimsA );
        pI = add_3d_byte( pI, dimsB );
        pO = add_3d_byte( pO, dimsO );
    }
}

#endif
