/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2026 Advanced Micro Devices, Inc. All rights reserved.
    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.
    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.
    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.
    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */
#ifndef __MATADD_KERNEL_WRAPPER_C__
#define __MATADD_KERNEL_WRAPPER_C__

#include "kernel_helpers.h"
#include "qdq/qdq_int16_bfloat16.hpp"
#include "qdq/qdq_int8_bfloat16.hpp"
#include "qdq/qdq_kernel_helpers.h"

/*
 * Layer parameter takes in uint16 data
 * First nibble op types - 0 - Add ; 1 - Mul ; 2 - Sub ; 3 - Div
 * Second nibble check if it is broadcast, innermost broadcast / elw
 * Third nibble onwards - fusion (cascade)
 * 5th bit is set       - 0b0001 00xx      - broadcast
 * 6th and 5th bit set  - 0b0011 00xx      - innermost broadcast
 * 9th bit is set       - 0b0001 xxxx 00xx - Cascading
 */
#define ELW_ADD 0x00
#define ELW_MUL 0x01
#define ELW_SUB   0x02
#define ELW_DIV   0x03
#define BCAST_TOGGLE    0x08
#define BCAST_ADD 0x10
#define BCAST_MUL 0x11

#define BCAST_ADD_SINGLE_ELEM 0x30
#define BCAST_MUL_SINGLE_ELEM 0x31

#define CASCADE_ADD 0x0100

#define MATADD_NUM_BYTES 2
#define MATADD_A8_NUM_BYTES 1
#define DEBUG 0

#define BCAST_MASK 0x10
#define BCAST_SINGLE_ELEM_MASK 0x20
#define OP_VAL_MASK 0x3

/*
 * Sign mask and Tensor precision masks
 * Mat A - 0x0000 0000 XXXX 4 bits ((precision) X - 0 = uint16; 1 = uint8; 2 =
 * int16; 3 = int8) Mat B - 0x0000 XXXX 0000 4 bits ((precision) X - 0 = uint16;
 * 1 = uint8; 2 = int16; 3 = int8) Mat C - 0xXXXX 0000 0000 4 bits ((precision)
 * X - 0 = uint16; 1 = uint8; 2 = int16; 3 = int8)
 */

#define BINARY_OPS_DTYPE_UINT16 0
#define BINARY_OPS_DTYPE_UINT8 1
#define BINARY_OPS_DTYPE_INT16 2
#define BINARY_OPS_DTYPE_INT8 3

#define ADD_MAT_A_DTYPE 0xf
#define ADD_MAT_B_DTYPE 0xf0
#define ADD_MAT_C_DTYPE 0xf00

namespace matadd_qdq {
  struct LayerParam {
    uint16_t op_select;
    uint16_t Msubv;
    uint16_t Nsubv;
    uint16_t qdq_addr;
    uint16_t tdm1_addr;
    uint16_t tdm2_addr;
    uint16_t fused_op_flag;
    uint16_t do_neg;
    uint16_t itr_stage;
    uint16_t data_type;
    uint16_t matA_elems;
    uint16_t scratchA;
    uint16_t scratchB;
  };

  struct Qdq {
    uint16_t zero_point;
    bfloat16 scale;
    uint16_t enable;
  };

  struct QdqParams {
    uint16_t matA_zero_point;
    bfloat16 matA_scale;
    uint16_t dq_A_enable;
    uint16_t matB_zero_point;
    bfloat16 matB_scale;
    uint16_t dq_B_enable;
    uint16_t out_zero_point;
    bfloat16 out_scale;
    uint16_t q_out_enable;
    uint16_t sin_zero_point;
    bfloat16 sin_scale;
    uint16_t cos_zero_point;
    bfloat16 cos_scale;
    uint16_t sin_enable;
    uint16_t cos_enable;
  };

  struct QdqWrapper {
    uint16_t nItems;
    Qdq qdq[];
  };
}

