#ifndef __ADD2D_INT8X8_IMPL_H__
#define __ADD2D_INT8X8_IMPL_H__

#include "common.hh"
#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"

#ifndef ReLU_uMAX
#define ReLU_uMAX UINT8_MAX
#endif
#ifndef ReLU_sMAX
#define ReLU_sMAX INT8_MAX
#endif

#pragma pack(push,1)
struct Add2dInt8x8Params {
    uint16_t outer_loop;
    dims_2d_param dims;
    int8_t shift_in;
    int8_t shift_in1;
    int8_t shift_res;
    struct Control {
        uint8_t sign_A:1;
        uint8_t sign_W:1;
        uint8_t sign_O:1;
        uint8_t sign_srs:1;
    } ctrl;
    int8_t max_value;
};
#pragma pack(pop)

using namespace aie;

template<bool has_relu6=false, unsigned loop_range=8, typename Ti0 = int8, typename Ti1 = int8, typename To = int8>
ALWAYS_INLINE void add2d_int8x8 
(
    Ti0 * ifm,
    Ti1 * wgt,
    To  * __restrict ofm,
    const Add2dInt8x8Params &params
) {

    const Ti0  * pA = ifm;
    const Ti1  * pW = wgt;
    To         * restrict pO = ofm;

    constexpr unsigned gran = 64;
    dims_2d_t dims = params.dims.instantiate( );
    uint16_t shift_mul = 1 << params.shift_in;

    for ( unsigned i = 0; i < params.outer_loop; i++ )
        chess_prepare_for_pipelining
        chess_loop_range( loop_range, )
    {
        accum<acc32, 32> acc0, acc1;
        accum<acc32, 32> y0, y1;
        vector<int16, 64> x;
        vector<int8, 64> y;
        vector<To, 32> out0, out1;
        vector<To, 64> out;

        /* Load data */
        x = unpack(aie::load_v<64,aie_dm_resource::a>( pA ), params.ctrl.sign_A);
        y = aie::load_v<64,aie_dm_resource::a>( pW );
        y0 = sups(y.extract<32>(0), params.shift_in1, params.ctrl.sign_W);
        y1 = sups(y.extract<32>(1), params.shift_in1, params.ctrl.sign_W);

        acc0 = mac_elem_32(x.extract<32>(0), params.ctrl.sign_A, shift_mul, 0, y0);
        acc1 = mac_elem_32(x.extract<32>(1), params.ctrl.sign_A, shift_mul, 0, y1);
        out0 = acc0.template to_vector_sign<To>( params.ctrl.sign_srs, params.shift_res);
        out1 = acc1.template to_vector_sign<To>( params.ctrl.sign_srs, params.shift_res);

        if constexpr( has_relu6 ) {
            out = aie::min( aie::concat(out0,out1), (To)params.max_value, params.ctrl.sign_O);
            store_v( pO, out );
        } else {
            store_v( pO+32, out1 );
            store_v( pO, out0 );
	    }

        /* Increment pointers */
        pA += gran;
        pW = add_2d_byte( pW, dims );
        pO += gran;
    }
}

#endif // __ADD2D_INT8X8_TEMPLATE_H__