#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__

#include <stdexcept>

using namespace std;
#include <variant>
#include <type_traits>

#include "common.hpp"
#include "broadcast.hpp"

struct BroadcastParams8 {
    int ifm_a_n, ifm_a_chs, ifm_a_chs_orig, ifm_a_cols, ifm_a_rows;
    int ifm_b_n, ifm_b_chs, ifm_b_chs_orig, ifm_b_cols, ifm_b_rows;
    int ofm_n, ofm_chs, ofm_chs_orig, ofm_cols, ofm_rows;
    int ifm_bytes, ofm_bytes;
    int wgt_size;
    int shift_in, shift_in1, shift_res;
    int read_ifm, op_type;
    int sign_A, sign_O;
    int a_on_wgt;
    int b_on_wgt;
    int has_scalar_broadcast;
    int debug_mode; // 1 enables CPU reference path
};

#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

std::string const ifm_bin_path1 = "../intermediate_bins/ifm1.bin";
std::string const ifm_bin_path2 = "../intermediate_bins/ifm2.bin";


int run_broadcast_8(const BroadcastParams8& params)
{

    assert(false);

    int ifm_a_size = ActTensor<int8_t>::size(params.ifm_a_chs, params.ifm_a_n * params.ifm_a_rows, params.ifm_a_cols);
    int ifm_b_size = ActTensor<int8_t>::size(params.ifm_b_chs, params.ifm_b_n * params.ifm_b_rows, params.ifm_b_cols);
    int ofm_size = ActTensor<int8_t>::size(params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols);
    int total_ifm_size = (!params.b_on_wgt * ifm_b_size) + (!params.a_on_wgt * ifm_a_size);
    int total_wgt_size = params.wgt_size + (params.a_on_wgt * ifm_a_size) + (params.b_on_wgt * ifm_b_size);

#if !ASM_MODE
    auto aie_ifm = static_cast<int8_t*>(adf::GMIO::malloc(total_ifm_size));
    auto aie_wgt = static_cast<int8_t*>(adf::GMIO::malloc(total_wgt_size));
#else
    auto aie_ifm = static_cast<int8_t*>(malloc(total_ifm_size));
    auto aie_wgt = static_cast<int8_t*>(malloc(total_wgt_size));

#endif // !ASM_MODE

    auto aie_matA = params.a_on_wgt ? static_cast<int8_t*>(aie_wgt) : static_cast<int8_t*>(aie_ifm);
    auto aie_matB = params.b_on_wgt ? static_cast<int8_t*>(aie_wgt) : (static_cast<int8_t*>(aie_ifm) + (ifm_a_size * !params.a_on_wgt));

    int* shift_in_val = (int*)aie_wgt;
    shift_in_val[0] = params.shift_in;
    shift_in_val[1] = params.shift_in1;
    shift_in_val[2] = params.shift_res;

    ActTensor<int8_t> aie_ifm_a(
        params.ifm_a_chs, params.ifm_a_n * params.ifm_a_rows, params.ifm_a_cols,
        aie_matA
    );

    ActTensor<int8_t> aie_ifm_b(
        params.ifm_b_chs, params.ifm_b_n * params.ifm_b_rows, params.ifm_b_cols,
        aie_matB
    );

    ActTensor<int8_t> cpu_ofm(
        params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols,
        params.debug_mode ? malloc(ofm_size) : nullptr
    );

#if !ASM_MODE
    printf("Creating AIE OFM 1 of size %d with ofm_chs=%d, ofm_rows=%d, ofm_cols=%d\n", ofm_size, params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols);
    ActTensor<int8_t> aie_ofm(
        params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols,
        adf::GMIO::malloc(ofm_size)
    );
#else
    ActTensor<int8_t> aie_ofm(
        params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols,
        malloc(ofm_size)
    );
#endif // !ASM_MODE

    if (params.read_ifm) {
        std::string const ifm_bin_path1 = "../intermediate_bins/ifm1.bin";
        std::string const ifm_bin_path2 = "../intermediate_bins/ifm2.bin";
        read_bin_file(ifm_bin_path1, reinterpret_cast<char*>(aie_matA), ifm_a_size);
        read_bin_file(ifm_bin_path2, reinterpret_cast<char*>(aie_matB), ifm_b_size);
    }
    else {
        int min_val = (params.sign_A == 1 && params.sign_O == 1) ? -4 : 4;
        init_tensor_random(aie_ifm_a, params.ifm_a_chs_orig, min_val, 8);
        init_tensor_random(aie_ifm_b, params.ifm_b_chs_orig, min_val, 8);
    }

    if (params.debug_mode) {
        switch(params.op_type){
            case ADD:
                add_broadcast_8(aie_ifm_a, aie_ifm_b, cpu_ofm, params.shift_in, params.shift_in1, params.shift_res); break;
        }
    }
    

#if ASM_MODE
    write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), total_ifm_size);
    write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), total_wgt_size);
    if (params.debug_mode) {
        write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm.data), ofm_size);
    }