/*
// matC = matA + matB
*/
void __attribute__((noinline)) matadd_bf16_bf16_bf16
(
    int8_t* matA,
    int8_t* matB,
    int8_t* output,
    uint16_t op_select,
    uint16_t matadd_num_elems,
    uint16_t Msubv,
    uint16_t Nsubv
) {
	bfloat16*    matA_bf16 = reinterpret_cast<bfloat16*>(matA);
	v32bfloat16* matA_v    = (v32bfloat16*) matA;
	v32bfloat16* matB_v    = (v32bfloat16*) matB;
        bool is_toggle         = (op_select & BCAST_TOGGLE) >> 3;
        if(is_toggle)
        {
	   matA_v    = (v32bfloat16*) matB;
	   matB_v    = (v32bfloat16*) matA;
        }
	v32bfloat16* matC_v    = (v32bfloat16*) output;
	bool single_elem       = (op_select & BCAST_SINGLE_ELEM_MASK) >> 5;
	bool is_broadcast      = (op_select & BCAST_MASK) >> 4;
	uint16_t op            = (op_select & OP_VAL_MASK);
        //Note: 
        //      -> is_toggle == False, indicating natural ifmA is broadcast.
        //      -> is_toggle == True,  indicating natural ifmB is broadcast.
	int idxA               = (is_broadcast && !is_toggle) ? 0 : Nsubv / 32;
	int idxB               = (is_broadcast && is_toggle)  ? 0 : Nsubv / 32;
	bfloat16 factor        = op == ELW_SUB ? -1.0 : 1.0;

	for(int i = 0; i < Msubv; i++)
	chess_no_hw_loop
	{
	    aie::vector<bfloat16,32> va1;
	    aie::vector<bfloat16,32> va2;
	    va2 = aie::broadcast(matA_bf16[i * 8]);
	    for(int j = 0; j < Nsubv / 32; j++)
	    chess_no_hw_loop
	    {
	        va1 = matA_v[j + (idxA * i)];
	        aie::vector<bfloat16,32> va0   = aie::select( va1, va2, aie::mask<32>( single_elem && !is_toggle));
	        aie::vector<bfloat16,32> vb0   = matB_v[j + (idxB * i)];
                //Note: Handling for single element broadcast with toggle cases.
	        vb0                            = aie::select( vb0, va2, aie::mask<32>( single_elem && is_toggle));
	        aie::vector<bfloat16,32> op_b  = aie::select( va0, bfloat16( factor ), aie::mask<32>( op == ELW_ADD || op == ELW_SUB));

                //Note: Performing B * op + A
                *matC_v++                      = to_v32bfloat16( mac_elem_32_conf( vb0, op_b, aie::accum<accfloat,32>( va0 ), op == ELW_MUL || op == ELW_DIV, 0, 0 ));
	    }
	}
}

INLINE_DECL void inv_bf16(
    int8_t* matA,
    uint16_t num_elems
) {
	v32bfloat16* matA_v    = (v32bfloat16*) matA;
	v32bfloat16* matC_v    = (v32bfloat16*) matA;
	int tot_itr            = num_elems / 32;

	aie::vector<bfloat16,32> va1;
	aie::vector<bfloat16,32> va2;
        //making min elems to 2 because min num_elems will always be 64.
        [[using chess: prepare_for_pipelining, min_loop_count( 2 )]]
	for(int i = 0; i < tot_itr; i++)
	{
	    va1  = matA_v[i];
	    va2  = aie::inv(va1);
	    *matC_v++  = va2;
	}
}

