/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __REDUCE_OP_H__
#define __REDUCE_OP_H__

#include "aie_api/aie.hpp"
#include <type_traits>
#include <stdio.h>

#define IFM_SIGNED 0x0001
#define OFM_SIGNED 0x0010
#define IFM_16BITS 0x0100
#define OFM_16BITS 0x1000

// template<bool en_c=true, bool en_hw=true>
__attribute__((always_inline)) inline void reduce_axis(
    bfloat16 * restrict input,
    bfloat16 * restrict output,
    int Nlrn,
    uint16_t inner_loop,
    uint16_t outer_loop,
    uint64_t mask_last,
    dims_3d_param dims_in,
    uint8_t axis_flags,
    int c1,
    const int rowIdx, const int colIdx,
    uint16_t op_type = 0,    // 0=sum, 1=mean, 2=max
    int split_type = 0      // 0=local reduce (no cross-core), 1=global reduce (cross-core via global_add_reduce)
)
{
    bfloat16 * pI = input;
    bfloat16 * pO = output;
    dims_3d_t dimsI = dims_in.instantiate();
    constexpr unsigned N = 32;
    constexpr uint64_t FULL_MASK = 0xffffffffffffffff;

    auto get_mask_value = [&](unsigned i, unsigned o) __aie_inline -> uint64_t {
        return (axis_flags == 0) ? 
            ((i < inner_loop - 1) ? FULL_MASK : mask_last) :
            ((o < outer_loop - 1) ? FULL_MASK : mask_last);
    };

    auto load_and_mask_vectors = [&](unsigned i, unsigned o) __aie_inline {
        uint64_t mask_val = get_mask_value(i, o);
        uint32_t mask_val1 = (uint32_t)(mask_val & 0xffffffff);
        uint32_t mask_val2 = (uint32_t)(mask_val >> 32);

        auto vec1 = aie::load_v<N>(pI);
        auto vec2 = aie::load_v<N>(pI + N);

        vec1 = aie::select((bfloat16)0, vec1, aie::mask<32>::from_uint32(mask_val1));
        vec2 = aie::select((bfloat16)0, vec2, aie::mask<32>::from_uint32(mask_val2));

        return std::make_pair(vec1, vec2);
    };

    // Unified inner loop for all reduction operations
    auto process_inner_loop = [&](auto& state1, auto& state2, auto op_func, unsigned o) __aie_inline {
        [[using chess: min_loop_count( 1 ), no_hw_loop ]]
        for (unsigned i = 0; i < inner_loop; i++) {
            auto [vec1, vec2] = load_and_mask_vectors(i, o);
            op_func(state1, state2, vec1, vec2, i);
            pI = add_3d_byte(pI, dimsI);
        }
    };

    if (op_type == 2) {// Max
        bfloat16 neg_inf;
        *(uint16_t*)&neg_inf = 0xFF80;  // bfloat16 representation of -inf
        
        [[using chess: min_loop_count(1), no_hw_loop]]
        for (unsigned o = 0; o < outer_loop; o++) {
            v32bfloat16 vec1_max = broadcast_bfloat16(neg_inf);
            v32bfloat16 vec2_max = broadcast_bfloat16(neg_inf);

            process_inner_loop(vec1_max, vec2_max, 
                [](v32bfloat16& m1, v32bfloat16& m2, auto v1, auto v2, unsigned) __aie_inline {
                    m1 = max(m1, v1);
                    m2 = max(m2, v2);
                }, o);

            if (axis_flags == 0) {
                vec1_max = max(vec1_max, vec2_max);
                pO[o] = aie::reduce_max(aie::vector<bfloat16, N>(vec1_max));
            } else {
                aie::store_v(pO, aie::vector<bfloat16, N>(vec1_max));
                aie::store_v(pO + N, aie::vector<bfloat16, N>(vec2_max));
                pO += 2 * N;
            }
        }
    } else { // Mean (op_type==1) or Sum (op_type==0)
        bfloat16 inv_n = (op_type == 1) ? (bfloat16)inv(fix2float(Nlrn)) : (bfloat16)1.0;
        
        [[using chess: min_loop_count(1), no_hw_loop]]
        for (unsigned o = 0; o < outer_loop; o++) {
            aie::accum<accfloat, N> acc1, acc2;

            int zero_acc = 1;
            process_inner_loop(acc1, acc2,
                [&zero_acc](aie::accum<accfloat, N>& a1, aie::accum<accfloat, N>& a2, auto v1, auto v2, unsigned) __aie_inline {
                    a1 = aie::add(aie::op_zero(a1, zero_acc), v1);
                    a2 = aie::add(aie::op_zero(a2, zero_acc), v2);
                    zero_acc = 0;
                }, o);

            if (axis_flags == 0) {
                acc1 = aie::add(acc1, acc2);
                pO[o] = (bfloat16)(aie::reduce_add(acc1.to_vector<bfloat16>()) * inv_n);
            } else {
                auto vec1 = acc1.to_vector<bfloat16>(c1);
                auto vec2 = acc2.to_vector<bfloat16>(c1);
                aie::store_v(pO, aie::mul(vec1, inv_n).to_vector<bfloat16>(0));
                aie::store_v(pO + N, aie::mul(vec2, inv_n).to_vector<bfloat16>(0));
                pO += 2 * N;
            }
        }
    }
}