#endif

#ifdef __AIESIM__
#if !ASM_MODE
    if (params.debug_mode) {
        log_tensor(aie_ifm_a, "IFM_A =\n");
        log_tensor(aie_ifm_b, "IFM_B =\n");
        log_tensor(cpu_ofm, "CPU OFM =\n");
    }
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm), total_ifm_size,
                 reinterpret_cast<void*>(aie_wgt), total_wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm, aie_wgt);
    g_compute_graph.end();
#endif // USE_CERT_LIBRARY

    if (params.debug_mode) {
        log_tensor(aie_ofm, "AIE OFM =\n");
        int epsilon = 0;
        int err = cmp_tensor(cpu_ofm, aie_ofm, 1, epsilon);
        if (err == 0)
            printf("DI_PASS\n");
        else
            printf("DI_FAIL\n");
    }
#endif // !ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm);
    adf::GMIO::free(aie_wgt);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm);
    free(aie_wgt);
#endif // !ASM_MODE
    if (cpu_ofm.data) free(cpu_ofm.data);
    return 0;
}




template<typename Ti, typename To>
int run_broadcast_16(const BroadcastParams16& params)
{
    if (params.dq_enable && is_float16_type_v<Ti>) {
        throw std::invalid_argument("Dequant should not be enabled for float16 input");
    }
    if (params.q_enable && is_float16_type_v<To>) {
        throw std::invalid_argument("Quant should not be enabled for float16 output");
    }
    if (params.a_on_wgt && params.b_on_wgt){
        throw std::invalid_argument("Both A_ON_WGT and B_ON_WGT cannot be enabled simultaneously");
    }

    // We fold the n and y dimensions because ActTensor only supports 3 dimensions: C, Y, and X.
    // This is fine for the entire test bench except the CPU broadcast function, where we manually
    // handle addressing. 
    int ifm_a_n_elements = ActTensor<int8_t>::size(params.ifm_a_chs, params.ifm_a_n * params.ifm_a_rows, params.ifm_a_cols);
    int ifm_a_size = ifm_a_n_elements * sizeof(Ti);
    int ifm_b_n_elements = ActTensor<int8_t>::size(params.ifm_b_chs, params.ifm_b_n * params.ifm_b_rows, params.ifm_b_cols);
    int ifm_b_size = ifm_b_n_elements * sizeof(Ti);
    int ofm_n_elements = ActTensor<int8_t>::size(params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols);
    int ofm_size = ofm_n_elements * sizeof(To);
    int total_ifm_size = 0;
    int total_wgt_size = params.wgt_size;
    if (params.a_on_wgt){
        total_ifm_size += ifm_b_size;
        total_wgt_size += ifm_a_size;
    } else if (params.b_on_wgt){
        total_ifm_size += ifm_a_size;
        total_wgt_size += ifm_b_size;
    } else {
        total_ifm_size += ifm_a_size + ifm_b_size;

    }

#if !ASM_MODE
    auto aie_ifm = static_cast<void*>(adf::GMIO::malloc(total_ifm_size));
    auto aie_wgt = static_cast<void*>(adf::GMIO::malloc(total_wgt_size));
#else
    auto aie_ifm = static_cast<void*>(malloc(total_ifm_size));
    auto aie_wgt = static_cast<void*>(malloc(total_wgt_size));
#endif // !ASM_MODE
    for (int i = 0; i < total_wgt_size; i++) {
        reinterpret_cast<int8_t*>(aie_wgt)[i] = 0;
    }

    Ti* aie_matA = static_cast<Ti*>(params.a_on_wgt ? aie_wgt : aie_ifm);
    Ti* aie_matB = static_cast<Ti*>(params.b_on_wgt ? aie_wgt : (reinterpret_cast<int8_t*>(aie_ifm) + (ifm_a_size * !params.a_on_wgt)));

    ActTensor<Ti> aie_ifm_a(
        params.ifm_a_chs, params.ifm_a_n * params.ifm_a_rows, params.ifm_a_cols,
        aie_matA
    );

    ActTensor<Ti> aie_ifm_b(
        params.ifm_b_chs, params.ifm_b_n * params.ifm_b_rows, params.ifm_b_cols,
        aie_matB
    );

    ActTensor<To> cpu_ofm(
        params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols,
        params.debug_mode ? malloc(ofm_size) : nullptr
    );

#if !ASM_MODE
    ActTensor<To> aie_ofm(
        params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols,
        adf::GMIO::malloc(ofm_size)
    );
#else
    ActTensor<To> aie_ofm(
        params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols,
        malloc(ofm_size)
    );
#endif // !ASM_MODE

    std::string const ifm_bin_path1 = "../intermediate_bins/ifm1.bin";
    std::string const ifm_bin_path2 = "../intermediate_bins/ifm2.bin";
    if (params.read_ifm) {
        int bytes;
        bytes = read_bin_file(ifm_bin_path1, reinterpret_cast<char*>(aie_matA), ifm_a_size);
        if (bytes == 0) {
            throw std::runtime_error("Error reading IFM bin file 1");
        }
        bytes = read_bin_file(ifm_bin_path2, reinterpret_cast<char*>(aie_matB), ifm_b_size);
        if (bytes == 0) {
            throw std::runtime_error("Error reading IFM bin file 2");
        }
    }
    else {
        int min_val = (params.sign_A == 1 && params.sign_O == 1) ? -4 : 4;
        init_tensor_random<Ti>(aie_ifm_a, params.ifm_a_chs_orig, min_val, 8);
        init_tensor_random<Ti>(aie_ifm_b, params.ifm_b_chs_orig, min_val, 8);
    }

    BinaryQDQParams* qdq_prm = reinterpret_cast<BinaryQDQParams*>(
        reinterpret_cast<int8_t*>(aie_wgt) + 
        (params.b_on_wgt * ifm_b_size) +
        (params.a_on_wgt * ifm_a_size)
    );

    if (!params.read_model_data)
    {
        qdq_prm->dq_a_zp = params.dq_a_zp;
        qdq_prm->dq_a_sc = params.dq_a_sc;
        qdq_prm->dq_b_zp = params.dq_b_zp;
        qdq_prm->dq_b_sc = params.dq_b_sc;
        qdq_prm->q_zp = params.q_zp;
        qdq_prm->q_sc = 1.0f / params.q_sc;
        qdq_prm->q_enable = params.q_enable;
        qdq_prm->dq_enable = params.dq_enable;
    }
    else
    {
        broadcast_op_init_model_data<Ti>(params,
                                        qdq_prm,
                                        reinterpret_cast<Ti *>(aie_ifm),
                                        reinterpret_cast<Ti*>(aie_matB),
                                        params.a_on_wgt,
                                        params.b_on_wgt);
    }
    initialize_qdq_buffer((float*)qdq_prm, params.dq_buf_offset, params.q_buf_offset); // broadcast zero points and scales across buffer

    if (params.debug_mode) {
        log_tensor(aie_ifm_a, "IFM A=\n");
        log_tensor(aie_ifm_b, "IFM B=\n");
        print_qdq_buffer((float*)qdq_prm, params.dq_buf_offset, params.q_buf_offset);
        ActTensor<QDQFloatType> *aie_ifm_t0, *aie_ifm_t1, *cpu_ofm_t2;
        if (qdq_prm->dq_enable){
            QDQFloatType* T0 = static_cast<QDQFloatType*>(malloc(ifm_a_n_elements * sizeof(QDQFloatType)));
            QDQFloatType* T1 = static_cast<QDQFloatType*>(malloc(ifm_b_n_elements * sizeof(QDQFloatType)));
            dequant<Ti*, QDQFloatType*, float>((Ti*)aie_ifm_a.data, T0,  params.ifm_a_n * params.ifm_a_rows, params.ifm_a_cols, params.ifm_a_chs, qdq_prm->dq_a_sc, qdq_prm->dq_a_zp);
            dequant<Ti*, QDQFloatType*, float>((Ti*)aie_ifm_b.data, T1,  params.ifm_b_n * params.ifm_b_rows, params.ifm_b_cols, params.ifm_b_chs, qdq_prm->dq_b_sc, qdq_prm->dq_b_zp);
            aie_ifm_t0 = new ActTensor<QDQFloatType>( params.ifm_a_chs, params.ifm_a_n * params.ifm_a_rows, params.ifm_a_cols, T0);
            aie_ifm_t1 = new ActTensor<QDQFloatType>( params.ifm_b_chs, params.ifm_b_n * params.ifm_b_rows, params.ifm_b_cols, T1);
        } else {
            aie_ifm_t0 = reinterpret_cast<ActTensor<QDQFloatType>*>(&aie_ifm_a);
            aie_ifm_t1 = reinterpret_cast<ActTensor<QDQFloatType>*>(&aie_ifm_b);
        }
        if (qdq_prm->q_enable){
            QDQFloatType* T2 = static_cast<QDQFloatType*>(malloc(ofm_n_elements * sizeof(QDQFloatType)));
            cpu_ofm_t2 = new ActTensor<QDQFloatType> (params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols, T2 );
        } else {
            cpu_ofm_t2 = new ActTensor<QDQFloatType> (params.ofm_chs, params.ofm_n * params.ofm_rows, params.ofm_cols, cpu_ofm.data );
        }
        broadcast_16(*aie_ifm_t0, *aie_ifm_t1, *cpu_ofm_t2,
                  params.ifm_a_n, params.ifm_a_rows, params.ifm_b_n, params.ifm_b_rows, params.ofm_n, params.ofm_rows,
            params.op_type, params.has_scalar_broadcast);
        if(qdq_prm->dq_enable){
            free(aie_ifm_t0);
            free(aie_ifm_t1);
        }
        if (qdq_prm->q_enable){
            quant_bfloat16_to_uint16<QDQFloatType*, float, To>(cpu_ofm_t2->data, cpu_ofm.data, params.ofm_n * params.ofm_rows, params.ofm_cols, params.ofm_chs, qdq_prm->q_sc, qdq_prm->q_zp);
            free(cpu_ofm_t2);
        }
    }

#if ASM_MODE

    write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), total_ifm_size);
    write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), total_wgt_size);
    if (params.debug_mode) {
        log_tensor(cpu_ofm, "CPU OFM =\n");
        write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm.data), ofm_size);
    }