INLINE_DECL void run_matadd_cascade_qdq(KernelArgs &args)
{
  set_sat();
  set_rnd(rnd_conv_even);

  matadd_qdq::LayerParam *LayerParamPtr =
      static_cast<matadd_qdq::LayerParam *>(args.params_data);

  int matadd_num_elems = LayerParamPtr->Msubv * LayerParamPtr->Nsubv;
  int matA_elems = LayerParamPtr->matA_elems;

  static int count = 0;
  count = LayerParamPtr->itr_stage == 0 ? 0 : count;

  // Tensor precision and signed estimation
  uint16_t dType = LayerParamPtr->data_type;

  int matA_dtype = (dType & ADD_MAT_A_DTYPE) >> 0;
  int matB_dtype = (dType & ADD_MAT_B_DTYPE) >> 4;
  int matC_dtype = (dType & ADD_MAT_C_DTYPE) >> 8;

  // only supports 16-bit / 8-bit
  bool matA_precision = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_UINT16));
  bool matB_precision = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_UINT16));
  bool matC_precision = ((matC_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matC_dtype == BINARY_OPS_DTYPE_UINT16));

  // only supports u16/u8/i16/i8
  bool is_matA_signed = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_INT8));
  bool is_matB_signed = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_INT8));
  bool is_output_signed = ((matC_dtype == BINARY_OPS_DTYPE_INT16) ||
                           (matC_dtype == BINARY_OPS_DTYPE_INT8));

  // Elwise matadd matmul
  int8_t *matA = static_cast<int8_t *>(args.s2mm_ch0_data);
  int8_t *matB = byte_incr(
      matA, (matadd_num_elems * MATADD_A8_NUM_BYTES) +
                (matadd_num_elems * MATADD_A8_NUM_BYTES * (int)matA_precision));

  int8_t *output = static_cast<int8_t *>(args.mm2s_ch0_data);

  if (LayerParamPtr->op_select & BCAST_MASK) {
    matA = static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->tdm1_addr));
    matB = static_cast<int8_t *>(args.s2mm_ch0_data);
  }

  int8_t *scratchA =
      (LayerParamPtr->scratchA == 0xffff)
          ? matA
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchA));
  int8_t *scratchB =
      (LayerParamPtr->scratchB == 0xffff)
          ? matB
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchB));

  matadd_qdq::QdqParams qdq_param;
  matadd_qdq::QdqWrapper *qdq_param_cascade = static_cast<matadd_qdq::QdqWrapper *>(conv_to_local_ptr(LayerParamPtr->qdq_addr));

  qdq_param.matA_zero_point = qdq_param_cascade->qdq[count].zero_point;
  qdq_param.matA_scale = qdq_param_cascade->qdq[count].scale;
  qdq_param.dq_A_enable = qdq_param_cascade->qdq[count].enable;
  if (LayerParamPtr->itr_stage == 0) {
    qdq_param.matB_zero_point = qdq_param_cascade->qdq[count + 1].zero_point; 
    qdq_param.matB_scale = qdq_param_cascade->qdq[count + 1].scale; 
    qdq_param.dq_B_enable = qdq_param_cascade->qdq[count + 1].enable; 
    qdq_param.q_out_enable = 0; 
    count = 1;
  } else if (LayerParamPtr->itr_stage == 1) {
    scratchB = output;
    qdq_param.dq_B_enable = 0;
    qdq_param.q_out_enable = 0;
  } else {
    scratchB = output;
    qdq_param.dq_B_enable = 0;
    qdq_param.out_zero_point = qdq_param_cascade->qdq[qdq_param_cascade->nItems].zero_point;
    qdq_param.out_scale = qdq_param_cascade->qdq[qdq_param_cascade->nItems].scale;
    qdq_param.q_out_enable = qdq_param_cascade->qdq[qdq_param_cascade->nItems].enable;
  }
  count++;

  dequant_int16_to_bf16(matA, scratchA, matA_elems, qdq_param.matA_zero_point,
                        qdq_param.matA_scale, is_matA_signed, qdq_param.dq_A_enable,
                        matA_precision);
  dequant_int16_to_bf16(matB, scratchB, matadd_num_elems,
                        qdq_param.matB_zero_point, qdq_param.matB_scale,
                        is_matB_signed, qdq_param.dq_B_enable, matB_precision);
  matadd_bf16_bf16_bf16(scratchA, scratchB, output, LayerParamPtr->op_select,
                        matadd_num_elems, LayerParamPtr->Msubv,
                        LayerParamPtr->Nsubv);
  quant_bf16_to_int16(output, output, matadd_num_elems,
                      qdq_param.out_zero_point, qdq_param.out_scale,
                      is_output_signed, qdq_param.q_out_enable,
                      matC_precision);

}

