/*  (c) Copyright 2019 - 2021 Xilinx, Inc. All rights reserved.
    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.
    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.
    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.
    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __SOFTMAX_C__
#define __SOFTMAX_C__

#include "qdq/qdq_kernel_helpers.h"
#include "qdq/qdq_int8_bfloat16.hpp"
#include "qdq/qdq_int16_bfloat16.hpp"
#include "nonlinear/softmax_bf16x16/softmax_bf16x16_kernel.c"
#include "nonlinear/global_reduce.h"

#define IFM_SIGNED 0x0001
#define OFM_SIGNED 0x0010
#define IFM_16BITS 0x0100
#define OFM_16BITS 0x1000

void __attribute__((noinline)) mask_w8_cols(bfloat16* ptr, int rows, bfloat16 mask_value, uint32_t mask) 
{
    bfloat16* pI = ptr;
    bfloat16* pO = ptr;
    for (unsigned i = 0; i < rows / 4; i++) 
        chess_no_hw_loop
    {
            aie::store_v(pO, aie::select(mask_value, aie::load_v<32>(pI), aie::mask<32>::from_uint32(mask)));
            pI += 32;
            pO += 32;
    } 
}

void softmax(int8_t* matA, int8_t* wgt, int8_t* output, 
             uint16_t Nlayer, uint16_t MSUB, uint16_t Nsubv, bool multi_core_sm, 
             uint16_t colIdx, uint16_t rowIdx,
             bool dequant_en, bfloat16 s_in, uint16_t zp_in,
             bool quant_en, bfloat16 s_out, uint16_t zp_out,
             float softmax_scalefactor, bool in_sign=0, bool out_sign=0,
             bool is_in_int16=0, bool is_out_int16=0) 
{
    int8_t *pDQ_in, *pDQ_out;
    int8_t *pSM_inout, *pSM_scratch;
    int8_t *pQ_in , *pQ_out;  
    
    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////
    //  Assuming softmax is always on
    //  Following mode are considered / supported : ( mode 2 and 3 the same config, mode 4 not supported)
    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////
    //                       |  dQ_in  |  dQ_out  | SM_inout  | SM_scratch | Q_in   |  Q_out   |  
    //  1. dQ + softmax      |  matA   |   matA   |    matA   |   output   |    x   |    x     |
    //  2.      softmax + Q  |    x    |     x    |    matA   |   output   | output |  output  |
    //  3. dQ + softmax + Q  |  matA   |   matA   |    matA   |   output   | output |  output  |
    //  4.      softmax      |    x    |     x    |    matA   |   output   |    x   |    x     |
    ///////////////////////////////////////////////////////////////////////////////////////////////////////////////
    
    //////////////////////////////////////
    //  Assuming softmax is always on
    //  All possible modes are supported
    //  //////////////////////////////////
    pDQ_in      = matA;
    pDQ_out     = matA;
    pSM_inout   = matA;
    pSM_scratch = output;
    pQ_in       = output;
    pQ_out      = output;

    if(dequant_en && (is_in_int16 == 0)){
        pDQ_out     = matA + MSUB*Nsubv; 
        pSM_inout   = matA + MSUB*Nsubv;
    }

#if 0 // 1 : for debug print only
    int rowIdx_reg = (get_coreid() & 0xF);
    int colIdx_reg = (get_coreid() >> 16);
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Debugging Purpose ONLY !! :
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    uint16_t s_out_fix = *(uint16_t *) (qdq_offset);
    uint16_t s_in_fix  = *(uint16_t *) (qdq_offset + 12);
    if(rowIdx_reg == 2 && colIdx_reg == 0)
    {
        printf("Nlayer,Nsubv : %d %d \n", Nlayer, Nsubv);
        printf("s_out_fix  : %d \n", s_out_fix);
        printf("zp_out     : %d \n", zp_out);
        printf("quant_en   : %d \n", (quant_en)? 1 : 0);
        printf("s_in_fix   : %d \n", s_in_fix );
        printf("zp_in      : %d \n", zp_in);
        printf("dequant_en : %d \n", (dequant_en)? 1 : 0);
    }
    /*
    uint16_t* matIn = (uint16_t*)matA;
    uint32_t* matIn32 = (uint32_t*)matA;   
    if(rowIdx_reg == 2 && colIdx_reg == 0)
    {
        for(int i = 0; i < MSUB * Nsubv; i++)
        {
            printf("%5d ", matIn[i]);
            if(i % 8 == 7)
                printf("\n");
        }
    }
    */
