#ifndef __WRAPPER_TRANSPOSE_CC__
#define __WRAPPER_TRANSPOSE_CC__
#include "generic_transpose_impl.hpp"
#include "qdq/qdq_int16_bfloat16.hpp"

void run_int8_transpose(int8_t* input, int8_t* output, KernelGenericTransposeParam& kernel_params)
{
    generic_transpose(
        (int8_t*)input,
        (int8_t*)output,
        kernel_params
    );
}

void run_int16_transpose(int8_t* input, int8_t* output, KernelGenericTransposeParam& kernel_params)
{
    generic_transpose(
        (int16_t*)input,
        (int16_t*)output,
        kernel_params
    );
}

void run_transpose(KernelArgs& args)
    {

    set_sat();
    set_rnd(rnd_conv_even);
    struct LayerParams
    {
        KernelGenericTransposeParam kernel_params;
        int32_t dtype;
        int32_t input_subv_elems;
        int32_t output_subv_elems;
        int32_t is_int16;
        int32_t is_signed;
        int32_t scratch_buf;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    KernelGenericTransposeParam& kernel_params = layer_params->kernel_params;
    bool is_signed = bool(layer_params->is_signed);

    int8_t* qdq_prm = static_cast<int8_t*>(args.s2mm_ch1_data);
    int input_index   = 0;
    int quant_offset  = 2;
    int byte_offset = quant_offset*4;
    uint16_t* dq_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));
    uint16_t* q_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset ));
    bfloat16* q_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, byte_offset  + 4));
    uint16_t* dq_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*2));
    uint16_t* q_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*3));

    bool is_int16 = bool(layer_params->is_int16);
    /*
        if is_int16:
        ifm_bits = 16
        ofm_bits = ifm_bits
        has_scratch_buf = False
        transpose_bits = 16
    else: # int8
        if qdq_mode == 0:  #dq only
            #NOTE: sequence:
            # 1. first do transpose (8bits in) -> 8bits output buff (2nd half);
            # 2. then do dq, from 8bits output buf 2nd half to 16bits out buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 16
            has_scratch_buf = False
            scratch_buf_bits = 8
            transpose_bits = 8
        elif qdq_mode == 1: #q only
            #NOTE: sequence:
            # 1. first do q (16bits input buf) -> 8bits to same buf;
            # 2. then do transpose, from 8bits input buf to 8bits out buf
            # sctrach buf elem:  0
            ifm_bits = 16
            ofm_bits = 8
            has_scratch_buf = False # q output use ifm buffer
            scratch_buf_bits = 8
            transpose_bits = 8
        elif qdq_mode == 2:
            #NOTE: sequence:
            # 1. first do dq (8bits input buf) -> 16bits to scrath buf;
            # 2. second do q (16bits scratch buf) -> 8bits to scratch buf;
            # 3. then do transpose, from 8bits scratch buf to 8bits out buf
            # sctrach buf elem:  same as ifm
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = True
            scratch_buf_bits = 16
            transpose_bits = 8
        elif qdq_mode == 3:
            #NOTE: sequence:
            # 1. do transpose from 8bits input buf to 8bits output buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = False
            scratch_buf_bits = 8
            transpose_bits = 8

*/
    int8_t* input = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);

    dequant_int16_to_bf16(input, input, layer_params->input_subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);
    quant_bf16_to_int16(input, input, layer_params->input_subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16);
    run_int16_transpose(input, output, kernel_params);

}

void run_transpose_a8(KernelArgs& args)
    {

    set_sat();
    set_rnd(rnd_conv_even);
    struct LayerParams
    {
        KernelGenericTransposeParam kernel_params;
        int32_t dtype;
        int32_t input_subv_elems;
        int32_t output_subv_elems;
        int32_t is_int16;
        int32_t is_signed;
        int32_t scratch_buf;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    KernelGenericTransposeParam& kernel_params = layer_params->kernel_params;


    int8_t* qdq_prm = static_cast<int8_t*>(args.s2mm_ch1_data);
    int input_index   = 0;
    int quant_offset  = 2;
    int byte_offset = quant_offset*4;
    uint16_t* dq_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, 2 * input_index * 4 ));
    bfloat16* dq_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, 2 * input_index * 4  + 4));
    uint16_t* q_zp = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset ));
    bfloat16* q_sc = reinterpret_cast<bfloat16*>(byte_incr(qdq_prm, byte_offset  + 4));
    uint16_t* dq_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*2));
    uint16_t* q_enable = reinterpret_cast<uint16_t*>(byte_incr(qdq_prm, byte_offset  + 4*3));

    bool is_int16 = bool(layer_params->is_int16);
    bool is_signed = bool(layer_params->is_signed);

    int8_t* input = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);

    if (*dq_enable){ // qdq_mode = 0(dq only) and qdq_mode = 2(dq+q)
        if(*q_enable){//2:
            int8_t* scratch = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_buf));
            dequant_int16_to_bf16(input, scratch, layer_params->input_subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);
            quant_bf16_to_int16(scratch, scratch, layer_params->input_subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16);
            run_int8_transpose(scratch, output, kernel_params);
        }else{//0: --
            int8_t* scratch = byte_incr(output, layer_params->output_subv_elems); //second half (2bytes/elem)
            run_int8_transpose(input, scratch, kernel_params);
            dequant_int16_to_bf16(scratch, output, layer_params->output_subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);//NOTE: redundancy -for code clean up
            quant_bf16_to_int16(output, output, layer_params->output_subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16);
        }
    }else{ //qdq_mode = 1( q only) and qdq_mode = 3(q and dq both bypassed)
        dequant_int16_to_bf16(input, input, layer_params->input_subv_elems, *dq_zp, *dq_sc, is_signed, *dq_enable, is_int16);
        quant_bf16_to_int16(input, input, layer_params->input_subv_elems, *q_zp, *q_sc, is_signed, *q_enable, is_int16); //NOTE: redundancy -for code clean up
        run_int8_transpose(input, output, kernel_params);
    }
}

#endif
