/*  (c) Copyright 2019 - 2021 Xilinx, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __GROUP_NORM_H__
#define __GROUP_NORM_H__
#include "global_reduce_impl.hpp"
#define GPN_VEC_LEN 32
#include "q/q.hpp"
#include "q/q_impl.hpp"
#include "dq/dq.hpp"
#include "dq/dq_impl.hpp"
#include "uniop/uniop_qdq.hpp"

//#include "softmax.cc"
/*void __attribute__((noinline)) mask_w8_cols(bfloat16* ptr, int rows, int step_cols, bfloat16 mask_value, uint32_t mask) 
{
    bfloat16* pI = ptr + step_cols * rows;
    bfloat16* pO = ptr + step_cols * rows;
    for (unsigned i = 0; i < rows / 4; i++) 
        chess_no_hw_loop
        chess_prepare_for_pipelining
        chess_loop_range(2,)
    {
            aie::store_v(pO, aie::select(mask_value, aie::load_v<32>(pI), aie::mask<32>::from_uint32(mask)));
            pI += 32;
            pO += 32;
    } 
}*/
// expands each elements repeated 16 times
template<typename T>
inline __attribute__((always_inline)) void expansionMeanVar(float* input,float* output,int num_groups){
    v32float* out = (v32float*)output;

    for(int i=0;i<num_groups;i++)
    chess_no_hw_loop
    {
        out[i] = broadcast_to_v32float(input[i]); // Same as AIE2P
    }
}

#define NOTPASSTHROUGH 1
template<typename T>
inline __attribute__((always_inline)) void invsqrt_vectorwise
(
  float * input,
  float * output,
  int16_t N_Dim
)
{   
    for (int i = 0; i < N_Dim; i++)
    chess_no_hw_loop
    {     
       output[i] = invsqrt (input[i]); // Same as AIE2P
    }
}