#endif

    set_rnd_wrapper();
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // De-Quantization:  Input (i16) ----> Output (bf16)  , Output @ pDQ_out
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    dequant_int16_to_bf16(pDQ_in, pDQ_out, MSUB*Nsubv, zp_in, s_in, in_sign, dequant_en, is_in_int16);
    
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Perform Softmax :　SoftMax Input (bf16) ----> SoftMax Output (bf16) , In-place: I/O both at pSM_inout
    // Constant here 1.4426950408889f is 1 / (ln(2)), this scale makes the 2^x based kernel performs e^x 
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Masking QKt in bf16
    int Nlayer_mod_8 = (Nlayer & 0x07);
    int col_offset =  8*(int)(Nlayer/8);
    bool is_Nlayer_mul_8 = (Nlayer_mod_8 == 0);
    const bfloat16 mask_val =-3.3895314e38;
    if (!is_Nlayer_mul_8)  //NOTE make this function generic for other odd shapes
    {
        uint32_t mask = ((1 << Nlayer_mod_8) - 1) * (16843009); // 16843009 == (1+(1<<8)+(1<<16)+(1<<24))
        mask_w8_cols(reinterpret_cast<bfloat16*>(pSM_inout + (2 * MSUB * col_offset ) ), MSUB, mask_val, mask);
        
    }
    // mask residual data
    int msk_res_blks = (int)((Nsubv-Nlayer)/8);
    mask_w8_cols(reinterpret_cast<bfloat16*>(pSM_inout + (2 * MSUB * (col_offset + 8*(!is_Nlayer_mul_8)) ) ), MSUB * msk_res_blks, mask_val, 0);

    
    softmax_bf16x16(pSM_inout, pSM_scratch, MSUB, Nsubv, colIdx % 4, rowIdx + 2, multi_core_sm, softmax_scalefactor);//(layer_params->multi_core > 0));

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Debugging Purpose ONLY !! (DO NOT REMOVE) : Copy Softmax result in pSM_inout into output 
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    /*
    if(0)
    {
        bfloat16 * input_pt  = (bfloat16*)pSM_inout;
        bfloat16 * output_pt = (bfloat16*)output;
        for(int ii = 0; ii < MSUB*Nsubv; ii++)
        {
            output_pt[ii] = input_pt[ii];
        }
    }
    */

    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    // Quantization:  Input (bf16) ----> Output (i16)  , Output @ pQ_out
    ////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
    quant_bf16_to_int16(pQ_in, pQ_out, MSUB*Nsubv, zp_out, s_out, out_sign, quant_en, is_out_int16);
    
    /*bfloat16 * input_pt = (bfloat16 *)pSM_inout; //matA;
    bfloat16 * output_pt = (bfloat16 *)output;
    if(!quant_en)
    {
        for(int ii = 0; ii < MSUB*Nsubv; ii++)
        {
            output_pt[ii] = input_pt[ii];
        }
    }*/
}

void run_softmax_qdq(KernelArgs& args) 
{
    int8_t* matA   = static_cast<int8_t*>(args.s2mm_ch0_data);  
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);
    
    uint16_t* args_params = (uint16_t*)args.params_data;
    uint16_t Nlayer    = args_params[0];
    uint16_t MSUB      = args_params[1];
    uint16_t Nsubv     = args_params[2];
    uint16_t SplitType = args_params[3];
    uint16_t colIdx    = args_params[4];
    uint16_t rowIdx    = args_params[5];
    uint16_t q_addr    = args_params[6]; 
    uint16_t dq_addr   = args_params[7]; 
    uint16_t matA_addr = args_params[8]; 
    uint16_t out_addr  = args_params[9]; 
    uint16_t fuse_mode = args_params[10];  //1 : MHA, 0: Standalone
    float scalefactor = *(float*)(&(args_params[12]));
    uint16_t sign_type = args_params[14];

    bool in_sign      = sign_type & IFM_SIGNED; // true --> IFM Signed
    bool out_sign     = sign_type & OFM_SIGNED; // true --> OFM Signed
    bool is_in_int16  = sign_type & IFM_16BITS; // true --> IFM 16 bits 
    bool is_out_int16 = sign_type & OFM_16BITS; // true --> OFM 16 bits

    if(fuse_mode)
    {
        matA   = static_cast<int8_t*>(conv_to_local_ptr(matA_addr)); 
        output = static_cast<int8_t*>(conv_to_local_ptr(out_addr));
    }
    int8_t* q_offset   = static_cast<int8_t*>(conv_to_local_ptr( q_addr));
    int8_t* dq_offset  = static_cast<int8_t*>(conv_to_local_ptr(dq_addr));
    
    uint16_t zp_in   = *(uint16_t *) (dq_offset);
    bfloat16 s_in    = *(bfloat16 *) (dq_offset + 4);  
    bool dequant_en  = *(bool *  )   (dq_offset + 8);

    uint16_t zp_out  = *(uint16_t *) ( q_offset);
    bfloat16 s_out   = *(bfloat16 *) ( q_offset + 4);
    bool quant_en    = *(bool *  )   ( q_offset + 8);

    softmax(matA, nullptr, output, 
            Nlayer, MSUB, Nsubv, SplitType, 
            colIdx, rowIdx,
            dequant_en, s_in, zp_in,
            quant_en, s_out, zp_out,
            scalefactor, in_sign, out_sign,
            is_in_int16, is_out_int16);
}
#endif
