/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __LAYER_NORM_H__
#define __LAYER_NORM_H__

#define IFM_SIGNED 0x0001
#define OFM_SIGNED 0x0010
#define IFM_16BITS 0x0100
#define OFM_16BITS 0x1000

template<int CORE_TILE_HIN = 8>
inline __attribute__((always_inline)) void layer_norm_col_split
(
        bfloat16 * restrict input,
        int8_t * restrict parameters,
        bfloat16 * restrict output,
        int Nlrn,
        int Nsubv,
        int split_type,
        int gamma_offset,
        int beta_offset,
        const int rowIdx, const int colIdx, 
        uint16_t simplified
)
{
    //int params_offset = (colIdx*NUM_ROWS+ (rowIdx-2)) * Nsubv * BF16_BYTES * 8;     // assuming 8x repeated and 8x8 tiled scale and bias params

    const int PER_ITERATION_ROW_BLOCK = 8;
    const int WIDTH_ALIGNMENT = 8;
    const int num_input_lanes = 64;
    const int num_output_lanes = num_input_lanes / 2;

    const float eps = EPSILON;
    const v16accfloat epsilon = v16accfloat(broadcast_float(eps));

    bfloat16 mean_factor = (simplified==1) ? 0.0 : 1.0;

    v32bfloat16 mac_op_1_values[2];
    v32bfloat16 mac_op_2_values[2];

    auto inIter1 = aie::begin_vector<num_input_lanes>(input);
    auto inIter2 = aie::begin_vector<num_input_lanes>(input);
    auto outIter = aie::begin_vector<num_input_lanes>(output);

    for (int Row_Block = 0; Row_Block < (CORE_TILE_HIN / PER_ITERATION_ROW_BLOCK); ++Row_Block)
    chess_prepare_for_pipelining
    {

        // Per-Iteration Layer Normalization on a 16x128 block of the Input Buffer
        v64accfloat sum[2];


        v32bfloat16 one_32_lane = broadcast_bfloat16( mean_factor );
        v64bfloat16 one_64_lane = concat( one_32_lane, one_32_lane);

        int zero_init = 1;

        // ------------------------------  Stage 1: Local Sums and Sums of Squares ---------------------------

        // Element wise accumulation  for each 8x8 block across the columns for each value in the input buffer as well as its squared value
        // sum[0] -> top 8x8 block accumulated sum
        // sum[1] -> top 8x8 block accumulated sum of squares

        for (int block_8x8 = 0; block_8x8 < (Nsubv / WIDTH_ALIGNMENT); ++block_8x8)
        chess_prepare_for_pipelining
        {
            v64bfloat16 one_8x8 = *inIter1++;

            sum[0] = mac_elem_64_conf(one_8x8, one_64_lane, sum[0], zero_init, 0, 0);
            sum[1] = mac_elem_64_conf(one_8x8, one_8x8, sum[1], zero_init, 0, 0);

            zero_init = 0;
        }


        // Reduce each 8x8 block to an 8x1 block by summing across the columns
        v16accfloat sums_block_1 = tree_add_8x8(sum[0]);
        v16accfloat sums_block_2 = tree_add_8x8(sum[1]);

        // Insert the local sum and sum of squares for each block into a vector such that the lower half of the vector contains the 16 row-wise sums and the
        // upper half contains the row-wise sums of squares
        v8accfloat local_sum = extract_v8accfloat(sums_block_1, 0);
        v8accfloat local_sum_sq = extract_v8accfloat(sums_block_2, 0);
        v16accfloat local_sums = concat(local_sum, local_sum_sq);

        v16accfloat rec_sum;
        v16accfloat write_inter_sum;

        v16accfloat mac_op_1;
        //event0();

        // ------------------------------  Stage 2 : Divergent Stage: -------------------------------------
        //                                Accumulate a per-row global sum
        if(split_type)
            mac_op_1 = global_add_reduce( local_sums, rowIdx, colIdx );
        else
            mac_op_1 = local_sums;
        event0();
        // ----------------------------------  Stage 3: Normlization Step ------------------------------------------
        // Use the per-row global sum and sums of squares to calculate the final normalized values


        // Calculate per row values of:
        //
        //              gamma                            epsilon- * gamma
        //      -------------------------  , beta  -    -------------------
        //       sqrt(Var[x] + epsilon)                 sqrt(Var[x] + epsilon)
        //
        // from the global sum and sums of square column vectors



        // divide the global sum column vectors by 2048 to get a per-row E[x] and E[x^2]
        v16float inv_K = broadcast_to_v16float(inv(fix2float(Nlrn)));
        v16accfloat mean_sq_float = mul_elem_16(v16float(mac_op_1), inv_K);

        // Calculate per-row Variance using the formula E[x^2] - (E[x])^2
        v8accfloat mean_float = extract_v8accfloat(mean_sq_float, 0);
        v8accfloat sq_mean_float = extract_v8accfloat(mean_sq_float, 1);
        v32bfloat16 mean_bf16 = to_v32bfloat16(concat(mean_float, mean_float, mean_float, mean_float));
        v16accfloat variance = msc_elem_16(mean_bf16, mean_bf16, concat(sq_mean_float, sq_mean_float));

        // Add the global epsilon value to Var[x]
        variance = add(variance, epsilon);
        event1();

        // Calculate 1 / sqrt(Var[x] + epsilon)
        //
        // v16int32 upd_elem(v16int32 v, int idx, int b)
        // extract_elem() [3/20] float extract_elem(v16float v,int idx)
        // invsqrt() [1/2] float invsqrt(float a)
        //v16float() [5/16] v16float(v16accfloat)

        // v16int32 tmp;
        // for(int i =0; i<16; i++){
        //     tmp = upd_elem(tmp,  i, as_int32(invsqrt(extract_elem(v16float(variance),  i))));
        // }

        v16float tmp;

        for(int i = 0; i < 4; i++) chess_no_hw_loop {
            v2float in = extract_v2float( v16float(variance), i );
            v2float out = set_v2float( 0, invsqrt( extract_elem( in, 0 )));
            out = insert( out, 1, invsqrt( extract_elem( in, 1 )));
            tmp = insert( tmp, i, out );
        }

        v16accfloat invsqrt_var_val = v16accfloat(tmp);
        v16accfloat invsqrt_var_val2 = invsqrt_var_val;
        v64bfloat16* restrict p_scale = (v64bfloat16*) (parameters + gamma_offset);
        v64bfloat16* restrict p_bias = (v64bfloat16*) (parameters + beta_offset);

        v32bfloat16 operand_1 = to_v32bfloat16(concat(invsqrt_var_val2, invsqrt_var_val));
        v32bfloat16 negative_mu = to_v32bfloat16(negmul_elem_32(mean_bf16, one_32_lane));
        v32bfloat16 operand_2 = to_v32bfloat16(mul_elem_32(operand_1, negative_mu));
        mac_op_1_values[0] = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( operand_1, 0 )), T16_8x4 );
        mac_op_1_values[1] = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( operand_1, 1 )), T16_8x4 );

        mac_op_2_values[0] = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( operand_2, 0 )), T16_8x4 );
        mac_op_2_values[1] = shuffle( broadcast_to_v32bfloat16( extract_v4bfloat16( operand_2, 1 )), T16_8x4 );

        v64bfloat16 mac_op_1_st3 = concat(mac_op_1_values[0], mac_op_1_values[1]);
        v64bfloat16 mac_op_2_st3 = concat(mac_op_2_values[0], mac_op_2_values[1]);

        for (int block_8x8 = 0; block_8x8 < (Nsubv / WIDTH_ALIGNMENT); block_8x8++)
        chess_prepare_for_pipelining
        {
            v64bfloat16 scale = *p_scale++;
            v64accfloat bias  = to_v64accfloat(*p_bias++);

            v64bfloat16 scaled_op_1 = to_v64bfloat16(mul_elem_64(scale, mac_op_1_st3));

            v64accfloat scaled_b_op_1 = mac_elem_64(scale, mac_op_2_st3, bias);

            aie::vector<bfloat16, num_input_lanes> res = to_v64bfloat16(mac_elem_64(*inIter2++, scaled_op_1, scaled_b_op_1));
            *outIter++ = res;
        }
    }
}

