//descritpion:
// because the requirement is for each input has its own dequantize
// but because of qdq has 64 elements aligned and
// qdq input and output address has to be 64 bytes aligned;
// so below analysis is for how the input/output address and total_element being calculated.


/* in the concat kernel layer params.  -- currently only for two inputs

----> Theoritical cacualtion

                |->in_addr_A, 64bytes aligned                       |->in_addr_B, not 64bytes aligned
----------------|---------------------------------------------------|---------------------------------------------------------------------------------|
                |                  input A                          |                                  input B                                        |
----------------|-------------------------------------------|---|---|-----------------------------------------------------------------------------|---|
    |->dq_out_addrA, 64bytes aligned                        |   |->dq_B_start here for 64bytes aligned                                            |
    |                                                       |->dq_out_addr_B, 64 bytes aligned                                                    |
    |                                                       |-> end_dq_A, not currupt B                                                           |-> end_dq_B(need correction)
    |                                                           |                                                                                 |
    |->q_in_start                                               |                                                                                 |-> end_q
    |->concat_A_in(scalar, 64bytes not needed)                  |->concat_B_in(scalar, 64bytes not needed)

(end_dq_A - dq_out_addr_A) = iceil(total_A_bytes, 64) = inputA_elems * ifm_bits//8
(end_dq_B - dq_B_start)   = iceil(total_B_bytes, 64) = inputB_elems * ifm_bits//8  --> this might miss samples. => inputB_elems += (in_addr_B - dq_B_start)

1-1）dq-int16:
    dq_A_in  = ifm                                                       // 64bytes aligned which guranteed in dataflow address generation
    dq_A_out = scratch                                                   // either reuse ifm or a scratch buf but it will be 64bytes aligned
    dq_B_in  = ifm + offset_bytes                                        // offset_bytes might not be 64bytes aligned -- it is the inputA subv size
            dq_B_in_start   = ifloor(ifm + offset_bytes, 64)             //dq_B_in_start is the first sample dq B taken
                            = ifm + ifloor(offset_bytes, 64)
                            = ifm + offset_bytes - constant_B            // constant_B = offset_bytes - ifloor(offset_bytes, 64)
    dq_B_out = dq_A_out + offset_bytes                                   // might not be 64bytes aligned.
            dq_B_out_start = ifloor(dq_A_out + offset_bytes, 64)         //dq_B_out_start is the first valid sample dq B output
                           = dq_A_out + ifloor(offset_bytes, 64)
                           = dq_A_out + offset_bytes - constant_B
1-2）dq-int8:
    dq_A_in  = ifm                                                       // 64bytes aligned which guranteed in dataflow address generation
    dq_A_out = scratch                                                   // either reuse ifm or a scratch buf but it will be 64bytes aligned
    dq_B_in  = ifm + offset_bytes                                        // offset_bytes might not be 64bytes aligned
            dq_B_in_start   = ifloor(ifm + offset_bytes, 64)             //dq_B_in_start is the first sample dq B taken
                            = ifm + ifloor(offset_bytes, 64)
                            = ifm + offset_bytes - constant_B            // constant_B = offset_bytes - ifloor(offset_bytes, 64)
    dq_B_out = dq_A_out + offset_bytes *2                                // might not be 64bytes aligned. why *2, because dq doing 8bit->16bits conversion
            dq_B_out_start = ifloor(dq_A_out + offset_bytes*2, 64)        //dq_B_out_start is the first valid sample dq B output
                           = dq_A_out + ifloor(offset_bytes*2, 64)
                           = dq_A_out + offset_bytes*2 - 2*constant_B
2-1) q-int16:
    q_in     = dq_A_out                                                  //64bytes aligned  and this is the first valid sample of input A as well
    q_out    = dq_A_out                                                  //64bytes algined
            total_q_sample = inputA_elems + inputB_elems
                            + iceil(constant_B, 64)                      // iceil(constant_B, 64) will be compensate the B last non-64 aligned part
                                                                         //for end_dq_B end early because of dq_B_start start early
2-2) q-int8:
    q_in     = dq_A_out                                                  //64bytes aligned  and this is the first valid sample of input A as well
    q_out    = dq_A_out                                                  //64bytes algined
            total_q_sample = inputA_elems + inputB_elems
                            + iceil(constant_B, 64)                      // iceil(constant_B, 64) will be compensate the B last non-64 aligned part
                                                                         //for end_dq_B end early because of dq_B_start start early
3-1) concat:
    concat_A = q_out = dq_A_out
    concat_B = dq_B_out + constant_B = dq_A_out + offset_bytes - constant_B + constant_B
            = dq_A_out + offset_bytes

----> end of Theoritical cacualtion

*/

