#ifndef __STX_GEMM_INT16xINT16_W4_CC__
#define __STX_GEMM_INT16xINT16_W4_CC__

#include <adf.h>
#include <aie_api/aie.hpp>
#include <aie_api/aie_adf.hpp>
#include <aie_api/utils.hpp>
#include <stdint.h>

// Here the M0, K0, N0 is dimension being done in single iteration of inner loop
// "NOT" the dimesnion being done by one call to intrinsic
#define I16xI16_GEMM_M0 8
#define I16xI16_GEMM_K0 4
#define I16xI16_GEMM_N0 8

struct gemmIncrements property(keep_in_registers)
{
    int numI1;
    int incI1;
    int numI2;
    int incI2;
    int incI3;

    int numW1;
    int incW1;
    int numW2;
    int incW2;
    int incW3;
};

inline void setup_gemm_parameters
(
    int core_m,
    int core_k,
    int core_n,
    int m0,
    int k0,
    int n0,
    gemmIncrements &incrs
)
{
    int bf16_bytes = 2;

    incrs.incI1 = core_m * k0 * bf16_bytes;
    incrs.incI2 = (m0*k0 - (core_m)*(core_k-k0))*bf16_bytes;
    incrs.incI3 = -(core_m*core_k-m0*k0)*bf16_bytes;
    incrs.numI1 = (core_k / k0) - 1;
    incrs.numI2 = (core_m / m0) - 1;

    incrs.incW1 = k0 * 8 * bf16_bytes;
    incrs.incW2 = -(core_k-k0) * 8 * bf16_bytes;
    incrs.incW3 = k0 * 8 * bf16_bytes;
    incrs.numW1 = (core_k/k0) - 1;
    incrs.numW2 = (core_m/m0) - 1;
}

inline __attribute__((always_inline))
void w8_to_w4_8x8block
(
    v32int16 * /*__restrict*/ & ptrA0,
    v32int16 * /*__restrict*/ & ptrO,
    int num_out_rows
)
{
    int int16_bytes = 2;
    v32int16* ptrA1 = byte_incr(ptrA0, 4 * 8 * int16_bytes);
    v32int16* ptrO1 = byte_incr(ptrO, num_out_rows * 4 * int16_bytes);

    v32int16 vldA0  = *ptrA0; // loading top 4x8
    v32int16 vldA1  = *ptrA1; // loading bot 4x8

    v32int16 v0A    = shuffle(vldA0, vldA1, T64_8x2_lo);  // left   8x4
    v32int16 v2A    = shuffle(vldA0, vldA1, T64_8x2_hi);  // right  8x4

    *ptrO           = v0A;   // VST wl0
    *ptrO1          = v2A;   // VST wh0
}

inline __attribute__((always_inline))
void w8_to_w8_Transpose8x8block
(
    v32int16 * /*__restrict*/ & ptrA0,
    v32int16 * /*__restrict*/ & ptrOut,
    int num_out_rows
)
{
    int int16_bytes  = 2;
    v32int16 * /*__restrict*/ ptrO = ptrOut;
    v32int16* ptrA1  = byte_incr(ptrA0, 4 * 8 * int16_bytes);
    v32int16* ptrO1  = byte_incr(ptrO , 4 * 8 * int16_bytes);

    v32int16 vldA0   = *ptrA0; // loading top 4x8
    v32int16 vldA1   = *ptrA1; // loading bot 4x8

    v64int16 m8x8    = concat(vldA0, vldA1);
    v64int16 m8x8_tr = aie::transpose(aie::vector<int16, 64>(m8x8), 8, 8);
    vldA0 = extract_v32int16(m8x8_tr, 0);
    vldA1 = extract_v32int16(m8x8_tr, 1);

    //here vldA0, vldA1 has the transposed 8x8 block (total v64int16)

    *ptrO  = vldA0;
    *ptrO1 = vldA1;
}