template<typename T>
inline __attribute__((always_inline)) void norm_and_affine
(
        T * input,
        T * restrict parameters,
        T * output,
        uint16_t num_groups,
        int16 ElemPerGrpCore,
        float *scratch_addr,
        const int rowIdx, const int colIdx
)
{
    int NG = num_groups;
    float* scratch_addr_var = scratch_addr + NG*GPN_VEC_LEN; //NG * 16
    using VecT = typename std::conditional<std::is_same<T, float16>::value, v32float16, v32bfloat16>::type;
    VecT *outIter    = (VecT*) output;

    v32accfloat * mean_float  = (v32accfloat *)scratch_addr;
    v32accfloat * vec_inv_var = (v32accfloat *)(scratch_addr_var);

    VecT * p_bias   = (VecT *)(parameters) ; // TODO: Check offset
    VecT * p_scale    = (VecT *)(parameters + NG*ElemPerGrpCore);

    VecT one_32_lane;
    if constexpr (std::is_same<T, float16>::value) {
        one_32_lane = broadcast_float16( 1.0 );
    } else {
        one_32_lane = broadcast_bfloat16( 1.0 );
    }
    VecT *inIter2    = (VecT *) input;

    auto mean_iter      = aie::begin_vector<32>(scratch_addr);
    auto inv_var_iter   = aie::begin_vector<32>(scratch_addr_var);
    auto bias_iter      = aie::begin_vector<32>((T*)p_bias); // TODO: Make it circular
    auto scale_iter     = aie::begin_vector<32>((T*)p_scale);

    for (int miter = 0; miter < NG; ++miter)
    chess_prepare_for_pipelining
    chess_loop_range(1,) 
    {
        VecT mean_vec, negative_mu, mac_op_1_st3, mac_op_2_st3;
        if constexpr (std::is_same<T, float16>::value) {
            mean_vec = to_v32float16(v32accfloat((*mean_iter++).to_native()));
            negative_mu = to_v32float16(negmul_elem_32(mean_vec, one_32_lane));
            mac_op_1_st3 = to_v32float16(v32accfloat((*inv_var_iter++).to_native()));
            mac_op_2_st3 = to_v32float16(mul_elem_32(mac_op_1_st3, negative_mu));
        } else {
            mean_vec = to_v32bfloat16(v32accfloat((*mean_iter++).to_native()));
            negative_mu = to_v32bfloat16(negmul_elem_32(mean_vec, one_32_lane));
            mac_op_1_st3 = to_v32bfloat16(v32accfloat((*inv_var_iter++).to_native()));
            mac_op_2_st3 = to_v32bfloat16(mul_elem_32(mac_op_1_st3, negative_mu));
        }
        for (int niter = 0; niter < (ElemPerGrpCore/GPN_VEC_LEN); ++niter)
        chess_prepare_for_pipelining
        chess_loop_range(1,) 
        {
            //VecT scale   = std::is_same<T, float16>::value ? broadcast_float16( 1.0 ) : broadcast_bfloat16( 1.0 ); //*scale_iter++; //TODO: Use scale iterator
            //VecT bias_t  = std::is_same<T, float16>::value ? broadcast_float16( 0.0 ) : broadcast_bfloat16( 0.0 ); //*bias_iter++; //TODO: Use bias iterator
            VecT scale   = *scale_iter++; //TODO: Use scale iterator
            VecT bias_t  = *bias_iter++; //TODO: Use bias iterator
            v32accfloat bias    = to_v32accfloat(bias_t);
            VecT scaled_op_1;
            if constexpr (std::is_same<T, float16>::value) {
                scaled_op_1 = to_v32float16(mul_elem_32(scale, mac_op_1_st3));
            } else {
                scaled_op_1 = to_v32bfloat16(mul_elem_32(scale, mac_op_1_st3));
            }
            v32accfloat scaled_b_op_1 = mac_elem_32(scale, mac_op_2_st3, bias);
            v32accfloat res = mac_elem_32(*inIter2++, scaled_op_1, scaled_b_op_1);
            if constexpr (std::is_same<T, float16>::value) {
                *outIter++ = to_v32float16(res);
            } else {
                *outIter++ = to_v32bfloat16(res);
            }
        }
    }

#if 0
    if(rowIdx == 2 && colIdx == 0)
    {
        /*auto outIter_print   = aie::begin_vector<32>(output);
        for(int i=0;i<4*ElemPerGrpCore/32;i++){
            aie::print(*outIter_print++, true, "res = ");
        }*/
        auto bias_iter_print      = aie::begin_vector<32>((T*)p_bias); // TODO: Make it circular
        auto scale_iter_print     = aie::begin_vector<32>((T*)p_scale);
        for(int i=0;i<4;i++){
            aie::print(*scale_iter_print++, true, "scale = ");
            aie::print(*bias_iter_print++, true, "bias = ");
        }
    }
    if(rowIdx == 2 && colIdx == 0)
    {
        //aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)addr_mean), true, "mean = ");
        //aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)(addr_mean+16)), true, "mean = ");
        aie::print(*inv_var_iter++, true, "res = ");
        //aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)(addr_var+16)), true, "variance = ");
        printf("--------------------------------------------------\n");
    }
#endif

}