INLINE_DECL void run_elw_inv_qdq(KernelArgs &args) {
  set_sat();
  set_rnd(rnd_conv_even);
  matadd_qdq::LayerParam *LayerParamPtr =
      static_cast<matadd_qdq::LayerParam *>(args.params_data);
  matadd_qdq::QdqParams *qdq_param = static_cast<matadd_qdq::QdqParams *>(
      conv_to_local_ptr(LayerParamPtr->qdq_addr));

  bool is_toggle = (LayerParamPtr->op_select & BCAST_TOGGLE) >> 3;
  //Note:
  //BroadCast:
  //    -> qdq_paramA => holds qdq params of broadCast.
  //    -> qdq_paramB => holds qdq params of non-broadCast vector;
  //    -> is_toggle == False, matA -> broadCast, so qdq_paramA => pointa to MatA_zp address.
  //    -> is_toggle == True,  matB -> broadCast, so qdq_paramA => pointa to MatB_zp address.
  //EleWise:
  //    -> qdq_paramA => holds qdq params of MatA.
  //    -> qdq_paramB => holds qdq params of MatB;
  uint16_t idxA = static_cast<uint16_t>(is_toggle);
  uint16_t idxB = static_cast<uint16_t>(!is_toggle);
  matadd_qdq::Qdq *qdq        = reinterpret_cast<matadd_qdq::Qdq *>(qdq_param);
  matadd_qdq::Qdq qdq_paramA  = qdq[idxA];
  matadd_qdq::Qdq qdq_paramB  = qdq[idxB];

  bool is_matA_pinned = (LayerParamPtr->itr_stage != 0);
  uint16_t dq_A_en_flag =
      (!is_matA_pinned) ? qdq_paramA.enable : 0;

  int matadd_num_elems = LayerParamPtr->Msubv * LayerParamPtr->Nsubv;
  int matA_elems = LayerParamPtr->matA_elems;

  // Tensor precision and signed estimation
  uint16_t dType = LayerParamPtr->data_type;

  int matA_dtype = (dType & ADD_MAT_A_DTYPE) >> 0;
  int matB_dtype = (dType & ADD_MAT_B_DTYPE) >> 4;

  // only supports 16-bit / 8-bit
  bool matA_precision = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_UINT16));
  bool matB_precision = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_UINT16));

  // only supports u16/u8/i16/i8
  bool is_matA_signed = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_INT8));
  bool is_matB_signed = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_INT8));

  // Elwise matadd matmul
  int8_t *matA = static_cast<int8_t *>(args.s2mm_ch0_data);
  int8_t *matB = byte_incr(
      matA, (matadd_num_elems * MATADD_A8_NUM_BYTES) +
                (matadd_num_elems * MATADD_A8_NUM_BYTES * (int)matA_precision));

  int8_t *output = static_cast<int8_t *>(args.mm2s_ch0_data);

  if (LayerParamPtr->op_select & BCAST_MASK) {
    matA = static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->tdm1_addr));
    matB = static_cast<int8_t *>(args.s2mm_ch0_data);
  }

  int8_t *scratchA =
      (LayerParamPtr->scratchA == 0xffff)
          ? matA
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchA));
  int8_t *scratchB =
      (LayerParamPtr->scratchB == 0xffff)
          ? matB
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchB));


  dequant_int16_to_bf16(matA, scratchA, matA_elems, qdq_paramA.zero_point,
                        qdq_paramA.scale, is_matA_signed, dq_A_en_flag,
                        matA_precision);
  dequant_int16_to_bf16(matB, scratchB, matadd_num_elems,
                        qdq_paramB.zero_point, qdq_paramB.scale,
                        is_matB_signed, qdq_paramB.enable, matB_precision);
  //Note: Handling for pinning of broadCast Vector
  //    -> inv -> need to be performed on denominator (i.e. onnx matB).
  //    -> is_matA_pinned => indicating broadcast vector is pinned.
  //    -> is_toggle => True -> onnx matB is broadCast, False -> onnx matA is broadcast
  if(!is_matA_pinned || !is_toggle){
      inv_bf16(is_toggle ? scratchA : scratchB, is_toggle ? matA_elems : matadd_num_elems);
  }
}