inline __attribute__((always_inline))
void accumulate_w4xw8      // performs (8,4,8) == (m0,k0,n0) in a single inner loop iteration
(
    v32int16 *& ptrA0,     // 8x4    into two 4x4 (v0A , v1A)
    v32int16 *& ptrB0,     //   4x8  into one 4x8 (v0B      )
    v32acc64& accumulator1,
    v32acc64& accumulator2,
    addr_t& cnt_I_1,
    addr_t& cnt_I_2,
    addr_t& cnt_W_1,
    addr_t& cnt_W_2,
    gemmIncrements& incrs,
    int zero_acc,
    int rowIdx,
    int colIdx
)
{
    //int rowIdx = (get_coreid() & 0xF);
    //int colIdx = (get_coreid() >> 16);
    /*if(rowIdx == 2 && colIdx == 0)
    {
        printf("ptrA0 == %x\n", ptrA0);
        printf("ptrB0 == %x\n", ptrB0);
    }*/

    // use pointers to load (8x4) from ifm and another (4x8) from wgt
    v32int16 v0A = *ptrA0;   // v0A:  up 4x4
    v32int16 v1A;            // v1A: low 4x4  // combined : 8x4
    v1A = insert(v1A, 0, extract_v16int16(v0A, 1));

    v32int16 v0B = *ptrB0;   // v0B:     4x8

    ptrA0 = add_3d_byte(ptrA0, incrs.incI3,
                        incrs.numI1, cnt_I_1, incrs.incI1,
                        incrs.numI2, cnt_I_2, incrs.incI2);

    ptrB0 = add_3d_byte(ptrB0, incrs.incW3,
                        incrs.numW1, cnt_W_1, incrs.incW1,
                        incrs.numW2, cnt_W_2, incrs.incW2);

    // mac_4x4_4x8_conf() [1/32]
    // v32acc64 mac_4x4_4x8_conf(v32int16 a, int sgn_x, v32int16 b, int sgn_y, v32acc64 acc,
    //                           int zero_acc, int shift16, int sub_mul, int sub_acc1)
    accumulator1 = mac_4x4_4x8_conf(v0A, 0, v0B, 0, accumulator1, zero_acc, 0, 0, 0);
    accumulator2 = mac_4x4_4x8_conf(v1A, 0, v0B, 0, accumulator2, zero_acc, 0, 0, 0);
}


v32int16 v32acc64_to_v32int16(v32acc64 accv)
{
    v16int16 o0, o1;
    o0 = srs_to_v16int16(v16acc32(lsrs(extract_v16acc64(accv, 0), 0)), 0);
    o1 = srs_to_v16int16(v16acc32(lsrs(extract_v16acc64(accv, 1), 0)), 0);
    return concat(o0, o1);
}
v32int32 v32acc64_to_v32int32(v32acc64 accv, int shiftamt_acc64_to_int32)
{
    v16int32 o0, o1;
    o0 = lsrs(extract_v16acc64(accv, 0), shiftamt_acc64_to_int32);
    o1 = lsrs(extract_v16acc64(accv, 1), shiftamt_acc64_to_int32);
    return concat(o0, o1);
}


//template<int out_mode>  // 0: 32 bit output (after lsrs). 1: 64 bit output (acc64 store to L1)
                 // for out_mode==1, shift_amt has no effect