/*
----> list all condiations for easy understanding (refer Theoritical for why)

1. if int16

  a) no qdq (dq_offst = 0; constant_B = 0)
  b) dq only (dq_offst, constant_B caculated based above)
  c)  q only (dq_offst = 0; constant_B = 0)
  d) dq + q  (dq_offst, constant_B caculated based above)

  **ifm     = static_cast<int8_t*>(args.s2mm_ch0_data);
  **scratch = ifm;

    1) int8_t*  dq_A_in  = ifm
    2) int8_t*  dq_A_out = scratch
                dq_A_elem = inputA_elems
    3) int8_t*  dq_B_in  = ifm + offset_bytes
    4) int8_t*  dq_B_out = dq_A_out + offset_bytes
                dq_B_elem = inputB_elems + iceil(constant_B, 64)
    5) int8_t*  q_in   = dq_A_out
    6) int8_t*  q_out  = dq_A_out
                q_elem = inputA_elems + inputB_elems + iceil(constant_B, 64)
    7) int8_t*  concat_in_A = dq_A_out
                concat_in_B = dq_A_out + offset_bytes
    8) int8_t* concat_out = static_cast<int8_t*>(args.mm2s_ch0_data);


2. if int8
     **ifm     = static_cast<int8_t*>(args.s2mm_ch0_data);
  a) no qdq (dq_offst = 0; constant_B = 0)
            **scratch = ifm;
  b) dq only (dq_offst, constant_B caculated based above)
            ** scratch  = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));
  c)  q only (dq_offst = 0; constant_B = 0)
            **scratch = ifm;
  d) dq + q  (dq_offst, constant_B caculated based above)
            ** scratch  = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));

    **scratch = ifm;
    1) int8_t*  dq_A_in  = ifm
    2) int8_t*  dq_A_out = scratch
                dq_A_elem = inputA_elems
    3) int8_t*  dq_B_in  = ifm + offset_bytes
    4) int8_t*  dq_B_out = dq_A_out + offset_bytes
                dq_B_elem = inputB_elems + iceil(constant_B, 64)
    5) int8_t*  q_in   = dq_A_out
    6) int8_t*  q_out  = dq_A_out
                q_elem = inputA_elems + inputB_elems + iceil(constant_B, 64)
    7) int8_t*  concat_in_A = dq_A_out
                concat_in_B = dq_A_out + offset_bytes
    8) int8_t* concat_out = static_cast<int8_t*>(args.mm2s_ch0_data);
----> end of list all condiations

*/

#ifndef __WRAPPER_CONCAT_CC__
#define __WRAPPER_CONCAT_CC__
#include "concat_16b_inner_impl.hpp"
#include "qdq/qdq_int16_bfloat16.hpp"

void run_concat(KernelArgs& args)
{
    set_sat();
    set_rnd(rnd_conv_even);
    struct LayerParams
    {
        int16_t inputA_elems_64_aligned;
        int16_t inputB_elems_64_aligned;
        int16_t dq_B_in_offset;
        int16_t dq_B_out_offset;
        int16_t is_int16;
        int16_t scratch_sel;
        int16_t q_elems_64_aligned;
        int16_t scratch_buf;
        int16_t concat_B_in_offset;
        int16_t is_signed;
        int16_t reserved_1;
        int16_t reserved_2;
        ConcatParams kernel_params;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    ConcatParams& kernel_params = layer_params->kernel_params;

    int8_t* qdq_prm = static_cast<int8_t*>(args.s2mm_ch1_data);
    int quant_offset = 4;
    int byte_offset = quant_offset * 4;

    int input_index = 0;
    uint16_t* dq_zp_A = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc_A = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));

    input_index = 1;
    uint16_t* dq_zp_B = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc_B = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));

    uint16_t* q_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset ));
    bfloat16* q_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, byte_offset  + 4));
    uint16_t* dq_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*2));
    uint16_t* q_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*3));

    //NOTE: below scratch selection might NOT be able to be done in dataflow because the ifm might be pingpong
    bool is_int16 = bool(layer_params->is_int16);
    bool scratch_sel = bool(layer_params->scratch_sel);
    int8_t* ifm = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* scratch = scratch_sel ? static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf)) : ifm;
    bool is_signed = bool(layer_params->is_signed);
    int8_t* dq_A_in     = ifm;
    int8_t* dq_A_out    = scratch;
    int     dq_A_elem   = layer_params->inputA_elems_64_aligned;
    int8_t* dq_B_in     = byte_incr(ifm, layer_params->dq_B_in_offset);
    int8_t* dq_B_out    = byte_incr(scratch, layer_params->dq_B_out_offset);
    int     dq_B_elem   = layer_params->inputB_elems_64_aligned;
    int8_t* q_in        = scratch;
    int8_t* q_out       = scratch;
    int     q_elem      = layer_params->q_elems_64_aligned;
    int8_t* concat_in_A = scratch;
    int8_t* concat_in_B = byte_incr(scratch, layer_params->concat_B_in_offset);
    int8_t* output      = static_cast<int8_t*>(args.mm2s_ch0_data);

    dequant_int16_to_bf16(dq_A_in, dq_A_out, dq_A_elem, *dq_zp_A, *dq_sc_A, is_signed, *dq_enable, is_int16);
    dequant_int16_to_bf16(dq_B_in, dq_B_out, dq_B_elem, *dq_zp_B, *dq_sc_B, is_signed, *dq_enable, is_int16);
    quant_bf16_to_int16(q_in, q_out, q_elem, *q_zp, *q_sc, is_signed, *q_enable, is_int16);

    concat_16b_inner ((int16_t*)concat_in_A, (int16_t*)concat_in_B, (int16_t*)output, kernel_params);
}