INLINE_DECL void run_matadd_qdq(KernelArgs &args) {
  set_sat();
  set_rnd(rnd_conv_even);
  matadd_qdq::LayerParam *LayerParamPtr =
      static_cast<matadd_qdq::LayerParam *>(args.params_data);
  matadd_qdq::QdqParams *qdq_param = static_cast<matadd_qdq::QdqParams *>(
      conv_to_local_ptr(LayerParamPtr->qdq_addr));

  bool is_toggle = (LayerParamPtr->op_select & BCAST_TOGGLE) >> 3;
  //Note:
  //BroadCast:
  //    -> qdq_paramA => holds qdq params of broadCast.
  //    -> qdq_paramB => holds qdq params of non-broadCast vector;
  //    -> is_toggle == False, matA -> broadCast, so qdq_paramA => pointa to MatA_zp address.
  //    -> is_toggle == True,  matB -> broadCast, so qdq_paramA => pointa to MatB_zp address.
  //EleWise:
  //    -> qdq_paramA => holds qdq params of MatA.
  //    -> qdq_paramB => holds qdq params of MatB;
  uint16_t idxA = static_cast<uint16_t>(is_toggle);
  uint16_t idxB = static_cast<uint16_t>(!is_toggle);
  matadd_qdq::Qdq *qdq        = reinterpret_cast<matadd_qdq::Qdq *>(qdq_param);
  matadd_qdq::Qdq qdq_paramA  = qdq[idxA];
  matadd_qdq::Qdq qdq_paramB  = qdq[idxB];

  uint16_t dq_A_en_flag =
      (LayerParamPtr->itr_stage == 0) ? qdq_paramA.enable : 0;
  int matadd_num_elems = LayerParamPtr->Msubv * LayerParamPtr->Nsubv;
  int matA_elems = LayerParamPtr->matA_elems;

  // Tensor precision and signed estimation
  uint16_t dType = LayerParamPtr->data_type;

  int matA_dtype = (dType & ADD_MAT_A_DTYPE) >> 0;
  int matB_dtype = (dType & ADD_MAT_B_DTYPE) >> 4;
  int matC_dtype = (dType & ADD_MAT_C_DTYPE) >> 8;

  // only supports 16-bit / 8-bit
  bool matA_precision = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_UINT16));
  bool matB_precision = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_UINT16));
  bool matC_precision = ((matC_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matC_dtype == BINARY_OPS_DTYPE_UINT16));

  // only supports u16/u8/i16/i8
  bool is_matA_signed = ((matA_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matA_dtype == BINARY_OPS_DTYPE_INT8));
  bool is_matB_signed = ((matB_dtype == BINARY_OPS_DTYPE_INT16) ||
                         (matB_dtype == BINARY_OPS_DTYPE_INT8));
  bool is_output_signed = ((matC_dtype == BINARY_OPS_DTYPE_INT16) ||
                           (matC_dtype == BINARY_OPS_DTYPE_INT8));

  // Elwise matadd matmul
  int8_t *matA = static_cast<int8_t *>(args.s2mm_ch0_data);
  int8_t *matB = byte_incr(
      matA, (matadd_num_elems * MATADD_A8_NUM_BYTES) +
                (matadd_num_elems * MATADD_A8_NUM_BYTES * (int)matA_precision));

  int8_t *output = static_cast<int8_t *>(args.mm2s_ch0_data);

  if (LayerParamPtr->op_select & BCAST_MASK) {
    matA = static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->tdm1_addr));
    matB = static_cast<int8_t *>(args.s2mm_ch0_data);
  }

  int8_t *scratchA =
      (LayerParamPtr->scratchA == 0xffff)
          ? matA
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchA));
  int8_t *scratchB =
      (LayerParamPtr->scratchB == 0xffff)
          ? matB
          : static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->scratchB));

  bool is_div = (LayerParamPtr->op_select & OP_VAL_MASK) == ELW_DIV ; 

  //NOTE: Disabling dequant operation of MAtA & MatB if is_div flag is true by making dq_enable_flag = False.
  // Dequant of MatA & MatB is taken care in elw_inv_qdq, in case of Div.
  dequant_int16_to_bf16(matA, scratchA, matA_elems, qdq_paramA.zero_point,
                        qdq_paramA.scale, is_matA_signed, dq_A_en_flag && (!is_div),
                        matA_precision);
  dequant_int16_to_bf16(matB, scratchB, matadd_num_elems,
                        qdq_paramB.zero_point, qdq_paramB.scale,
                        is_matB_signed, qdq_paramB.enable && (!is_div), matB_precision);
  matadd_bf16_bf16_bf16(scratchA, scratchB, output, LayerParamPtr->op_select,
                        matadd_num_elems, LayerParamPtr->Msubv,
                        LayerParamPtr->Nsubv);
  quant_bf16_to_int16(output, output, matadd_num_elems,
                      qdq_param->out_zero_point, qdq_param->out_scale,
                      is_output_signed, qdq_param->q_out_enable,
                      matC_precision);
}