#endif // ASM_MODE

#ifdef __AIESIM__
#if !ASM_MODE

#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm), total_ifm_size,
                 reinterpret_cast<void*>(aie_wgt), total_wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm, aie_wgt);
    g_compute_graph.end();
#endif // USE_CERT_LIBRARY

    if (params.debug_mode) {
        log_tensor(aie_ofm, "AIE OFM =\n");
        int epsilon = 1;
        int err = cmp_tensor(cpu_ofm, aie_ofm, 1, epsilon);
        if (err == 0)
            printf("DI_PASS\n");
        else
            printf("DI_FAIL\n");
    }
#endif // !ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm);
    adf::GMIO::free(aie_wgt);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm);
    free(aie_wgt);
#endif // !ASM_MODE
    if (cpu_ofm.data) free(cpu_ofm.data);
    return 0;
}

template<typename Ti>
void dispatch_output_type(const BroadcastParams16& params, const std::string& dtype_ofm) {
    if (dtype_ofm == "uint16") run_broadcast_16<Ti, uint16_t>(params);
    else if (dtype_ofm == "int16") run_broadcast_16<Ti, int16_t>(params);
    else if (dtype_ofm == "bfloat16") run_broadcast_16<Ti, bfloat16_t>(params);
    else if (dtype_ofm == "float16") run_broadcast_16<Ti, float16_t>(params);
    else if (dtype_ofm == "float32") run_broadcast_16<Ti, float>(params);
    else if (dtype_ofm == "uint8") run_broadcast_16<Ti, uint8_t>(params);
    else if (dtype_ofm == "int8") run_broadcast_16<Ti, int8_t>(params);
    else throw std::runtime_error("Unsupported output dtype: " + dtype_ofm);
}

