/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __GROUP_NORM_H__
#define __GROUP_NORM_H__
#include "softmax.cc"

#define IFM_SIGNED 0x0001
#define OFM_SIGNED 0x0010
#define IFM_16BITS 0x0100
#define OFM_16BITS 0x1000

// expands each elements repeated 16 times
inline __attribute__((always_inline)) void expansionMeanVar(float* input,float* output,int num_groups){
    v16float* out = (v16float*)output;

    for(int i=0;i<num_groups;i++)
    chess_no_hw_loop
    {
        out[i] = broadcast_float(input[i]);
    }

}

#define NOTPASSTHROUGH 1
inline __attribute__((always_inline)) void invsqrt_vectorwise
(
  float * input,
  float * output,
  int16_t N_Dim
)
{   
    for (int i = 0; i < N_Dim; i++)
    chess_no_hw_loop
    {     
       output[i] = invsqrt (input[i]);
    }
}

inline __attribute__((always_inline)) void norm_and_affine
(
        bfloat16 * input,
        int8_t * restrict parameters,
        bfloat16 * output,
        uint16_t num_groups,
        uint16_t nsubv_core,
        int16 ElemPerGrpCore,
        int16 GammaLen, 
        uint16_t is_last_iter,
        float *scratch_addr,
        const int rowIdx, const int colIdx,
        int aieInst,
        int numAieInst,
        const int vec_len
)
{

    float* scratch_addr_var = scratch_addr + num_groups*vec_len; //num_groups * 16
    v32bfloat16 *outIter = (v32bfloat16*) output;

    v32accfloat * mean_float  = (v32accfloat *)scratch_addr;
    v32accfloat * vec_inv_var = (v32accfloat *)(scratch_addr_var);

    v32bfloat16 * p_scale  = (v32bfloat16 *)(parameters + GammaLen*BF16_BYTES*(aieInst)) ;
    v32bfloat16 * p_bias = (v32bfloat16 *)(parameters + GammaLen*BF16_BYTES*numAieInst + GammaLen*BF16_BYTES*(aieInst));

    v32bfloat16 one_32_lane = broadcast_bfloat16( 1.0 );
    v32bfloat16 *inIter2 = (v32bfloat16 *) input;

    auto mean_iter = aie::begin_vector_circular<32>(scratch_addr, num_groups*vec_len);
    auto inv_var_iter = aie::begin_vector_circular<32>(scratch_addr_var, num_groups*vec_len);
    auto bias_iter = aie::begin_vector_circular<32>((bfloat16*)p_bias, GammaLen);
    auto scale_iter = aie::begin_vector_circular<32>((bfloat16*)p_scale, GammaLen);

    //printf("INSIDE NORM AND AFFINE");
    for (int ii = 0; ii < nsubv_core/32; ii++)
    chess_prepare_for_pipelining
    chess_loop_range(2, )
    {
        v32bfloat16 mean_bf16 = to_v32bfloat16(v32accfloat((*mean_iter++).to_native()));
        v32bfloat16 negative_mu = to_v32bfloat16(negmul_elem_32(mean_bf16, one_32_lane));
        v32bfloat16 mac_op_1_st3 = to_v32bfloat16(v32accfloat((*inv_var_iter++).to_native()));
        v32bfloat16 mac_op_2_st3 = to_v32bfloat16(mul_elem_32(mac_op_1_st3, negative_mu));

        v32bfloat16 scale   = *scale_iter++; 
        v32bfloat16 bias_t  = *bias_iter++; 

        v32accfloat bias    = to_v32accfloat(bias_t);

        v32bfloat16 scaled_op_1 = to_v32bfloat16(mul_elem_32(scale, mac_op_1_st3));

        v32accfloat scaled_b_op_1 = mac_elem_32(scale, mac_op_2_st3, bias);
        v32accfloat res = (mac_elem_32((inIter2[ii]), scaled_op_1, scaled_b_op_1));
        outIter[ii] = to_v32bfloat16(res);

    }
#if 0
        if(rowIdx == 2 && colIdx == 0)
        {
            //aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)addr_mean), true, "mean = ");
            //aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)(addr_mean+16)), true, "mean = ");
            aie::print(*inv_var_iter++, true, "res = ");
            //aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)(addr_var+16)), true, "variance = ");
            printf("--------------------------------------------------\n");
        }
#endif

}

