#ifndef __WRAPPER_SLICE_CC__
#define __WRAPPER_SLICE_CC__
#include "slice_16b_inner_range_impl.hpp"
#include "qdq/wrapper_qdq.cc"
#include "qdq/qdq_sum.hpp"
#include "qdq/qdq_int16_bfloat16.hpp"
#include "qdq/qdq.cc"

void run_slice_a8(KernelArgs& args)
{
    struct LayerParams
    {
        int32_t scratch_addr;
        int subv_elems;
        int32_t input_dtype;
        int32_t output_dtype;
        int32_t is_signed;
        SliceParams kernel_params;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    SliceParams& kernel_params = layer_params->kernel_params;


    int8_t* input = static_cast<int8_t*>(args.s2mm_ch0_data);
    int8_t* scratch_buffer = static_cast<int8_t*>(conv_to_local_ptr(layer_params->scratch_addr));

    int32_t* qdq_prm = static_cast<int32_t*>(args.s2mm_ch1_data);

    uint16_t dq_zp        = reinterpret_cast<uint16_t*>(&qdq_prm[0])[0];
    bfloat16 dq_sc        = reinterpret_cast<bfloat16*>(&qdq_prm[1])[0];
    uint16_t q_zp         = reinterpret_cast<uint16_t*>(&qdq_prm[2])[0];
    bfloat16 q_sc         = reinterpret_cast<bfloat16*>(&qdq_prm[3])[0];
    uint16_t dq_enable    = reinterpret_cast<uint16_t*>(&qdq_prm[4])[0];
    uint16_t q_enable     = reinterpret_cast<uint16_t*>(&qdq_prm[5])[0];

    int8_t * matA = input;
    int8_t * matB = (!dq_enable && !q_enable) ? input : scratch_buffer;
    int8_t * matC = dq_enable ? scratch_buffer : input;
    int8_t * matD = dq_enable ^ q_enable ? scratch_buffer : input;
    bool is_signed = bool(layer_params->is_signed);

    dequant_int16_to_bf16(matA, matB, layer_params->subv_elems, dq_zp, dq_sc, is_signed, dq_enable, false);
    quant_bf16_to_int16(matC, matD, layer_params->subv_elems, q_zp, q_sc, is_signed, q_enable, false);


    if (layer_params->output_dtype == 1)
    {
        int8_t* input_slice = layer_params->input_dtype == 2 ?  reinterpret_cast<int8_t*>(scratch_buffer) : reinterpret_cast<int8_t*>(input);
        int8_t* output = static_cast<int8_t*>(args.mm2s_ch0_data);
        slice_16b_inner_range<2, int8_t>(
            input_slice,
            output,
            kernel_params
        );
    } else if (layer_params->output_dtype == 2)
    {
        int16_t* input_slice = reinterpret_cast<int16_t*>(scratch_buffer);
        int16_t* output = static_cast<int16_t*>(args.mm2s_ch0_data);
        slice_16b_inner_range<2, int16_t>(
            input_slice,
            output,
            kernel_params
        );

    }
}


void run_slice(KernelArgs& args)
{
    struct LayerParams
    {
        int32_t subv_elems;
        int32_t is_signed;
        SliceParams kernel_params;
    };
    LayerParams* layer_params = static_cast<LayerParams*>(args.params_data);
    SliceParams& kernel_params = layer_params->kernel_params;
    int16_t* input = static_cast<int16_t*>(args.s2mm_ch0_data);
    int32_t* qdq_prm = static_cast<int32_t*>(args.s2mm_ch1_data);

    uint16_t dq_zp        = reinterpret_cast<uint16_t*>(&qdq_prm[0])[0];
    bfloat16 dq_sc        = reinterpret_cast<bfloat16*>(&qdq_prm[1])[0];
    uint16_t q_zp         = reinterpret_cast<uint16_t*>(&qdq_prm[2])[0];
    bfloat16 q_sc         = reinterpret_cast<bfloat16*>(&qdq_prm[3])[0];
    uint16_t dq_enable    = reinterpret_cast<uint16_t*>(&qdq_prm[4])[0];
    uint16_t q_enable     = reinterpret_cast<uint16_t*>(&qdq_prm[5])[0];

    int8_t* matA = (int8_t*)input;

    bool is_signed = bool(layer_params->is_signed);
    dequant_int16_to_bf16(matA, matA, layer_params->subv_elems, dq_zp, dq_sc, is_signed, dq_enable);
    quant_bf16_to_int16(matA, matA, layer_params->subv_elems, q_zp, q_sc, is_signed, q_enable);

    int16_t* output = static_cast<int16_t*>(args.mm2s_ch0_data);

    // v64int16 *  pIn  = (v64int16 *) matA;
    // v64int16 *  pOut = (v64int16 *) output;
    // int col_id = get_coreid() >> 16;
    // int row_id = get_coreid() & 0xf;
    // if(col_id == 0 && row_id == 2){
    // printf("----------------COL0::ROW0::CORE--------------------\n");
    // uint8_t* data = reinterpret_cast<uint8_t*>(args.params_data);
    // size_t total_size = sizeof(LayerParams);

    // printf("Receiver raw bytes (%zu):\n", total_size);
    // for (size_t i = 0; i < total_size; ++i) {
    //     printf("%02X ", data[i]);
    // }
    // printf("\n");


    //     printf("kernel_params.loop_s1: %d\n", kernel_params.loop_s1);
    //     printf("kernel_params.startC: %d\n", kernel_params.startC);
    //     printf("kernel_params.incS0: %d\n", kernel_params.incS0);
    //     printf("kernel_params.numS1: %d\n", kernel_params.numS1);
    //     printf("kernel_params.incS1: %d\n", kernel_params.incS1);
    //     printf("kernel_params.incO1: %d\n", kernel_params.incO1);
    //     printf("kernel_params.size1: %d\n", kernel_params.size1);
    //     printf("kernel_params.mask1: %u\n", kernel_params.mask1);
    //     printf("dq_enable: %d\n", dq_enable);
    //     printf("q_enable: %d\n", q_enable);
    //     printf("layer_params->subv_elems: %d\n", layer_params->subv_elems);
    //     printf("layer_params->is_signed: %d\n", layer_params->is_signed);

    //     printf("input address: %d\n", input);
    //     printf("output address: %d\n", output);

    //     chess_report(pIn);
    //     for (int i = 0; i < 3 * 5760 / 64 / 1; i++){
    //         chess_report(*pIn);
    //         pIn++;
    //     }
    // }

    slice_16b_inner_range(
        input,
        output,
        kernel_params
    );

    // if(col_id == 0 && row_id == 2){
    //     chess_report(pOut);
    //     for (int i = 0; i < 3 * 576 /2 / 64 / 1; i++){
    //         chess_report(*pOut);
    //         pOut++;
    //     }
    // }

}

#endif