/*
// RoPE
*/

void run_a16a16_rope_qdq(KernelArgs &args) {
  set_sat();
  set_rnd(rnd_conv_even);
  uint16_t *args_params = (uint16_t *)args.params_data;
  matadd_qdq::LayerParam *LayerParamPtr =
      static_cast<matadd_qdq::LayerParam *>(static_cast<void *>(args_params));

  matadd_qdq::QdqParams *qdq_param = static_cast<matadd_qdq::QdqParams *>(
      conv_to_local_ptr(LayerParamPtr->qdq_addr));

  int matadd_num_elems = LayerParamPtr->Msubv * LayerParamPtr->Nsubv;
  int fused_op = LayerParamPtr->fused_op_flag;
  int do_neg = LayerParamPtr->do_neg;

  // Elwise matadd matmul
  int8_t *output_tdm1_A =
      static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->tdm1_addr));
  int8_t *output_tdm2_B =
      static_cast<int8_t *>(conv_to_local_ptr(LayerParamPtr->tdm2_addr));
  int8_t *output = static_cast<int8_t *>(args.mm2s_ch0_data);
  int8_t *matA =
      (fused_op == 1) ? output : static_cast<int8_t *>(args.s2mm_ch0_data);
  int8_t *sin = (fused_op == 1)
                    ? static_cast<int8_t *>(args.s2mm_ch0_data)
                    : byte_incr(matA, matadd_num_elems * MATADD_NUM_BYTES);
  int8_t *cos = byte_incr(sin, matadd_num_elems * MATADD_NUM_BYTES);

#if DEBUG
  int col_id = get_coreid() >> 16;
  int row_id = get_coreid() & 0xf;
  if (col_id == 0 && row_id == 2) {
    for (int idx = 0; idx < 1; idx++) {
        v16bfloat16* dbg_ptr = (v16bfloat16*) LayerParamPtr;
        chess_report(*(dbg_ptr));
    }
    printf("qdq addr: %u\n", LayerParamPtr->qdq_addr);
    printf("tdm1 Addr: %u\n", LayerParamPtr->tdm1_addr);
    printf("tdm2 Addr: %u\n", LayerParamPtr->tdm2_addr);
    printf("Msubv: %u\n", LayerParamPtr->Msubv);
    printf("Nsubv: %u\n", LayerParamPtr->Nsubv);
    printf("fused_op_flag: %u\n", LayerParamPtr->fused_op_flag);
    printf("do_neg_flag: %u\n", LayerParamPtr->do_neg);
    printf("N_orig: %u\n", LayerParamPtr->itr_stage);
    printf("op_select: %u\n", LayerParamPtr->op_select);
    printf("matadd_num_elems: %u\n", matadd_num_elems);
    printf("dq_zero_point: %u\n", qdq_param->matA_zero_point);
    printf("dq_scale: %f\n", (float)qdq_param->matA_scale);
    printf("q_zero_point: %u\n", qdq_param->out_zero_point);
    printf("q_scale: %f\n", (float)qdq_param->out_scale);
    printf("dq_enable: %u\n", qdq_param->dq_A_enable);
    printf("q_enable: %u\n", qdq_param->q_out_enable);
    printf("sin_zero_point: %u\n", qdq_param->sin_zero_point);
    printf("sin_scale: %f\n", (float)qdq_param->sin_scale);
    printf("cos_zero_point: %u\n", qdq_param->cos_zero_point);
    printf("cos_scale: %f\n", (float)qdq_param->cos_scale);
    printf("sin_dq_enable: %u\n", qdq_param->sin_enable);
    printf("cos_dq_enable: %u\n", qdq_param->cos_enable);

    for (int idx = 0; idx < 0; idx++) {
      v16bfloat16 *qdq_ptr = (v16bfloat16 *)qdq_param;
      chess_report(*(qdq_ptr + idx));
    }
    for (int idx = 0; idx < 0 /*8*4*/; idx++) {
      v16bfloat16 *in_ptr = (v16bfloat16 *)matA;
      chess_report(*(in_ptr + idx));
    }
    for (int idx = 0; idx < 0; idx++) {
      v16bfloat16 *sin_ptr = (v16bfloat16 *)sin;
      chess_report(*(sin_ptr + idx));
    }
    for (int idx = 0; idx < 0; idx++) {
      v16bfloat16 *cos_ptr = (v16bfloat16 *)cos;
      chess_report(*(cos_ptr + idx));
    }
  }