template<typename T>
inline __attribute__((always_inline)) void mean_and_var
(
    T * input,
    uint16_t msubv,
    uint16_t nsubv,
    int mparam,
    uint16_t ColPerGrp,
    int16 ElemPerGrpCore,
    uint16_t is_last_iter,
    uint16_t enable_global_reduce,
    float *scratch_addr,
    uint16_t num_groups,
    float* output, 
    const int rowIdx, const int colIdx
)
{
    
    float eps = EPSILON;
    //const int ElemPerGrpCore = 80; //16;
    int NG = num_groups;
    auto inIter1 = aie::begin_vector<GPN_VEC_LEN>(input);
    static int zero_init = 1;
    v32accfloat* sum = (v32accfloat*)scratch_addr; // It was v16 for AIE2P
    v32accfloat* sum_sq = (v32accfloat*)(scratch_addr + NG*GPN_VEC_LEN); // factor of 2 is for the redundant v16
    using VecT = typename std::conditional<std::is_same<T, float16>::value, v32float16, v32bfloat16>::type;
    VecT one_32_lane;
    if constexpr (std::is_same<T, float16>::value) {
        one_32_lane = broadcast_float16( 1.0 );
    } else {
        one_32_lane = broadcast_bfloat16( 1.0 );
    }
    // ------------------------------  Stage 1: Local Sums and Sums of Squares ---------------------------
    for (int miter = 0; miter < NG; ++miter)
    chess_prepare_for_pipelining
    chess_loop_range(1,) 
    {
        int acc_zero = zero_init;
        for (int niter = 0; niter < (ElemPerGrpCore/GPN_VEC_LEN); ++niter)
        chess_prepare_for_pipelining
        chess_loop_range(1,) 
        {
            VecT in_1x32;
            in_1x32 = *inIter1++; // Same as AIE2
            sum[miter] = mac_elem_32_conf(in_1x32, one_32_lane, sum[miter], acc_zero, 0, 0);  
            sum_sq[miter] = mac_elem_32_conf(in_1x32, in_1x32, sum_sq[miter], acc_zero, 0, 0);
            acc_zero = 0;
        }
    }
    zero_init = 0;

    /*if(rowIdx == 2 && colIdx == 0)
    {
        for(int i=0;i<64;i++){
            printf("sum: %f \n", scratch_addr[i]);
        }
    }

    float *output_mean = output;
    global_reduce((v32float)sum[0], output, 1);
    if(rowIdx == 2 && colIdx == 0)
    {
        for(int i=0;i<32;i++){
            printf("sum: %f \n", output[i]);
        }
    }
    if(rowIdx == 2 && colIdx == 0)
    {
        printf("is_last_iter: %d \n", is_last_iter);
    }*/
    // ------------------------------  Stage 2 : Divergent Stage: -------------------------------------
    //                                Accumulate global sum for last iteration of a group
    if(is_last_iter){
#if 1
        float* addr_mean = scratch_addr + 2 * NG * GPN_VEC_LEN;
#else
        float* addr_mean = output;
#endif
        float* addr_var  = addr_mean + 32; //Assuming max num groups = 32

        v32accfloat* addr_sum = (v32accfloat*)scratch_addr;

        for (int ngrp=0; ngrp<NG; ngrp++)
        //chess_prepare_for_pipelining
        chess_no_hw_loop
        {
            addr_mean[ngrp] = aie::reduce_add(aie::accum<accfloat, 32>(addr_sum[ngrp]).to_vector<float>());  // processes both mean and variance - placed contiguously
            addr_var[ngrp] = aie::reduce_add(aie::accum<accfloat, 32>(addr_sum[ngrp+NG]).to_vector<float>());
        }
        
        v32accfloat* mean_vec = (v32accfloat*)addr_mean;

        int tot_el_per_grp = ColPerGrp * mparam; 
        v32float inv_K = broadcast_to_v32float(inv(fix2float(tot_el_per_grp))); // Same as AIE2

        for (int niter = 0; niter < 2; ++niter) // TODO: Generalize
        chess_no_hw_loop
        {
            if(enable_global_reduce){
                global_reduce((v32float)mean_vec[niter], (float *)(mean_vec + niter), 1);
            }
            mean_vec[niter] = mul_elem_32((v32float)mean_vec[niter], inv_K);
        }
        /*if(rowIdx == 2 && colIdx == 0)
        {
            for(int i=0;i<NG;i++){
                printf("addr_mean: %f \n", addr_mean[i]);
            }
            for(int i=0;i<NG;i++){
                printf("addr_var: %f \n", addr_var[i]);
            }
        }*/
        
        // NOTE: following snippet assume num groups = 32. TODO: generalize it.
        VecT mean_vec_t;
        if constexpr (std::is_same<T, float16>::value) {
            mean_vec_t = to_v32float16(*(v32accfloat*)mean_vec);
        } else {
            mean_vec_t = to_v32bfloat16(*(v32accfloat*)mean_vec);
        }
        v32accfloat* var_vec = (v32accfloat*)(addr_var);
        *var_vec = msc_elem_32(mean_vec_t, mean_vec_t, *var_vec); //Same as AIE2P

        // Add the global epsilon value to Var[x]
        v16accfloat epsilon16 = broadcast_to_v16accfloat(eps); // Changed to AIE4
        v32accfloat epsilon32 = concat(epsilon16, epsilon16); // Same as AIE2P
        *var_vec = add(*var_vec, epsilon32); // Same as AIE2P

        zero_init = 1;  

        invsqrt_vectorwise<T> ((float*)var_vec, (float*)var_vec, NG);
        expansionMeanVar<T>(addr_mean, scratch_addr, NG);
        expansionMeanVar<T>(addr_var, scratch_addr+NG*GPN_VEC_LEN, NG);
#if 0
    if(rowIdx == 2 && colIdx == 0)
    {
        for(int i=0;i<3*32;i++){
            printf("addr_mean: %f \n", scratch_addr[i]);
        }
        for(int i=0;i<3*32;i++){
            printf("addr_var: %f \n", scratch_addr[NG*GPN_VEC_LEN + i]);
        }
    }
#endif
    }
}

