/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2026 Advanced Micro Devices, Inc. All rights reserved.
    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.
    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.
    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.
    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __MATADD_KERNEL_WRAPPER_C__
#define __MATADD_KERNEL_WRAPPER_C__

#include "kernel_helpers.h"
#include "qdq/qdq_int16_bfloat16.hpp"
#include "qdq/qdq_int8_bfloat16.hpp"
#include "qdq/qdq_kernel_helpers.h"

/*
 * Layer parameter takes in uint16 data
 * First nibble op types - 0 - Add ; 1 - Mul ; 2 - Sub ; 3 - Div
 * Second nibble check if it is broadcast, innermost broadcast / elw
 * Third nibble onwards - fusion (cascade)
 * 5th bit is set       - 0b0001 00xx      - broadcast
 * 6th and 5th bit set  - 0b0011 00xx      - innermost broadcast
 * 9th bit is set       - 0b0001 xxxx 00xx - Cascading
 */
#define ELW_ADD 0x00
#define ELW_MUL 0x01
#define ELW_SUB   0x02
#define ELW_DIV   0x03
#define BCAST_TOGGLE    0x08
#define BCAST_ADD 0x10
#define BCAST_MUL 0x11

#define BCAST_ADD_SINGLE_ELEM 0x30
#define BCAST_MUL_SINGLE_ELEM 0x31

#define CASCADE_ADD 0x0100

#define MATADD_NUM_BYTES 2
#define MATADD_A8_NUM_BYTES 1
#define DEBUG 0

#define BCAST_MASK 0x10
#define BCAST_SINGLE_ELEM_MASK 0x20
#define OP_VAL_MASK 0x3

/*
 * Sign mask and Tensor precision masks
 * Mat A - 0x0000 0000 XXXX 4 bits ((precision) X - 0 = uint16; 1 = uint8; 2 =
 * int16; 3 = int8) Mat B - 0x0000 XXXX 0000 4 bits ((precision) X - 0 = uint16;
 * 1 = uint8; 2 = int16; 3 = int8) Mat C - 0xXXXX 0000 0000 4 bits ((precision)
 * X - 0 = uint16; 1 = uint8; 2 = int16; 3 = int8)
 */

#define BINARY_OPS_DTYPE_UINT16 0
#define BINARY_OPS_DTYPE_UINT8 1
#define BINARY_OPS_DTYPE_INT16 2
#define BINARY_OPS_DTYPE_INT8 3

#define ADD_MAT_A_DTYPE 0xf
#define ADD_MAT_B_DTYPE 0xf0
#define ADD_MAT_C_DTYPE 0xf00

namespace matadd_qdq {
  struct LayerParam {
    uint16_t op_select;
    uint16_t Msubv;
    uint16_t Nsubv;
    uint16_t qdq_addr;
    uint16_t tdm1_addr;
    uint16_t tdm2_addr;
    uint16_t fused_op_flag;
    uint16_t do_neg;
    uint16_t itr_stage;
    uint16_t data_type;
    uint16_t matA_elems;
    uint16_t scratchA;
    uint16_t scratchB;
  };

  struct Qdq {
    uint16_t zero_point;
    bfloat16 scale;
    uint16_t enable;
  };

  struct QdqParams {
    uint16_t matA_zero_point;
    bfloat16 matA_scale;
    uint16_t dq_A_enable;
    uint16_t matB_zero_point;
    bfloat16 matB_scale;
    uint16_t dq_B_enable;
    uint16_t out_zero_point;
    bfloat16 out_scale;
    uint16_t q_out_enable;
    uint16_t sin_zero_point;
    bfloat16 sin_scale;
    uint16_t cos_zero_point;
    bfloat16 cos_scale;
    uint16_t sin_enable;
    uint16_t cos_enable;
  };

  struct QdqWrapper {
    uint16_t nItems;
    Qdq qdq[];
  };
}

