#ifndef RUN_BDCASTMUL_WRAPPER_CC
#define RUN_BDCASTMUL_WRAPPER_CC

#include "common.hh"
#include "aie_api/aie.hpp"
#include "broadcast/mul_bf16x16.hpp"
#include "broadcast/mul_bf16x16_impl.hpp"
#include "q/q_impl.hpp"
#include "dq/dq_impl.hpp"
using namespace aie;

#pragma pack(push,1)
struct bdcast_mul16_layer_params{
    uint32_t ifm_a_elements;
    uint32_t ifm_b_elements;
    uint32_t core_qbuf_offset;
    uint32_t core_dqbuf_offset;
    uint32_t dq_a_inner_g;
    uint32_t dq_b_inner_g;
    uint32_t q_inner_g;
    uint32_t has_scalar_broadcast;
    uint32_t sign_A;
    uint32_t sign_W;
    uint32_t sign_O;
    uint32_t is_input_16_bit;
    uint32_t is_output_16_bit;
    KernelMulBf16x16Param mul_params;
};
#pragma pack(pop)

void run_bdcastmul_16(KernelArgs& args)
{
    bdcast_mul16_layer_params* layer_params = static_cast<bdcast_mul16_layer_params*>(args.params_data);

    /*
    1 byte layout:
    | IFM A (int 8) |        IFM A (int 16)      | IFM B (int 8) |        IFM B (int 16)      |
    2 byte layout:
    |       IFM A (int 16)        |       IFM B (int 16)        |
    */ 
    // when there is 1 bit, we need to use a temp buffer for 2 byte elements
    int offset_multiplier = 2 + (!layer_params->is_input_16_bit);
    // 384 is size of dq iteration, x2 bytes (whether or not input is 16 bit)
    constexpr int DQ_PADDING = 768;
    int ifm_a_offset = layer_params->ifm_a_elements * offset_multiplier;
    // 128 element align
    ifm_a_offset = ((ifm_a_offset + 127) / 128) * 128;
    ifm_a_offset += DQ_PADDING;

    // we send mata, then matb, so the port will point to the start of matb
    uint16_t* matB = static_cast<uint16_t*>(args.s2mm_ch0_port->data());
    uint16_t* matA = byte_incr(matB, -ifm_a_offset);
    uint16_t* output = static_cast<uint16_t*>(args.mm2s_ch0_port->data());
    
    int a_tmp_offset = layer_params->ifm_a_elements * (!layer_params->is_input_16_bit);
    int b_tmp_offset = layer_params->ifm_b_elements * (!layer_params->is_input_16_bit);
    // 128 element align
    a_tmp_offset = ((a_tmp_offset + 127) / 128) * 128;
    b_tmp_offset = ((b_tmp_offset + 127) / 128) * 128;
    uint16_t* matA_tmp = byte_incr(matA, a_tmp_offset);
    uint16_t* matB_tmp = byte_incr(matB, b_tmp_offset);

    BinaryQDQParams* qdq_prm   = reinterpret_cast<BinaryQDQParams*>(args.s2mm_ch1_port->data());
    KernelDqParam dq_a_krn_param, dq_b_krn_param;
    dq_a_krn_param.inner_g = layer_params->dq_a_inner_g;
    dq_a_krn_param.sign_A = layer_params->sign_A;
    dq_b_krn_param.inner_g = layer_params->dq_b_inner_g;
    dq_b_krn_param.sign_A = layer_params->sign_W; // since we don't have weights, use sign_W for matB
    KernelQParam q_krn_param;
    v32accfloat *dq_a_buf, *dq_b_buf, *q_buf;

    q_buf = (v32accfloat*)byte_incr(qdq_prm, layer_params->core_qbuf_offset );
    dq_a_buf = (v32accfloat*)byte_incr(qdq_prm, layer_params->core_dqbuf_offset );
    dq_b_buf = dq_a_buf + 1; // v32accfloat is 128 bytes, so next buffer is +1 since we interleave dq_a and dq_b (host/broadcast.hpp)

    dq_float16_v32((int8_t*) matA, (float*) dq_a_buf, (QDQFloatType*) matA_tmp, dq_a_krn_param, qdq_prm->dq_enable, layer_params->is_input_16_bit);
    dq_float16_v32((int8_t*) matB, (float*) dq_b_buf, (QDQFloatType*) matB_tmp, dq_b_krn_param, qdq_prm->dq_enable, layer_params->is_input_16_bit);

    mul_bf16x16_nontemplatized(
        layer_params->has_scalar_broadcast,
        (QDQFloatType*) matA_tmp, (QDQFloatType*) matB_tmp, (QDQFloatType*)output, layer_params->mul_params
    );

    q_krn_param.inner_g = layer_params->q_inner_g;
    q_krn_param.sign_O = layer_params->sign_O;
    q_float16_to_int16_v32((QDQFloatType*) output, (float*) q_buf, (int16*) output, q_krn_param, qdq_prm->q_enable, layer_params->is_output_16_bit);

}

#endif