#endif

  dequant_int16_to_bf16(matA, matA, matadd_num_elems,
                        qdq_param->matA_zero_point, qdq_param->matA_scale,
                        false, qdq_param->dq_A_enable);
  dequant_int16_to_bf16(sin, sin, matadd_num_elems, qdq_param->sin_zero_point,
                        qdq_param->sin_scale, false, qdq_param->sin_enable);
  dequant_int16_to_bf16(cos, cos, matadd_num_elems, qdq_param->cos_zero_point,
                        qdq_param->cos_scale, false, qdq_param->cos_enable);

#if DEBUG
  col_id = get_coreid() >> 16;
  row_id = get_coreid() & 0xf;
  if (col_id == 0 && row_id == 2) {
    v16bfloat16 *matA_ptr = (v16bfloat16 *)matA;
    v16bfloat16 *sin_ptr = (v16bfloat16 *)sin;
    v16bfloat16 *cos_ptr = (v16bfloat16 *)cos;
    for (int idx = 0; idx < 0; idx++) {
      chess_report(*(matA_ptr + idx));
      chess_report(*(sin_ptr + idx));
      chess_report(*(cos_ptr + idx));
    }
  }
#endif

  matadd_bf16_bf16_bf16(matA, cos, output_tdm2_B, ELW_MUL /*mul*/,
                        matadd_num_elems, LayerParamPtr->Msubv,
                        LayerParamPtr->Nsubv);
#if DEBUG
  col_id = get_coreid() >> 16;
  row_id = get_coreid() & 0xf;
  if (col_id == 0 && row_id == 2) {
    v16bfloat16 *tdm2_ptr = (v16bfloat16 *)output_tdm2_B;
    for (int idx = 0; idx < 0; idx++) {
      chess_report(*(tdm2_ptr + idx));
    }
  }
#endif

  
  do_neg = 0;
  if (fused_op) {
      // Original fused implementation (64-element vectors)
      int orig_N = LayerParamPtr->Nsubv / 8;
      int mid_pt = orig_N / 2;
      
      v64bfloat16 *in_p = (v64bfloat16 *)(matA);
      bfloat16 *out_p = (bfloat16 *)(output_tdm1_A);
        
      for (int i = 0; i < orig_N; i++) {
          if (i == mid_pt) {
              do_neg = 1;
          }
          aie::vector<bfloat16, 64> v = *(in_p + i);
          aie::vector<bfloat16, 64> vec = (do_neg == 1) ? aie::neg(v) : v;
           
          if (do_neg) {
              aie::store_v(out_p + (i - mid_pt) * 64, vec);
          } else {
              aie::store_v(out_p + (i + mid_pt) * 64, vec);
          }
      }
  } else {
      // Standalone implementation (8-element vectors, block-based)
      int orig_N = LayerParamPtr->itr_stage;
      int vector_size = 8;
      int total_vecs = (LayerParamPtr->Msubv * LayerParamPtr->Nsubv) / vector_size;
      int vecs_per_block = orig_N / vector_size;
      int mid_pt = vecs_per_block / 2;
      
      v8bfloat16 *in_p = (v8bfloat16 *)(matA);
      bfloat16 *out_p = (bfloat16 *)(output_tdm1_A);
      
      for (int i = 0; i < total_vecs; i++) {
          int vec_in_block = i % vecs_per_block;
          int block_num = i / vecs_per_block;
           
          if (vec_in_block == mid_pt) {
              do_neg = 1;
          } else if (vec_in_block == 0) {
              do_neg = 0;
          }
            
          aie::vector<bfloat16, 8> v = *(in_p + i);
          aie::vector<bfloat16, 8> vec = (do_neg == 1) ? aie::neg(v) : v;
            
          int out_vec_in_block = do_neg ? (vec_in_block - mid_pt) : (vec_in_block + mid_pt);
          int out_vec_idx = block_num * vecs_per_block + out_vec_in_block;
            
          aie::store_v(out_p + out_vec_idx * 8, vec);
      }
    }

 /*
  do_neg = 0;
  int alignment = fused_op ? 8 : 1; // IF FUSED OP THEN W8 ALIGNED
  int orig_N = fused_op ? LayerParamPtr->Nsubv / alignment : LayerParamPtr->itr_stage / alignment; // NOTE: Hacky way to reuse field ->Original N dimension before padding
  int vec_size = 8; // Process 8 elements at a time
  int total_vecs = fused_op ? LayerParamPtr->Nsubv / vec_size : (LayerParamPtr->Msubv * LayerParamPtr->Nsubv) / vec_size;
  int vecs_per_block = orig_N / vec_size; // Number of vectors in one orig_N block
  int mid_pt = vecs_per_block / 2; // Midpoint within each orig_N block

  for (int i = 0; i < total_vecs; i++) {
    // Determine position within the current orig_N block
    int vec_in_block = i % vecs_per_block; //vector if for block of orig_N
    int block_num = i / vecs_per_block; // Which orig_N block we're in
     
    // Apply rotation within each orig_N block
    if (vec_in_block == mid_pt) {
       do_neg = 1;
    } else if (vec_in_block == 0) {
      do_neg = 0; // Reset for each new block
    }
    
    v8bfloat16 *in_p = (v8bfloat16 *)(matA);
    bfloat16 *out_p = (bfloat16 *)(output_tdm1_A);
    aie::vector<bfloat16, 8> v = *(in_p + i);
     
    // Apply negation if needed
    aie::vector<bfloat16, 8> vec = (do_neg == 1) ? aie::neg(v) : v;
      
    // Calculate output position with rotation within the block
    int out_vec_in_block = do_neg ? (vec_in_block - mid_pt) : (vec_in_block + mid_pt);
    int out_vec_idx = block_num * vecs_per_block + out_vec_in_block;
     
    aie::store_v(out_p + out_vec_idx * 8, vec);
  }
  */