/*
// matC = matA + matB
*/
void __attribute__((noinline)) elewise_bf16_bf16_bf16
(
    int8_t* matA,
    int8_t* matB,
    int8_t* output,
    uint16_t op_select,
    uint16_t matadd_num_elems,
    uint16_t Msubv,
    uint16_t Nsubv
) {
  v32bfloat16 * restrict matA_v    = (v32bfloat16* restrict) matA;
	v32bfloat16 * restrict matB_v    = (v32bfloat16* restrict) matB;

  bool is_toggle         = (op_select & BCAST_TOGGLE) >> 3;
  if(is_toggle)
  {
	   matA_v    = (v32bfloat16*) matB;
	   matB_v    = (v32bfloat16*) matA;
  }
  v32bfloat16 * restrict matC_v    = (v32bfloat16* restrict) output;
  uint16_t op            = (op_select & OP_VAL_MASK);
  bfloat16 factor        = op == ELW_SUB ? -1.0 : 1.0;
  aie::mask<32> mask_add = aie::mask<32>(op == ELW_ADD || op == ELW_SUB);
  int mask_mul           = int(op == ELW_MUL || op == ELW_DIV);
  //aie::vector<bfloat16,32> op_b  = aie::broadcast(factor);
  //making min elems to 2 because min num_elems will always be 64.
  [[using chess: prepare_for_pipelining, min_loop_count( 8 )]]
  for(int i = 0; i < matadd_num_elems / 32; i++)
  {
    aie::vector<bfloat16,32> va0 = *matA_v++;
    aie::vector<bfloat16,32> vb0 = *matB_v++;

    aie::vector<bfloat16,32> op_b = aie::select( va0, bfloat16( factor ), mask_add);

    *matC_v++                    = to_v32bfloat16( mac_elem_32_conf( vb0, op_b, aie::accum<accfloat,32>( va0 ), mask_mul, 0, 0 ));
  }
}

void __attribute__((noinline)) broadcast_bf16_bf16_bf16
(
    int8_t* matA,
    int8_t* matB,
    int8_t* output,
    uint16_t op_select,
    uint16_t matadd_num_elems,
    uint16_t Msubv,
    uint16_t Nsubv
) {
  bfloat16*    matA_bf16 = reinterpret_cast<bfloat16*>(matA);
  v32bfloat16 * restrict matA_v    = (v32bfloat16* restrict) matA;
	v32bfloat16 * restrict matB_v    = (v32bfloat16* restrict) matB;
  bool is_toggle         = (op_select & BCAST_TOGGLE) >> 3;
  uint16_t op            = (op_select & OP_VAL_MASK);
  if(is_toggle)
  {
	   matA_v    = (v32bfloat16*) matB;
	   matB_v    = (v32bfloat16*) matA;
  }
  v32bfloat16 * restrict matC_v    = (v32bfloat16* restrict) output;
  bool single_elem       = (op_select & BCAST_SINGLE_ELEM_MASK) >> 5;
	bool is_broadcast      = (op_select & BCAST_MASK) >> 4;
  int idxA               = (is_broadcast && !is_toggle) ? 0 : Nsubv / 32;
	int idxB               = (is_broadcast && is_toggle)  ? 0 : Nsubv / 32;
  bfloat16 factor        = 1.0;

  for(int i = 0; i < Msubv; i++)
	chess_no_hw_loop
	{
	    aie::vector<bfloat16,32> va1;
	    aie::vector<bfloat16,32> va2;
	    va2 = aie::broadcast(matA_bf16[i * 8]);
	    for(int j = 0; j < Nsubv / 32; j++)
	    chess_no_hw_loop
	    {
	        va1 = matA_v[j + (idxA * i)];
	        aie::vector<bfloat16,32> va0   = aie::select( va1, va2, aie::mask<32>( single_elem && !is_toggle));
	        aie::vector<bfloat16,32> vb0   = matB_v[j + (idxB * i)];
                //Note: Handling for single element broadcast with toggle cases.
	        vb0                            = aie::select( vb0, va2, aie::mask<32>( single_elem && is_toggle));
	        aie::vector<bfloat16,32> op_b  = aie::select( va0, bfloat16( factor ), aie::mask<32>( op == ELW_ADD || op == ELW_SUB));

          //Note: Performing B * op + A
          *matC_v++                      = to_v32bfloat16( mac_elem_32_conf( vb0, op_b, aie::accum<accfloat,32>( va0 ), op == ELW_MUL || op == ELW_DIV, 0, 0 ));
	    }
	}
}

typedef void (*bf16_kernel_fn)(
    int8_t*,
    int8_t*,
    int8_t*,
    uint16_t,
    uint16_t,
    uint16_t,
    uint16_t
);