inline __attribute__((always_inline)) void mean_and_var
(
    bfloat16 * input,
    uint16_t nsubv,
    int mparam,
    uint16_t ColPerGrp,
    int16 ElemPerGrpCore,
    uint16_t is_last_iter,
    float *scratch_addr,
    uint16_t num_groups,
    float* output, 
    const int rowIdx, const int colIdx,     
    int aieInst, int var_offset,
    const int vec_len
)
{
    float eps = EPSILON;
    //const int ElemPerGrpCore = 80; //16;

    auto inIter1 = aie::begin_vector<16>(input);
    
    static int zero_init = 1;
    v16accfloat* sum = (v16accfloat*)scratch_addr;
    v16accfloat* sum_sq = (v16accfloat*)(scratch_addr + num_groups*vec_len);

    v32bfloat16 one_32_lane = broadcast_bfloat16( 1.0 );

    // ------------------------------  Stage 1: Local Sums and Sums of Squares ---------------------------
    
    for (int miter = 0; miter < (ElemPerGrpCore/16); ++miter)
    chess_prepare_for_pipelining
    chess_loop_range(4,) 
    {
        for (int niter = 0; niter < num_groups; ++niter)
        chess_prepare_for_pipelining
        chess_loop_range(2,) 
        {
            v32bfloat16 in_1x32;
            in_1x32 = insert(in_1x32, 0, *inIter1++);
            sum[niter] = mac_elem_16_conf(in_1x32, one_32_lane, sum[niter], zero_init, 0, 0);
            sum_sq[niter] = mac_elem_16_conf(in_1x32, in_1x32, sum_sq[niter], zero_init, 0, 0);
        }
            zero_init = 0;
    }


    // ------------------------------  Stage 2 : Divergent Stage: -------------------------------------
    //                                Accumulate global sum for last iteration of a group
    if(is_last_iter){
#if 1
        float* addr_mean = scratch_addr + 2 * num_groups * vec_len;
#else
        float* addr_mean = output;
#endif

        float* addr_var  = addr_mean + var_offset;
        v16accfloat* addr_sum = (v16accfloat*)scratch_addr;

        for (int ngrp=0; ngrp<num_groups; ngrp++)
        //chess_prepare_for_pipelining
        chess_no_hw_loop
        {
            addr_mean[ngrp] = aie::reduce_add(aie::accum<accfloat, 16>(addr_sum[ngrp]).to_vector<float>());  // processes mean - placed contiguously
            addr_var[ngrp] = aie::reduce_add(aie::accum<accfloat, 16>(addr_sum[ngrp+num_groups]).to_vector<float>());  // processes variance - placed contiguously
        }

        v16accfloat* mean_vec = (v16accfloat*)addr_mean;

        int tot_el_per_grp = ColPerGrp * mparam; 
        v16float inv_K = broadcast_to_v16float(inv(fix2float(tot_el_per_grp)));

        int niter_run = (var_offset / vec_len)*2;
        for (int niter = 0; niter < niter_run; ++niter)
        chess_no_hw_loop
        {
                mean_vec[niter] = global_add_reduce( mean_vec[niter], rowIdx, colIdx );  // processes mean - placed contiguously
                mean_vec[niter] = mul_elem_16(v16float(mean_vec[niter]), inv_K);
        }
        
        // NOTE: following snippet assume num groups = 32. TODO: generalize it.
        v32bfloat16 mean_bf16 = to_v32bfloat16(*(v32accfloat*)mean_vec);
        v32accfloat* var_vec = (v32accfloat*)(addr_var);
        *var_vec = msc_elem_32(mean_bf16, mean_bf16, *var_vec);

        // Add the global epsilon value to Var[x]
        v16accfloat epsilon16 = v16accfloat(broadcast_float(eps));
        v32accfloat epsilon32 = concat(epsilon16, epsilon16);
        *var_vec = add(*var_vec, epsilon32);

        zero_init = 1;  

        invsqrt_vectorwise ((float*)var_vec, (float*)var_vec, num_groups);
        expansionMeanVar(addr_mean, scratch_addr, num_groups);
        expansionMeanVar(addr_var, scratch_addr+num_groups*vec_len, num_groups);

#if 0
    if(rowIdx == 2 && colIdx == 0 && aieInst == 0)
    {
        aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)addr_var), true, "variance = ");
        //aie::print(aie::accum<accfloat, 16>(*(v16accfloat*)(addr_var+16)), true, "variance = ");
        printf("--------------------------------------------------\n");
    }
#endif
    }
}

void gpn_apply_mask_in_last_iter(
    bfloat16* matA,
    int num_groups,
    int ElemPerGrpCore,
    int is_last_iter,
    int mask_enb,
    int msubv_residual_core,
    int true_elem_in_last_res_row,
    const int vec_len
) {
    const int msubv_core = ElemPerGrpCore / vec_len;
    const int elem_per_row = vec_len * num_groups;
    uint32_t mask = 0xFFFFFFFF;
    const bfloat16 mask_val = 0.0f;

    if (mask_enb == 1) {
        for (int row = 0; row < msubv_core; ++row) {
            if (row > msubv_residual_core || (msubv_residual_core == 0 && true_elem_in_last_res_row == 0)) {
                mask = 0;
            } else if (row == msubv_residual_core) {
                mask = ((1 << true_elem_in_last_res_row) - 1) * 65537;
            } else {
                mask = 0xFFFFFFFF;
            }
            mask_w8_cols(reinterpret_cast<bfloat16*>(matA + row * elem_per_row),
                         elem_per_row / 8,mask_val, mask);
        }
    }
}