/*void gpn_apply_mask_in_last_iter(
    bfloat16* matA,
    int NG,
    int ElemPerGrpCore,
    int is_last_iter,
    int mask_enb,
    int msubv_residual_core,
    int true_elem_in_last_res_row
) {
    const int vec_len = 16;
    const int msubv_core = ElemPerGrpCore / vec_len;
    const int elem_per_row = vec_len * NG/2;
    uint32_t mask = 0xFFFFFFFF;
    const bfloat16 mask_val = 0.0f;

    if (is_last_iter && mask_enb == 1) {
        for (int row = 0; row < msubv_core; ++row) {
            if (msubv_residual_core == 0 || row > (msubv_residual_core - 1)) {
                mask = 0;
            } else if (row == (msubv_residual_core - 1) && true_elem_in_last_res_row != 0) {
                mask = ((1 << true_elem_in_last_res_row) - 1) * 65537;
            } else {
                mask = 0xFFFFFFFF;
            }
            mask_w8_cols(reinterpret_cast<bfloat16*>(matA + row * elem_per_row),
                         elem_per_row / 8, 0, mask_val, mask);
        }
    }
}*/

void run_group_norm_qdq(KernelArgs& args)
{
    set_rnd(rnd_conv_even);
    const aie::saturation_mode sat = aie::tile::current().get_saturation();
    aie::tile::current().set_saturation(aie::saturation_mode::saturate);
    auto rowIdx = get_coreid( ) & 7;
    auto colIdx = get_coreid( ) >> 16;

    groupnorm_layer_param* layer_params = (groupnorm_layer_param*)args.params_data;
    
    //int8_t* matA   = reinterpret_cast<int8_t*>(layer_params->input_addr); 
    //int8_t* matW   = reinterpret_cast<int8_t*>(layer_params->weight_addr); 
    //int8_t* output = reinterpret_cast<int8_t*>(layer_params->output_addr);
    //int8_t* matQdq = reinterpret_cast<int8_t*>(layer_params->qdq_param_addr);

    int8_t* matA            = reinterpret_cast<int8_t*>(layer_params->input_addr); 
    int8_t* output          = reinterpret_cast<int8_t*>(layer_params->output_addr);
    int8_t* scratch_addr    = reinterpret_cast<int8_t*>(layer_params->scratch_addr);
    int8_t* wgt_addr        = reinterpret_cast<int8_t*>(layer_params->weight_addr);
    uint32_t Msubv_core     = layer_params->Msubv_core;
    uint32_t Nsubv_core     = layer_params->Nsubv_core;
    uint32_t GroupSize      = layer_params->GroupSize;
    uint32_t num_groups     = layer_params->num_groups;
    uint32_t op_select      = layer_params->op_select;
    uint32_t is_last_iter   = layer_params->is_last_iter;
    uint32_t Mparam_t       = layer_params->Mparam_t;
    uint32_t enable_global_reduce = layer_params->enable_global_reduce;
    uint32_t ElemPerGrpCore = Msubv_core * GroupSize;

    uniop_qdq_param* qdqprm = reinterpret_cast<uniop_qdq_param*>(layer_params->qdq_param_addr);
    
    bool dequant_en  = (qdqprm->dq_enable == 1);
    bool quant_en    = (qdqprm->q_enable == 1);

    KernelDqParam dq_krn_param;
    KernelQParam q_krn_param;

    dq_krn_param.inner_g    = Msubv_core*Nsubv_core/32; 
    dq_krn_param.sign_A     = layer_params->sign_A;
    
    q_krn_param.inner_g     = Msubv_core*Nsubv_core/32;
    q_krn_param.sign_O      = layer_params->sign_O;
    
    v32accfloat *dq_buf = (v32accfloat*)(qdqprm->dq_buf);
    v32accfloat *q_buf = (v32accfloat*)(qdqprm->q_buf);

#if 1
    dq_float16_v32((int8_t*) matA, (float*) dq_buf, (QDQFloatType*) matA, dq_krn_param, dequant_en);
#endif
    /*gpn_apply_mask_in_last_iter(matA, 
                                NG, 
                                ElemPerGrpCore, 
                                is_last_iter, 
                                mask_enb,
                                msubv_residual_core, 
                                true_elem_in_last_res_row);

    int Mparam_t = (Mparam_msb << 16) | Mparam_lsb;*/
#if 1
    if(op_select == 0 && num_groups > 0){
         //dq_float16_v32((int8_t*) matA, (float*) dq_buf, (QDQFloatType*) matA, dq_krn_param, dequant_en);
         mean_and_var<QDQFloatType>(
                 (QDQFloatType *)matA,
                 Msubv_core, 
                 Nsubv_core, 
                 Mparam_t, 
                 GroupSize,
                 ElemPerGrpCore,
                 is_last_iter, 
                 enable_global_reduce,
                 (float *)scratch_addr,
                 num_groups,
                 (float *)output, 
                 rowIdx, colIdx);
    }
#endif
    //op_select = is_last_iter ? 1 : 0; // TODO: Remove this line after testing
    if(op_select == 1 && num_groups > 0){
#if 1
        /*
        auto row = get_coreid( ) & 7;
        auto col = get_coreid( ) >> 16;
        if(col == 0 and row == 2)
        {
            chess_report(0xFFFF);
            chess_report(dq_buf[0]);
            chess_report(dq_buf[2]);
            chess_report(dequant_en);
        }
        dq_float16_v32((int8_t*) matA, (float*) dq_buf, (QDQFloatType*) output, dq_krn_param, true);
        */
        norm_and_affine<QDQFloatType>(
                 (QDQFloatType *)matA, 
                 (QDQFloatType *)wgt_addr,
                 (QDQFloatType *)output,
                 num_groups,
                 ElemPerGrpCore,
                 (float *)scratch_addr,
                 rowIdx, colIdx);
        q_float16_to_int16_v32((QDQFloatType*) output, (float*) q_buf, (int16*) output, q_krn_param, quant_en);
        
//#else
        /*QDQFloatType * input_pt = (QDQFloatType *)matA;
        QDQFloatType * output_pt = (QDQFloatType *)output;
        for(int ii = 0; ii < Msubv_core*Nsubv_core; ii++)
        {
            output_pt[ii] = input_pt[ii];
        }*/
#endif
    }
}

#endif