INLINE_DECL void run_matadd_qdq(KernelArgs &args) {
  set_sat();
  set_rnd(rnd_conv_even);
  matadd_qdq::LayerParam *LayerParamPtr =
      static_cast<matadd_qdq::LayerParam *>(args.params_data);
  matadd_qdq::QdqParams *qdq_param = static_cast<matadd_qdq::QdqParams *>(
      conv_to_local_ptr(LayerParamPtr->qdq_addr));

  bool is_toggle = (LayerParamPtr->op_select & BCAST_TOGGLE) >> 3;
  //Note:
  //BroadCast:
  //    -> qdq_paramA => holds qdq params of broadCast.
  //    -> qdq_paramB => holds qdq params of non-broadCast vector;
  //    -> is_toggle == False, matA -> broadCast, so qdq_paramA => pointa to MatA_zp address.
  //    -> is_toggle == True,  matB -> broadCast, so qdq_paramA => pointa to MatB_zp address.
  //EleWise:
  //    -> qdq_paramA => holds qdq params of MatA.
  //    -> qdq_paramB => holds qdq params of MatB;
  uint16_t idxA = static_cast<uint16_t>(is_toggle);
  uint16_t idxB = static_cast<uint16_t>(!is_toggle);
  matadd_qdq::Qdq *qdq        = reinterpret_cast<matadd_qdq::Qdq *>(qdq_param);
  matadd_qdq::Qdq qdq_paramA  = qdq[idxA];
  matadd_qdq::Qdq qdq_paramB  = qdq[idxB];

  uint16_t dq_A_en_flag =
      (LayerParamPtr->itr_stage == 0) ? qdq_paramA.enable : 0;
  int matadd_num_elems = LayerParamPtr->Msubv * LayerParamPtr->Nsubv;
  int matA_elems = LayerParamPtr->matA_elems;

  // Tensor precision and signed estimation
  uint16_t dType = LayerParamPtr->data_type;

  int matA_dtype = (dType & ADD_MAT_A_DTYPE) >> 0;
  int matB_dtype = (dType & ADD_MAT_B_DTYPE) >> 4;
  int matC_dtype = (dType & ADD_MAT_C_DTYPE) >> 8;

  // only supports 16-bit / 8-bit
  bool matA_precision = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_UINT16));
  bool matB_precision = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_UINT16));
  bool matC_precision = ((matC_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matC_dtype == BINARY_OPS_DTYPE_UINT16));

  // only supports u16/u8/i16/i8
  bool is_matA_signed = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_INT8));
  bool is_matB_signed = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_INT8));
  bool is_output_signed = ((matC_dtype == BINARY_OPS_DTYPE_INT16) ||
                           (matC_dtype == BINARY_OPS_DTYPE_INT8));

  // Elwise matadd matmul
  int8_t *matA = static_cast<int8_t *>(args.s2mm_ch0_data);
  int8_t *matB = byte_incr(
      matA, (matadd_num_elems * MATADD_A8_NUM_BYTES) +
                (matadd_num_elems * MATADD_A8_NUM_BYTES * (int)matA_precision));

  int8_t *output = static_cast<int8_t *>(args.mm2s_ch0_data);

  if (LayerParamPtr->op_select & BCAST_MASK) {
    matA = static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->tdm1_addr));
    matB = static_cast<int8_t *>(args.s2mm_ch0_data);
  }

  int8_t *scratchA =
      (LayerParamPtr->scratchA == 0xffff)
          ? matA
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchA));
  int8_t *scratchB =
      (LayerParamPtr->scratchB == 0xffff)
          ? matB
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchB));

  bool is_div = (LayerParamPtr->op_select & OP_VAL_MASK) == ELW_DIV ; 

  bool is_broadcast = (LayerParamPtr->op_select & BCAST_MASK) >> 4;

  bf16_kernel_fn add_fn = is_broadcast? broadcast_bf16_bf16_bf16 : elewise_bf16_bf16_bf16;

  //NOTE: Disabling dequant operation of MAtA & MatB if is_div flag is true by making dq_enable_flag = False.
  // Dequant of MatA & MatB is taken care in elw_inv_qdq, in case of Div.
  dequant_int16_to_bf16(matA, scratchA, matA_elems, qdq_paramA.zero_point,
                        qdq_paramA.scale, is_matA_signed, dq_A_en_flag && (!is_div),
                        matA_precision);
  dequant_int16_to_bf16(matB, scratchB, matadd_num_elems,
                        qdq_paramB.zero_point, qdq_paramB.scale,
                        is_matB_signed, qdq_paramB.enable && (!is_div), matB_precision);
  add_fn(scratchA, scratchB, output, LayerParamPtr->op_select,
                        matadd_num_elems, LayerParamPtr->Msubv,
                        LayerParamPtr->Nsubv);
  quant_bf16_to_int16(output, output, matadd_num_elems,
                      qdq_param->out_zero_point, qdq_param->out_scale,
                      is_output_signed, qdq_param->q_out_enable,
                      matC_precision);
}


#endif //__MATADD_KERNEL_WRAPPER_C__
