#ifndef __BDCAST_ADD2D_INT8X8_IMPL_H__
#define __BDCAST_ADD2D_INT8X8_IMPL_H__

#include "aie_api/aie.hpp"
#include "add2d_int8x8.hpp"
#include "aie_api/utils.hpp"

#include <cstdint> 
#include <type_traits>

#ifndef ReLU_uMAX
#define ReLU_uMAX UINT8_MAX
#endif
#ifndef ReLU_sMAX
#define ReLU_sMAX INT8_MAX
#endif

using namespace aie;

template<bool has_relu6=false, unsigned loop_range=8, unsigned has_scalar_broadcast = 0 , typename Ti0 = int8, typename Ti1 = int8, typename To = int8>
ALWAYS_INLINE void add2d_int8x8_vmac
(
    Ti0 * ifm,
    Ti1 * wgt,
    To  * __restrict ofm,
    const bdcast_Add2dInt8x8Params &params
) {

    const Ti0  * pA = ifm;
    const Ti1  * pW = wgt;
    To         * restrict pO = ofm;

    constexpr unsigned gran = 64;
    dims_2d_t dims_x = params.dims_x.instantiate( );
    dims_2d_t dims_y = params.dims_y.instantiate( );
    uint16_t shift_mul = 1 << params.shift_in;
    
    using xtype = std::conditional_t<std::is_same_v<Ti0,int8>, int16, uint16>;

    for ( unsigned i = 0; i < params.outer_loop; i++ )
        chess_prepare_for_pipelining
        chess_loop_range( loop_range, )
    {
        
        accum<acc32, 32> acc0, acc1;
        accum<acc32, 32> y0, y1;
        vector<xtype, 64> x;
        vector<Ti1, 64> y;
        vector<To, 32> out0, out1;
        vector<To, 64> out;
        /* Load data */
        x = unpack(aie::load_v<64,aie_dm_resource::a>( pA ), params.ctrl.sign_A);
        if constexpr( has_scalar_broadcast == 0 ) {
            y = load_v<gran,aie_dm_resource::a>( pW ); 
        } else {
            y = aie::broadcast<Ti1,gran>(*pW);
        }
        y0 = sups(y.template extract<32>(0), params.shift_in1, params.ctrl.sign_W);
        y1 = sups(y.template extract<32>(1), params.shift_in1, params.ctrl.sign_W);

        acc0 = mac_elem_32(x.template extract<32>(0), params.ctrl.sign_A, shift_mul, 0, y0);
        acc1 = mac_elem_32(x.template extract<32>(1), params.ctrl.sign_A, shift_mul, 0, y1);
        out0 = acc0.template to_vector_sign<To>( params.ctrl.sign_srs, params.shift_res);
        out1 = acc1.template to_vector_sign<To>( params.ctrl.sign_srs, params.shift_res);

        if constexpr( has_relu6 ) {
            out = aie::min( aie::concat(out0,out1), (To)params.max_value, params.ctrl.sign_O);
            store_v( pO, out );
        } else {
            store_v( pO+32, out1 );
            store_v( pO, out0 );
	    }

        /* Increment pointers */
        pA = add_2d_byte( pA, dims_x );
        pW = add_2d_byte( pW, dims_y );
        pO += gran;
    }
}


