#ifndef __KERNEL_SOFTMAX_FP16X16_HPP__
#define __KERNEL_SOFTMAX_FP16X16_HPP__

#include "aie_api/aie.hpp"
#include "aie_api/utils.hpp"
#include "adf.h"
#include "stdint.h"
#include "common.hh"
//#include "ml_params.h"
//#include "kernel_helpers.h"
#include "access_helpers.hpp"
#include <stdio.h>

struct KernelSoftmax_fp16x16Param {
    uint8_t Co_g;
    uint8_t X_g;
    uint16_t outer_g;
    uint16_t step_Ci;
    uint16_t step_Co;
    uint16_t step_Yi;
    uint16_t step_Yo;
    dims_2d_param dimsI_il;
    dims_2d_param dimsI_ol;
    dims_2d_param dimsO_il;
    dims_2d_param dimsO_ol;
    dims_2d_param dimsM_il;
    dims_2d_param dimsM_ol;
};

template<typename Ti, typename To>
requires(( std::is_same_v<Ti, float16> || std::is_same_v<Ti, bfloat16> ) && ( std::is_same_v<To, float16> || std::is_same_v<To, bfloat16> || std::is_same_v<To, float8> ))
ALWAYS_INLINE void softmax_fp16x16(
    Ti * input,
    int * mask,
    To * restrict output,
    const KernelSoftmax_fp16x16Param &params
);

struct softmax_layer_param
{
    uint32_t input_addr;
    uint32_t mask_addr;
    uint32_t output_addr;
    uint32_t qdq_param_addr;
    uint32_t dq_buffer_addr;
    uint32_t q_buffer_addr;
    uint32_t true_num_cols;
    uint32_t Msubv;
    uint32_t Nsubv;
    uint32_t num_elem_subv;
    uint32_t msk_num_bytes;
    uint32_t sign_A;
    uint32_t sign_O;
    KernelSoftmax_fp16x16Param krn_param;
};

#endif