void run_group_norm_qdq(KernelArgs& args)
{
    const int vec_len = 16;
    set_rnd(rnd_conv_even);
    const aie::saturation_mode sat = aie::tile::current().get_saturation();
    aie::tile::current().set_saturation(aie::saturation_mode::saturate);
    bfloat16 * matA   = static_cast<bfloat16*>(args.s2mm_ch0_data);
    int8_t * wgt   = static_cast<int8_t*>(args.s2mm_ch1_data);
    int8_t * output = static_cast<int8_t*>(args.mm2s_ch0_data);
    int rowIdx = (get_coreid() & 0xF);
    //int colIdx = (get_coreid() >> 16);
    uint16_t* args_params   = (uint16_t*)args.params_data;
    
    uint16_t Nsubv_core     = args_params[0];
    uint16_t Mparam_msb     = args_params[1];
    uint16_t Mparam_lsb     = args_params[2];
    uint16_t ColPerGrp      = args_params[3];
    float *scratch_addr     = static_cast<float*>(conv_to_local_ptr(args_params[4]));
    uint16_t is_last_iter   = *((uint8_t *)args_params+5*2);
    uint16_t op_select      = *((uint8_t *)args_params+5*2 + 1); //args_params[5];
    uint16_t num_groups     = args_params[6]; 
    uint16_t ElemPerGrpCore = args_params[7];
    uint16_t GammaLen       = args_params[8];
    uint16_t mask_enb       = args_params[9];
    int colIdx              = args_params[10];
    int aieInst             = args_params[11];
    int numAieInst          = args_params[12]; // Total number of AIE instances
    int msubv_residual_core = args_params[13];
    int true_elem_in_last_res_row = args_params[14];
    int var_offset = args_params[15];
    int8_t* qdq_offset = static_cast<int8_t*>(conv_to_local_ptr(args_params[16]));
    uint16_t sign_type      = args_params[17];
    
    bool in_sign      = sign_type & IFM_SIGNED; // true --> IFM Signed
    bool out_sign     = sign_type & OFM_SIGNED; // true --> OFM Signed
    bool is_in_int16  = sign_type & IFM_16BITS; // true --> IFM 16 bits 
    bool is_out_int16 = sign_type & OFM_16BITS; // true --> OFM 16 bits

    bfloat16 s_out   = *(bfloat16 *) (qdq_offset);
    uint16_t zp_out  = *(uint16_t *) (qdq_offset + 4);
    bool quant_en    = *(bool *) (qdq_offset + 8);
	bfloat16 s_in    = *(bfloat16 *) (qdq_offset + 12);  
    uint16_t zp_in   = *(uint16_t *) (qdq_offset + 16);
    bool dequant_en  = *(bool *) (qdq_offset + 20);

    bfloat16* gpn_in_ptr = (bfloat16 *) matA;
    // Overwrite the gpn input pointer for in-place compute 
    if(dequant_en){
        gpn_in_ptr  = (bfloat16 *) output;
    }
    
    dequant_int16_to_bf16((int8_t *)matA, (int8_t *)gpn_in_ptr, Nsubv_core, zp_in, s_in, in_sign, dequant_en, is_in_int16);

    gpn_apply_mask_in_last_iter(gpn_in_ptr, 
                                num_groups, 
                                ElemPerGrpCore, 
                                is_last_iter, 
                                mask_enb,
                                msubv_residual_core, 
                                true_elem_in_last_res_row,
                                vec_len);
    
    int Mparam_t = (Mparam_msb << 16) | Mparam_lsb;

    if(op_select == 0){
         mean_and_var(
                 gpn_in_ptr, 
                 Nsubv_core, 
                 Mparam_t, 
                 ColPerGrp,
                 ElemPerGrpCore,
                 is_last_iter, 
                 scratch_addr,
                 num_groups,
                 (float *)output, 
                 rowIdx, colIdx, aieInst, 
                 var_offset, vec_len);
    }
    if(op_select == 1){
         norm_and_affine(
                 gpn_in_ptr, 
                 wgt, 
                 (bfloat16 *)output,
                 num_groups,
                 Nsubv_core, 
                 ElemPerGrpCore,
                 GammaLen, 
                 is_last_iter,
                 scratch_addr,
                 rowIdx, colIdx, aieInst, 
                 numAieInst, vec_len);

        quant_bf16_to_int16(output, output, Nsubv_core, zp_out, s_out, out_sign, quant_en, is_out_int16);
    }
    
#if 0 
    bfloat16 * input_pt = matA;
    bfloat16 * output_pt = (bfloat16 *)output;
    if(op_select == 1){
        for(int ii = 0; ii < Nsubv_core; ii++)
        {
            output_pt[ii] = input_pt[ii];
        }
    }
    /*if(rowIdx == 2 && colIdx == 0 && is_last_iter)
    {
        for(int ii = 0; ii<32; ii++)
            printf("Mean_aie[%d]: %f \n", ii, *(scratch_addr + ii));
        for(int ii = 0; ii<32; ii++)
            printf("Variance_aie[%d]: %f \n", ii, *(scratch_addr + 32 + ii));
        //printf("variance: %f\n", variance);
        printf("--------------------------------------------------\n");
    }*/
#endif
    
}
#endif