__attribute__((always_inline)) inline void run_reduce_qdq_impl(KernelArgs& args)
{
    set_rnd(rnd_conv_even);
    const aie::saturation_mode sat = aie::tile::current().get_saturation();
    aie::tile::current().set_saturation(aie::saturation_mode::saturate);
    bfloat16 * input   = static_cast<bfloat16*>(args.s2mm_ch0_data);
    int8_t * output = static_cast<int8_t*>(args.mm2s_ch0_data);
    int rowIdx = (get_coreid() & 0xF);
    //int colIdx = (get_coreid() >> 16);
    uint16_t* args_params   = (uint16_t*)args.params_data;
	uint16_t Nlrn           = args_params[0];
    uint16_t Nsubv          = args_params[1];
    uint16_t split_type     = args_params[2]; // 0-> row split ; 1 -> col_split
    uint16_t colIdx         = args_params[3];
    uint16_t bias_addr      = args_params[4];
    uint16_t mode_index     = args_params[5]; //0 -> sum, 1-> mean, 2-> max
    uint16_t beta_offset    = args_params[6];
    uint16_t qdq_addr       = args_params[7];
    uint16_t sign_type      = args_params[8];
    uint16_t dim_type       = args_params[9]; //Reduce_C = 0,Reduce_W = 1,Reduce_H = 2,Reduce_N = 3,Reduce_HW = 4,Reduce_WC = 5,Reduce_HC = 6,Reduce_HWC = 7,Reduce_NHWC =8
    uint16_t inner_count    = args_params[10];
    uint16_t outer_count    = args_params[11];
    uint16_t step0          = args_params[12];
    uint16_t step1          = args_params[13];
    uint16_t step2          = args_params[14];
    uint16_t wrap0          = args_params[15];
    uint16_t wrap1          = args_params[16];
    uint16_t mask_part3     = args_params[17];  
    uint16_t mask_part2     = args_params[18];
    uint16_t mask_part1     = args_params[19];
    uint16_t mask_part0     = args_params[20];

    uint64_t last_cg_mask = (static_cast<uint64_t>(mask_part3) << 48) |
                            (static_cast<uint64_t>(mask_part2) << 32) |
                            (static_cast<uint64_t>(mask_part1) << 16) |
                            static_cast<uint64_t>(mask_part0);

    int total_num_elems = outer_count * inner_count * 64;//outer_count*Nsubv;

    uint16_t op_type  = mode_index;
    bool in_sign      = sign_type & IFM_SIGNED; // true --> IFM Signed
    bool out_sign     = sign_type & OFM_SIGNED; // true --> OFM Signed
    bool is_in_int16  = sign_type & IFM_16BITS; // true --> IFM 16 bits
    bool is_out_int16 = sign_type & OFM_16BITS; // true --> OFM 16 bits

    int8_t * qdq_offset   = static_cast<int8_t*>(conv_to_local_ptr(qdq_addr));
    
    bfloat16 s_out   = *(bfloat16 *) (qdq_offset);
    uint16_t zp_out  = *(uint16_t *) (qdq_offset + 4);
    bool quant_en    = *(bool *) (qdq_offset + 8);
	bfloat16 s_in    = *(bfloat16 *) (qdq_offset + 12);  
    uint16_t zp_in   = *(uint16_t *) (qdq_offset + 16);
    bool dequant_en  = *(bool *) (qdq_offset + 20);

    bfloat16* lrn_in_ptr = (bfloat16 *) input;
    bfloat16* lrn_out_ptr = (bfloat16 *) output;

    // if(dequant_en){
    //     lrn_in_ptr = (bfloat16 *) output;
    //     lrn_out_ptr = (bfloat16 *) input;
    // }

#if 1

    dequant_int16_to_bf16((int8_t *)input, (int8_t *)input, total_num_elems, zp_in, s_in, in_sign, dequant_en, is_in_int16);

    // Setup dims_3d_param for memory traversal
    dims_3d_param dims_in;
    dims_in.num0 = wrap0 -1 ;
    dims_in.num1 = wrap1 -1 ;
    dims_in.inc0 = (int32_t)(int16_t)step0;
    dims_in.inc1 = (int32_t)(int16_t)step1;
    dims_in.inc2 = (int32_t)(int16_t)step2;

    uint8_t axis_flags = (dim_type==0 || dim_type==5 || dim_type==7 || dim_type==8)? 0:1;
    int c1 = 0;  // Shift parameter for to_vector

    reduce_axis(lrn_in_ptr, lrn_out_ptr, Nlrn,
                            inner_count, outer_count, last_cg_mask,
                            dims_in, axis_flags, c1,
                            rowIdx, colIdx, op_type, split_type);

    int reduce_output_size = (axis_flags == 0) ? outer_count : outer_count * 64;

    if (reduce_output_size < 64) { reduce_output_size = 64;}

    quant_bf16_to_int16((int8_t*)lrn_out_ptr, output, reduce_output_size, zp_out, s_out, out_sign, quant_en, is_out_int16);

#else
    bfloat16 * input_pt = input;
    bfloat16 * output_pt = (bfloat16 *)output;
    for(int ii = 0; ii < total_num_elems; ii++)
    {
        output_pt[ii] = input_pt[ii];
    }
#endif
}

#endif // __REDUCE_OP_H__