void run_lrn_qdq(KernelArgs& args)
{
    set_rnd(rnd_conv_even);
    const aie::saturation_mode sat = aie::tile::current().get_saturation();
    aie::tile::current().set_saturation(aie::saturation_mode::saturate);
    const int MSUB = 8;
    bfloat16 * matA   = static_cast<bfloat16*>(args.s2mm_ch0_data);
    //int8_t * matB   = static_cast<int8_t*>(args.s2mm_ch1_data);
    int8_t * output = static_cast<int8_t*>(args.mm2s_ch0_data);
    int rowIdx = (get_coreid() & 0xF);
    //int colIdx = (get_coreid() >> 16);
    uint16_t* args_params   = (uint16_t*)args.params_data;
    uint16_t Nlrn           = args_params[0];
    uint16_t Nsubv          = args_params[1];
    uint16_t split_type     = args_params[2]; // 0-> row split ; 1 -> col_split
    uint16_t colIdx         = args_params[3];
    uint16_t bias_addr      = args_params[4];
    uint16_t gamma_offset   = args_params[5];
    uint16_t beta_offset    = args_params[6];
    uint16_t qdq_addr       = args_params[7];
    uint16_t sign_type      = args_params[8];
    uint16_t simplified     = args_params[9];

    bool in_sign      = sign_type & IFM_SIGNED; // true --> IFM Signed
    bool out_sign     = sign_type & OFM_SIGNED; // true --> OFM Signed
    bool is_in_int16  = sign_type & IFM_16BITS; // true --> IFM 16 bits 
    bool is_out_int16 = sign_type & OFM_16BITS; // true --> OFM 16 bits
    
    int8_t * matB   = static_cast<int8_t*>(conv_to_local_ptr(bias_addr));
    int8_t * qdq_offset   = static_cast<int8_t*>(conv_to_local_ptr(qdq_addr));
    
    bfloat16 s_out   = *(bfloat16 *) (qdq_offset);
    uint16_t zp_out  = *(uint16_t *) (qdq_offset + 4);
    bool quant_en    = *(bool *) (qdq_offset + 8);
	bfloat16 s_in    = *(bfloat16 *) (qdq_offset + 12);  
    uint16_t zp_in   = *(uint16_t *) (qdq_offset + 16);
    bool dequant_en  = *(bool *) (qdq_offset + 20);

    bfloat16* lrn_in_ptr = (bfloat16 *) matA;
    // Overwrite the lrn input pointer for in-place compute 
    if(dequant_en){
        lrn_in_ptr  = (bfloat16 *) output;
    }
    bfloat16* lrn_out_ptr = (bfloat16 *) output;

#if 1
    dequant_int16_to_bf16((int8_t *)matA, (int8_t *)output, MSUB*Nsubv, zp_in, s_in, in_sign, dequant_en, is_in_int16);
    
    layer_norm_col_split<MSUB>(lrn_in_ptr, matB, lrn_out_ptr, Nlrn, Nsubv, split_type, gamma_offset, beta_offset, rowIdx, colIdx, simplified);
    
    quant_bf16_to_int16(output, output, MSUB*Nsubv, zp_out, s_out, out_sign, quant_en, is_out_int16);
#else
    bfloat16 * input_pt = matA;
    bfloat16 * output_pt = (bfloat16 *)output;
    for(int ii = 0; ii < MSUB*Nsubv; ii++)
    {
        output_pt[ii] = input_pt[ii];
    }
#endif
}
#endif