void dispatch_run_broadcast_16(const BroadcastParams16& params,
                               const std::string& dtype_act,
                               const std::string& dtype_ofm) {
    if (dtype_act == "uint16") dispatch_output_type<uint16_t>(params, dtype_ofm);
    else if (dtype_act == "int16") dispatch_output_type<int16_t>(params, dtype_ofm);
    else if (dtype_act == "bfloat16") dispatch_output_type<bfloat16_t>(params, dtype_ofm);
    else if (dtype_act == "float16") dispatch_output_type<float16_t>(params, dtype_ofm);
    else if (dtype_act == "float32") dispatch_output_type<float>(params, dtype_ofm);
    else if (dtype_act == "uint8") dispatch_output_type<uint8_t>(params, dtype_ofm);
    else if (dtype_act == "int8") dispatch_output_type<int8_t>(params, dtype_ofm);
    else throw std::runtime_error("Unsupported input dtype: " + dtype_act);
}

template<typename T>
std::tuple<std::variant<BroadcastParams8, BroadcastParams16>, std::string, std::string> create_broadcast_params(const T& cfg) {
    int ifm_a_n = extract_json(cfg, "N_IN_A");
    int ifm_a_rows = extract_json(cfg, "Y_IN_A");
    int ifm_a_cols = extract_json(cfg, "X_IN_A");
    int ifm_a_chs = extract_json(cfg, "C_IN_A");
    int ifm_a_chs_orig = extract_json(cfg, "C_IN_A_ORIG");
    int ifm_b_chs_orig = extract_json(cfg, "C_IN_B_ORIG");
    int ifm_b_n = extract_json(cfg, "N_IN_B");
    int ifm_b_rows = extract_json(cfg, "Y_IN_B");
    int ifm_b_cols = extract_json(cfg, "X_IN_B");
    int ifm_b_chs = extract_json(cfg, "C_IN_B");
    int ifm_bytes = extract_json(cfg, "IFM_BYTES");
    int ofm_n = extract_json(cfg, "N_OUT");
    int ofm_rows = extract_json(cfg, "Y_OUT");
    int ofm_cols = extract_json(cfg, "X_OUT");
    int ofm_chs = extract_json(cfg, "C_OUT");
    int ofm_chs_orig = extract_json(cfg, "C_OUT_ORIG");
    int ofm_bytes = extract_json(cfg, "OFM_BYTES");
    int wgt_size = extract_json(cfg, "WGT_SIZE");
    int q_enable_int = extract_json(cfg, "DISABLE_Q") ? 0 : 1;
    int dq_enable_int = extract_json(cfg, "DISABLE_DQ0") ? 0 : 1;
    int dq_buf_offset = extract_json(cfg, "DQ_BUF_OFFSET");
    int q_buf_offset = extract_json(cfg, "Q_BUF_OFFSET");
    int op_type = extract_json(cfg, "OP_TYPE");
    int read_ifm = extract_json(cfg, "READ_IFM");
    int read_md = extract_json(cfg, "READ_MD");
    int sign_A = extract_json(cfg, "SIGN_A");
    int sign_O = extract_json(cfg, "SIGN_O");
    std::string node_name = extract_json_str(cfg, "NODE_NAME");
    std::string md_path = extract_json_str(cfg, "MD_PATH");
    std::string dtype_act = extract_json_str(cfg, "DTYPE_ACT");
    std::string dtype_ofm = extract_json_str(cfg, "DTYPE_OFM");
    int a_on_wgt = extract_json(cfg, "A_ON_WGT");
    int b_on_wgt = extract_json(cfg, "B_ON_WGT");
    int has_scalar_broadcast = extract_json(cfg, "HAS_SCALAR_BROADCAST");
    int debug_mode = 1; // default
    try {
        debug_mode = extract_json(cfg, "DEBUG_MODE");
    } catch (...) {
        debug_mode = 1;
    }

    if (sign_A == 0 && dtype_act.find("uint") != 0) {
        throw std::runtime_error("sign_A=0 requires dtype_act to be an integer type starting with 'uint', but got '" + dtype_act + "'");
    }

    if (sign_O == 0 && dtype_ofm.find("uint") != 0) {
        throw std::runtime_error("sign_O=0 requires dtype_ofm to be an integer type starting with 'uint', but got '" + dtype_ofm + "'");
    }

    float q_zp = generateRandomFloat(0, 2);
    float q_sc = generateRandomFloat(0.4, 0.5);
    float dq_a_zp = q_zp;
    float dq_b_zp = q_zp;
    float dq_a_sc = q_sc;
    float dq_b_sc = q_sc;

    BroadcastParams16 params{ifm_a_n, ifm_a_chs, ifm_a_chs_orig, ifm_a_cols, ifm_a_rows, ifm_b_n, ifm_b_chs, ifm_b_chs_orig, ifm_b_cols, ifm_b_rows,
                    ofm_n, ofm_chs, ofm_chs_orig, ofm_cols, ofm_rows, ifm_bytes, ofm_bytes, wgt_size, dq_buf_offset, q_buf_offset,
                    dq_a_zp, dq_a_sc, dq_b_zp, dq_b_sc, q_zp, q_sc,
                    dq_enable_int != 0, q_enable_int != 0, read_ifm, op_type, read_md, md_path, node_name, sign_A, sign_O, a_on_wgt, b_on_wgt, has_scalar_broadcast, debug_mode};
    return {params, dtype_act, dtype_ofm};
}

int main(void)
{
    auto cfg = load_json("broadcast_cfg.json");
    auto [params, dtype_act, dtype_ofm] = create_broadcast_params(cfg);

    if (std::holds_alternative<BroadcastParams8>(params)) {
        run_broadcast_8(std::get<BroadcastParams8>(params));
    } else {
        dispatch_run_broadcast_16(std::get<BroadcastParams16>(params), dtype_act, dtype_ofm);
    }

    return 0;
}