void run_concat_a8(KernelArgs& args)
{
    set_sat();
    set_rnd(rnd_conv_even);
    struct LayerParams
    {
        int16_t inputA_elems_64_aligned;
        int16_t inputB_elems_64_aligned;
        int16_t dq_B_in_offset;
        int16_t dq_B_out_offset;
        int16_t is_int16;
        int16_t scratch_sel;
        int16_t q_elems_64_aligned;
        int16_t scratch_buf;
        int16_t concat_B_in_offset;
        int16_t is_signed;
        int16_t reserved_1;
        int16_t reserved_2;
        ConcatParams kernel_params;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    ConcatParams& kernel_params = layer_params->kernel_params;

    int8_t* qdq_prm = static_cast<int8_t*>(args.s2mm_ch1_data);
    int quant_offset = 4;
    int byte_offset = quant_offset * 4;
    bool is_signed = bool(layer_params->is_signed);

    int input_index = 0;
    uint16_t* dq_zp_A = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc_A = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));

    input_index = 1;
    uint16_t* dq_zp_B = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc_B = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));

    uint16_t* q_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset ));
    bfloat16* q_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, byte_offset  + 4));
    uint16_t* dq_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*2));
    uint16_t* q_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*3));

    //NOTE: below scratch selection might NOT be able to be done in dataflow because the ifm might be pingpong
    bool is_int16 = bool(layer_params->is_int16);
    bool scratch_sel = bool(layer_params->scratch_sel);
    int8_t* ifm = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* scratch = scratch_sel ? static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf)) : ifm;

    int8_t* dq_A_in     = ifm;
    int8_t* dq_A_out    = scratch;
    int     dq_A_elem   = layer_params->inputA_elems_64_aligned;
    int8_t* dq_B_in     = byte_incr(ifm, layer_params->dq_B_in_offset);
    int8_t* dq_B_out    = byte_incr(scratch, layer_params->dq_B_out_offset);
    int     dq_B_elem   = layer_params->inputB_elems_64_aligned;
    int8_t* q_in        = scratch;
    int8_t* q_out       = scratch;
    int     q_elem      = layer_params->q_elems_64_aligned;
    int8_t* concat_in_A = scratch;
    int8_t* concat_in_B = byte_incr(scratch, layer_params->concat_B_in_offset);
    int8_t* output      = static_cast<int8_t*>(args.mm2s_ch0_data);

    dequant_int16_to_bf16(dq_A_in, dq_A_out, dq_A_elem, *dq_zp_A, *dq_sc_A, is_signed, *dq_enable, is_int16);
    dequant_int16_to_bf16(dq_B_in, dq_B_out, dq_B_elem, *dq_zp_B, *dq_sc_B, is_signed, *dq_enable, is_int16);
    quant_bf16_to_int16(q_in, q_out, q_elem, *q_zp, *q_sc, is_signed, *q_enable, is_int16);

    concat_16b_inner ((int8_t*)concat_in_A, (int8_t*)concat_in_B, (int8_t*)output, kernel_params);
}

#endif