template<bool has_relu6=false, unsigned loop_range=8, unsigned has_scalar_broadcast = 0, typename Ti0 = int8, typename Ti1 = int8, typename To = int8>
ALWAYS_INLINE void add2d_int8x8_mmac
(
    Ti0 * ifm,
    Ti1 * wgt,
    To  * __restrict ofm,
    const bdcast_Add2dInt8x8Params &params
) {

    const Ti0  * pA = ifm;
    const Ti1  * pW = wgt;
    To         * restrict pO = ofm;

    constexpr unsigned gran = 128;
    dims_2d_t dims_x = params.dims_x.instantiate( );
    dims_2d_t dims_y = params.dims_y.instantiate( );
    uint16_t shift_mul = 1 << params.shift_in;
    aie::mask shift_mask = aie::mask<64>::from_uint64(0x8040201008040201);
    //constexpr uint64_t broadcast_mask_value = has_scalar_broadcast ? 0xFFFFFFFFFFFFFFFF : 0x0;
    //constexpr aie::mask broadcast_mask_64 = aie::mask<64>::from_uint64(broadcast_mask_value);
    //const aie::mask broadcast_mask = aie::mask<128>::from_masks(broadcast_mask_64, broadcast_mask_64);
    vector<int8,64> shift_vector = aie::select(aie::broadcast<int8,64>(0), aie::broadcast<int8,64>(shift_mul),shift_mask);

    set_staging_8x8_8x8(shift_vector);
    staging_to_matrix_m64x64int8();

    for ( unsigned i = 0; i < params.outer_loop; i++ )
        chess_prepare_for_pipelining
        chess_loop_range( loop_range, )
    {
        
        vector<Ti0, gran> x;
        vector<Ti1, gran> y;
        vector<To, 64> out0, out1;
        //chess storage 2 accums to two banks to achieve 1 cycle schedule
        aie::accum<acc32,64> chess_storage(dma0) acc0, chess_storage(dma1) acc1, chess_storage(dmb0) acc0_ld, chess_storage(dmb1) acc1_ld;
        
        /* Load data */
        x = load_v<gran,aie_dm_resource::a>( pA ); pA = add_2d_byte( pA, dims_x );
        if constexpr( has_scalar_broadcast == 0 ) {
            y = load_v<gran,aie_dm_resource::a>( pW ); 
        } else {
            y = aie::broadcast<Ti1,gran>(*pW);
        }
        pW = add_2d_byte( pW, dims_y );
        //y = aie::select( load_v<gran,aie_dm_resource::a>( pW ), *pW, broadcast_mask ); pW = add_2d_byte( pW, dims_y );

        acc0_ld = sups(y.template extract<64>(0), params.shift_in1, params.ctrl.sign_W);
        acc1_ld = sups(y.template extract<64>(1), params.shift_in1, params.ctrl.sign_W);
        acc0 = mac_conf(x.template extract<64>(0), params.ctrl.sign_A, 0, acc0_ld, 0);
        acc1 = mac_conf(x.template extract<64>(1), params.ctrl.sign_A, 0, acc1_ld, 0);
        out0 = aie::accum<acc32,64>((v64acc32)acc0).template to_vector_sign<To>( params.ctrl.sign_srs, params.shift_res );
        out1 = aie::accum<acc32,64>((v64acc32)acc1).template to_vector_sign<To>( params.ctrl.sign_srs, params.shift_res );

        if constexpr( has_relu6 ) {
            out0 = aie::min( out0, (To)params.max_value, params.ctrl.sign_O );
            out1 = aie::min( out1, (To)params.max_value, params.ctrl.sign_O );
        }
        store_v( pO+64, out1);
        store_v( pO, out0); pO += gran;
    }
}

template<bool has_relu6=false, unsigned use_mmac= false, unsigned has_scalar_broadcast = 0, unsigned loop_range=8, typename Ti0 = int8, typename Ti1 = int8, typename To = int8>
ALWAYS_INLINE void bdcast_add2d_int8x8 
(
    Ti0 * ifm,
    Ti1 * wgt,
    To  * __restrict ofm,
    const bdcast_Add2dInt8x8Params &params
) {

    if constexpr( use_mmac==0 ) {
        add2d_int8x8_vmac<has_relu6,loop_range,has_scalar_broadcast,Ti0,Ti1,To>(ifm, wgt, ofm, params);
    } else {
        add2d_int8x8_mmac<has_relu6,loop_range,has_scalar_broadcast,Ti0,Ti1,To>(ifm, wgt, ofm, params);
    }

}

#endif // __ADD2D_INT8X8_TEMPLATE_H__