#ifndef __ADD2D_H__
#define __ADD2D_H__

#include "aie_api/aie.hpp"
#include <adf.h>

#include "conv/conv_xint8/biased_conv_int8x8_template.h"
#include "conv/conv_xint8/utils.h"
#include "conv/conv_xint8/mllib_const.h"

//template <typename dtype_in, typename dtype_in1, typename dtype_out, >
void run_matadd(KernelArgs& args) {
    set_rnd(rnd_conv_even);
    set_sat();
    const int shift_ifm1 = 0;

    auto runtime_params = static_cast<uint16_t const*>(args.params_data);
    
    
    int8_t* ifm0 = static_cast<int8_t*>(args.s2mm_ch0_data);                  // tdm2 buffer
    int8_t* ifm1 = static_cast<int8_t*>(conv_to_local_ptr(runtime_params[1]));   // ofm buffer
    int8_t* ofm = static_cast<int8_t*>(args.mm2s_ch0_data); // tdm1 buffer
    int ofmsize = runtime_params[2];
    conv3d_params_t& params = *((conv3d_params_t *) conv_to_local_ptr(runtime_params[0]));

    int num_inputs = ofmsize/32;              // Length of tensors in multiples of 32 (H*W*N*C/32)
    const int8_t __aie_dm_resource_a * __restrict in_ptr0 = (const int8_t __aie_dm_resource_a *) ifm0; // Input coming from IFM bank
    const int8_t __aie_dm_resource_a * __restrict in_ptr1 = (const int8_t __aie_dm_resource_a *) ifm1;  // Output of conv coming from OFM bank
    int8_t       __aie_dm_resource_a * __restrict out_ptr = (int8_t __aie_dm_resource_a *) ofm;      // OUtput is in same bank as above

#if 0
    int ColIdx = get_coreid() >> 16;
    int RowIdx = get_coreid() & 0xf;
    if( ColIdx == 0 && RowIdx == 2) {
        printf(" MATADD SHIFT %d \n",params.upshift_elw_ifm1);
        printf(" MATADD SHIFT %d \n",params.upshift_elw_ifm2);
        printf(" MATADD SHIFT %d \n",params.downshift_eltw_res);
        printf(" MATADD: ifm0 %p \n",ifm0);
        printf(" MATADD: ifm1 %p \n",ifm1);
        printf(" MATADD: ofm %p \n",ofm);
        printf(" MATADD: ofmsize %d \n",ofmsize);
        int8_t * tmp = static_cast<int8_t*>(conv_to_local_ptr(runtime_params[0]));
        printf("input 1 input 2 output\n");
        for(int i=0;i<64;i++){
            printf("ind: %d, ifm0: %d, ifm1: %d, out: %d\n",i,in_ptr0[i],in_ptr1[i],(in_ptr0[i]+in_ptr1[i]));
        }
    }
#endif


    const int incr = 32 ; // Number of acc lanes
    const int result_sign = 1 ;
    int shift_in0, shift_in1, shift_out;
    
    shift_in0=params.upshift_elw_ifm1;
    shift_in1=params.upshift_elw_ifm2;
    shift_out=params.downshift_eltw_res;

    #define ADD2D_BODY(in_ptr0,in_ptr1,out_ptr,result_sign,shift_in,shift_in1,shift_out,in0_sign)\
    {\
        accum<acc32,32> acc0 ;\
        accum<acc32,32> acc1 ;\
        accum<acc32,32> acc_res ;\
        acc0.from_vector_sign( load_v<32>( in_ptr0 ), in0_sign, shift_in );\
        in_ptr0 +=incr ;\
        acc1.from_vector_sign( load_v<32>( in_ptr1 ), 1, shift_in1 ); in_ptr1+=incr ;\
        acc_res = aie::add(acc0, acc1);\
        store_v(out_ptr, acc_res.to_vector_sign<int8_t>(result_sign,shift_out));out_ptr+=incr; \
    }

#if 1
    for (int i=0; i<num_inputs; i++)
    chess_prepare_for_pipelining
    //chess_loop_range(8,)
    //chess_unroll_loop(2)
    {
        ADD2D_BODY(in_ptr0,in_ptr1,out_ptr,result_sign, shift_in0, shift_in1, shift_out, 1)
    }
#else

#if 0
    int ColIdx = get_coreid() >> 16;
    int RowIdx = get_coreid() & 0xf;
    for(int i = 0; i < ofmsize; i++){
        ofm[i] = RowIdx;
    }
#endif

#endif


}

#endif