#if DEBUG
  col_id = get_coreid() >> 16;
  row_id = get_coreid() & 0xf;
  if (col_id == 0 && row_id == 2) {
    v64bfloat16 *matA_ptr = (v64bfloat16 *)matA;
    v64bfloat16 *tdm1_ptr = (v64bfloat16 *)output_tdm1_A;
    for (int idx = 0; idx < 8; idx++) {
      chess_report(*(matA_ptr++));
    }
    for (int idx = 0; idx < 8; idx++) {
      chess_report(*(tdm1_ptr++));
    }
  }
#endif

  matadd_bf16_bf16_bf16(output_tdm1_A, sin, cos, ELW_MUL /*mul*/,
                        matadd_num_elems, LayerParamPtr->Msubv,
                        LayerParamPtr->Nsubv);
  matadd_bf16_bf16_bf16(cos, output_tdm2_B, output, ELW_ADD /*add*/,
                        matadd_num_elems, LayerParamPtr->Msubv,
                        LayerParamPtr->Nsubv);
#if DEBUG
  col_id = get_coreid() >> 16;
  row_id = get_coreid() & 0xf;
  if (col_id == 0 && row_id == 2) {
    v16bfloat16 *tdm1_ptr = (v16bfloat16 *)cos;
    v16bfloat16 *tdm2_ptr = (v16bfloat16 *)output_tdm2_B;
    v16bfloat16 *out_ptr = (v16bfloat16 *)output;
    for (int idx = 0; idx < 0; idx++) {
      chess_report(*(tdm1_ptr + idx));
      chess_report(*(tdm2_ptr + idx));
      chess_report(*(out_ptr + idx));
    }
  }
#endif
  quant_bf16_to_int16(output, output, matadd_num_elems,
                      qdq_param->out_zero_point, qdq_param->out_scale, false,
                      qdq_param->q_out_enable);
#if DEBUG
  col_id = get_coreid() >> 16;
  row_id = get_coreid() & 0xf;
  if (col_id == 0 && row_id == 2) {
    v64bfloat16 *out_ptr = (v64bfloat16 *)output;
    for (int idx = 0; idx < 8; idx++) {
      chess_report(*(out_ptr + idx));
    }
  }
#endif
}

#endif //__MATADD_KERNEL_WRAPPER_C__