//inline __attribute__((always_inline))
__attribute__((noinline))
void gemm_int16xint16
(
    int8_t* bufA,   // m x k x 2
    int8_t* bufB,   // k x n x 2 or n x k x 2 : !perform_transpose_B --> k x n x 2, otherwise --> n x k x 2
    int8_t* bufT1,  // m x k x 2
    int8_t* bufT2,  // m x k x 2
    int8_t* bufO,   // m x n x 2
    int M_g, // M_SUBV
    int K_g, // K_SUBV
    int N_g, // N_SUBV
    int sizeA,
    int sizeB,
    int shift_amt,
    bool perform_transpose_B,
    int zero_acc = 1,
    int out_mode = 1,
    bool perform_transpose_A=0,   // ading last to keep signature compatible
    int sign_A=0,
    int sign_B=0
)
{

    v32int16 chess_storage(DM_bankA) * __restrict ptrI  = (v32int16 chess_storage(DM_bankA) * __restrict) bufA;
    v32int16 chess_storage(DM_bankA) * ptrW = (v32int16 chess_storage(DM_bankA)* ) bufB;
    v32acc64 * __restrict ptdm1in = (v32acc64 * __restrict) bufT1;
    v32acc64 * __restrict ptdm2in = (v32acc64 * __restrict) bufT2;
    v32acc64 * __restrict ptdm1out = (v32acc64 * __restrict) bufT1;
    v32acc64 * __restrict ptdm2out = (v32acc64 * __restrict) bufT2;
    v32int16 * pO = (v32int16 * __restrict) bufO;
    int32_t  * pO_32b = (int32_t * __restrict) bufO;
    v32acc64 * pO_64b = (v32acc64 * __restrict) bufO;

    //for transpose, jump to the start again

    int mi = M_g - 1;
    int incWI = perform_transpose_B ? N_g * 128 : 128;
    int incW1 = -K_g * incWI;
    int incW2 = perform_transpose_B ? incW1 + 128 : 0;
    int incAI = perform_transpose_A ? 128 : M_g * 128;
    int incA1 = perform_transpose_A ? 0 : 128 - sizeA;
    int incA2 = perform_transpose_A ? -sizeA : incA1 - incAI;

    //dims_2d_t dimsI( mi, 128 - sizeA, -128 * mi - sizeA );
    dims_2d_t dimsI( mi, incA1, incA2 );
    //dims_2d_t dimsW( mi, -16 * core_k_dim, 0 );
    dims_2d_t dimsW( mi, incW1, incW2 );

    fifo_state_t f;
    f.pos = 0;


    v32acc64 accumulator1, accumulator2;
    for(int j = 0; j < M_g*N_g; j++)
    chess_prepare_for_pipelining
    chess_loop_range(2,)
    {
    
        int z = zero_acc;
        accumulator1 = *ptdm1in++;
        accumulator2 = *ptdm2in++;

        for(int i = 0; i < K_g; i++)
        chess_prepare_for_pipelining
        chess_loop_range(4,)
        {
            //accumulate_w4xw8(pI, pW, accumulator1, accumulator2, \
            //                 cnt_I_1, cnt_I_2, cnt_W_1, cnt_W_2, incrs, z);

            v32int16 i1 = ptrI[1];
            v32int16 ib = ptrI[perform_transpose_A];
            v32int16 i0 = ptrI[0];
            ptrI = byte_incr( ptrI, incAI );

            v32int16 w0 = ptrW[0];
            v32int16 w1 = ptrW[1];

            ptrW = byte_incr( ptrW, incWI );

            int a_mode = perform_transpose_A ? T16_4x8 : T64_8x2_lo;
            v32int16 x0 = shuffle( i0, i1, a_mode );
            v32int16 x1 = shuffle( ib, i1, a_mode + !perform_transpose_A );
            v32int16 x2 = shuffle( x0, T256_2x2_hi );
            v32int16 x3 = shuffle( x1, T256_2x2_hi );

            int w_mode = perform_transpose_B ? T16_8x8_lo : T512_1x2_lo;
            v32int16 y0 = shuffle(w0, w1, w_mode);
            v32int16 y1 = shuffle(w0, w1, w_mode + 1);

            accumulator1 = mac_4x4_4x8_conf( x0, sign_A, y0, sign_B, accumulator1, z, 0, 0, 0 );
            accumulator1 = mac_4x4_4x8( x1, sign_A, y1, sign_B, accumulator1 );
            accumulator2 = mac_4x4_4x8_conf( x2, sign_A, y0, sign_B, accumulator2, z, 0, 0, 0 );
            accumulator2 = mac_4x4_4x8( x3, sign_A, y1, sign_B, accumulator2 );

            z = 0;
        }

        ptrI = add_2d_byte( ptrI, dimsI );
        ptrW = add_2d_byte( ptrW, dimsW );

        if(out_mode == 1)
        {
            aie::store_v( pO_32b, aie::accum<acc64,32>( accumulator1 ).to_vector<int32>( shift_amt ));      pO_32b += 32;
            aie::store_v( pO_32b, aie::accum<acc64,32>( accumulator2 ).to_vector<int32>( shift_amt ));      pO_32b += 32;
        }
        else
        {
            *ptdm1out++ = accumulator1;
            *ptdm2out++ = accumulator2;
        }
    }
}

#endif
