#include <string>
#include <sstream>
#include <fstream>
#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__
#include "conv.hpp"
#include "dwc.hpp"

using namespace std;

#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

template<typename Ta, typename Tw, typename To, typename Tb>
int run_conv_noqdq(
    int const Ci, int const Yi, int const Xi,
    int const Co, int const Yo, int const Xo,
    int const Ci_orig, int const Yi_orig, int const Xi_orig,
    int const Co_orig, int const Yo_orig, int const Xo_orig,
    int const Ky, int const Kx,
    int const Sy, int const Sx,
    int const Py, int const Px,
    int const Cis, int const Yis, int const Xis,
    int const Cos, int const Yos, int const Xos,
    int const act_mode, int const out_shift, int const bias_shift,
    int const ifm_sign, int const wgt_sign, int const ofm_sign,
    int const Co_split,
    int const read_ifm, int const read_wgt,
    int const debug_mode,
    int const rd_md,
    std::string const md_path,
    std::string const node_name
){
    ConvWgtTensor_noqdq_RT_Params qdq_params;
    qdq_params.lrelu_alpha = 0;
    if(ofm_sign == 1){
        qdq_params.max_value = int8_t(127);
    } else {
        qdq_params.max_value = int8_t(255);
    }
    qdq_params.shift_bias = bias_shift;
    qdq_params.shift_lrelu_in = 0;
    qdq_params.shift_out = out_shift;
    qdq_params.ctrl.sign_A = ifm_sign;
    qdq_params.ctrl.sign_W = wgt_sign;
    qdq_params.ctrl.sign_O = ofm_sign;
    int Co_padded = iceil(Co, (Cos * Co_split));
    int Ci_padded = iceil(Ci, Cis);
    int Kx_padded = (Ci < 64) ? ceildiv(64, Cis) : Kx;

    if (debug_mode) {
        printf("Padded COUT = %d \n", Co_padded);
        printf("Padded CIN = %d \n", Ci_padded);
        printf("Padded WGT dimensions: CIN x Ky x Kx x COUT: %d x %d x %d x %d \n", Ci_padded, Ky, Kx_padded, Co_padded);
    }
    int ifm_size = ActTensor<Ta>::size(Ci, Yi, Xi);
    int wgt_size = ConvWgtTensor_noqdq<Tw, Tb>::size(Co_padded, Ci_padded, Ky, Kx_padded, Cis, Cos);
    int ofm_size = ActTensor<To>::size(Co, Yo, Xo);

    if (debug_mode) {
        printf("DDR IFM SIZE = %d \n", ifm_size);
        printf("DDR WGT SIZE = %d \n", wgt_size);
        printf("DDR OFM SIZE = %d \n", ofm_size);
    }

#if !ASM_MODE
    ActTensor<Ta> aie_ifm(
        Ci, Yi, Xi,
        adf::GMIO::malloc(ifm_size)
    );
    ConvWgtTensor_noqdq<Tw, Tb> aie_wgt(
        Co_padded, Ci_padded, Ky, Kx, Cis, Cos,
        adf::GMIO::malloc(wgt_size)
    );
    ActTensor<To> aie_ofm(
        Co, Yo, Xo,
        adf::GMIO::malloc(ofm_size)
    );
#else
    ActTensor<Ta> aie_ifm(
        Ci, Yi, Xi,
        malloc(ifm_size)
    );
    ConvWgtTensor_noqdq<Tw, Tb> aie_wgt(
        Co_padded, Ci_padded, Ky, Kx, Cis, Cos,
        malloc(wgt_size)
    );
#endif // !ASM_MODE
    ActTensor<To> cpu_ofm(
        Co, Yo, Xo,
        malloc(ofm_size)
    );
    // Print the subv sizes
    if (debug_mode) {
        printf("subv_wgt_size = %d \n", aie_wgt.subv_wgt_size);
        printf("subv_bias_size = %d \n", aie_wgt.subv_bias_size);
        printf("subv_qdq_size = %d \n", aie_wgt.subv_qdq_size);
        printf("total_subv_size = %d \n", aie_wgt.subv_size);
    }

    if (!rd_md)
    {
        init_random_conv_noqdq_a8w8(aie_ifm, aie_wgt, qdq_params, Co, Ci);

        std::string const ifm_bin_path = "../intermediate_bins/ifm1.bin";
        std::string const wgt_bin_path = "../intermediate_bins/wgt.bin";
        if (read_ifm) {
            read_bin_file(ifm_bin_path, reinterpret_cast<char*>(aie_ifm.data), ifm_size);
        }
        if (read_wgt) {
            read_bin_file(wgt_bin_path, reinterpret_cast<char*>(aie_wgt.data), wgt_size);
        }

        cpu_iconv_2d(aie_ifm, aie_wgt, cpu_ofm, Sy, Sx, Py, Px);
    }
    else
    {
        init_conv_noqdq_model_data(md_path, node_name, aie_wgt, qdq_params, Co_orig, Ci_orig, 0, debug_mode);
    }

#if ASM_MODE
    write_bin_file("ifm.bin", (char*)aie_ifm.data, ifm_size);
    write_bin_file("wgt.bin", (char*)aie_wgt.data, wgt_size);
    write_bin_file("ofm.bin", (char*)cpu_ofm.data, ofm_size);
    write_external_buffer_json(ofm_size, ifm_size, wgt_size);
#endif // ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(aie_ifm, "ifm", Yi, Xi, Cis, 64);
    log_tensor(aie_wgt, "wgt");
    log_tensor(cpu_ofm, "cpu_ofm");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm.data), ifm_size,
                 reinterpret_cast<void*>(aie_wgt.data), wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm.data, aie_wgt.data);
    g_compute_graph.end();
#endif //USE_CERT_LIBRARY

    int epsilon = 1;
    int err = cmp_tensor(cpu_ofm, aie_ofm, qdq_params.ctrl.sign_O, epsilon);
    if (err == 0)
        printf("CONV_NO_QDQ_A8W8 DI_PASS: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d, Py=%d, Px=%d \n", Yi, Xi, Ci, Yo, Xo, Co, Ky, Kx, Sy, Sx, Py, Px);
    else
        printf("CONV_NO_QDQ_A8W8 DI_FAIL: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d, Py=%d, Px=%d \n", Yi, Xi, Ci, Yo, Xo, Co, Ky, Kx, Sy, Sx, Py, Px);
#endif // ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm.data);
    adf::GMIO::free(aie_wgt.data);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm.data);
    free(aie_wgt.data);
#endif // !ASM_MODE
    free(cpu_ofm.data);
    return 0;
}

template<typename Ta, typename Tw, typename To, typename Tb>
int run_conv_qdq(
    int const Ci, int const Yi, int const Xi,
    int const Co, int const Yo, int const Xo,
    int const Ci_orig, int const Yi_orig, int const Xi_orig,
    int const Co_orig, int const Yo_orig, int const Xo_orig,
    int const Ky, int const Kx,
    int const Sy, int const Sx,
    int const Py, int const Px,
    int const Cis, int const Yis, int const Xis,
    int const Cos, int const Yos, int const Xos,
    int const act_mode, int const out_shift, int const bias_shift,
    int const ifm_sign, int const wgt_sign, int const ofm_sign,
    int const Co_split, int const vec_coeff,
    int const read_ifm, int const read_wgt,
    int const debug_mode,
    int const rd_md,
    std::string const md_path,
    std::string node_name
){
    using Tacc = int64_t;
    using Tc0 = float;
    using Tc1 = float;
    using Tc2 = float;
    ConvWgtTensor_qdq_RT_Params qdq_params;
    qdq_params.shift_out = out_shift;
    qdq_params.ifm_sign = ifm_sign;
    qdq_params.wgt_sign = wgt_sign;
    qdq_params.ofm_sign = ofm_sign;
    int Co_padded = iceil(Co, (Cos * Co_split));
    int Ci_padded = iceil(Ci, Cis);
    int Kx_padded = (Ci < 64) ? ceildiv(64, Cis) : Kx;

    if (debug_mode) {
        printf("Padded COUT = %d \n", Co_padded);
        printf("Padded CIN = %d \n", Ci_padded);
        printf("Padded WGT dimensions: CIN x Ky x Kx x COUT: %d x %d x %d x %d \n", Ci_padded, Ky, Kx_padded, Co_padded);
    }
    int ifm_size = ActTensor<Ta>::size(Ci, Yi, Xi);
    int wgt_size = ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2>::size(Co_padded, Ci_padded, Ky, Kx_padded, Cis, Cos);
    int acc_size = ActTensor<Tacc>::size(Co, Yo, Xo);
    int ofm_size = ActTensor<To>::size(Co, Yo, Xo);

    if (debug_mode) {
        printf("DDR IFM SIZE = %d \n", ifm_size);
        printf("DDR WGT SIZE = %d \n", wgt_size);
        printf("DDR OFM SIZE = %d \n", ofm_size);
    }

#if !ASM_MODE
    ActTensor<Ta> aie_ifm(
        Ci, Yi, Xi,
        adf::GMIO::malloc(ifm_size)
    );
    ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2> aie_wgt(
        Co_padded, Ci_padded, Ky, Kx, Cis, Cos, vec_coeff,
        adf::GMIO::malloc(wgt_size)
    );
    ActTensor<To> aie_ofm(
        Co, Yo, Xo,
        adf::GMIO::malloc(ofm_size)
    );
#else
    ActTensor<Ta> aie_ifm(
        Ci, Yi, Xi,
        malloc(ifm_size)
    );
    ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2> aie_wgt(
        Co_padded, Ci_padded, Ky, Kx, Cis, Cos, vec_coeff,
        malloc(wgt_size)
    );
#endif // !ASM_MODE
    ActTensor<Tacc> cpu_conv_out(
        Co, Yo, Xo,
        malloc(acc_size)
    );

    ActTensor<To> cpu_ofm(
        Co, Yo, Xo,
        malloc(ofm_size)
    );
    // Print the subv sizes
    if (debug_mode) {
        printf("subv_wgt_size = %d \n", aie_wgt.subv_wgt_size);
        printf("subv_c0_size = %d \n", aie_wgt.subv_c0_size);
        printf("subv_c1_size = %d \n", aie_wgt.subv_c1_size);
        printf("subv_c2_size = %d \n", aie_wgt.subv_c2_size);
        printf("subv_qdq_size = %d \n", aie_wgt.subv_qdq_size);
        printf("total_subv_size = %d \n", aie_wgt.subv_size);
    }

    if (!rd_md)
    {
        init_random_conv_qdq_a16w8(aie_ifm, aie_wgt, qdq_params, Co, Ci);

        std::string const ifm_bin_path = "../intermediate_bins/ifm1.bin";
        std::string const wgt_bin_path = "../intermediate_bins/wgt.bin";
        if (read_ifm) {
            read_bin_file(ifm_bin_path, reinterpret_cast<char*>(aie_ifm.data), ifm_size);
        }
        if (read_wgt) {
            read_bin_file(wgt_bin_path, reinterpret_cast<char*>(aie_wgt.data), wgt_size);
        }

        cpu_iconv_2d(aie_ifm, aie_wgt, cpu_conv_out, Sy, Sx, Py, Px);
        cpu_3term_qdq(aie_ifm, cpu_conv_out, aie_wgt, cpu_ofm, Sy, Sx, Py, Px, debug_mode);
    }
    else
    {
        init_conv_qdq_model_data<Tw,Tc0,Tc1,Tc2,Tb>(md_path, node_name, aie_wgt, qdq_params, Co_orig, Ci_orig, vec_coeff, debug_mode);
    }

#if ASM_MODE
    write_bin_file("ifm.bin", (char*)aie_ifm.data, ifm_size);
    write_bin_file("wgt.bin", (char*)aie_wgt.data, wgt_size);
    write_bin_file("ofm.bin", (char*)cpu_ofm.data, ofm_size);
    write_external_buffer_json(ofm_size, ifm_size, wgt_size);
#endif // ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(aie_ifm, "ifm", Yi, Xi, Cis, 64);
    log_tensor(aie_wgt, "wgt");
    log_tensor(cpu_conv_out, "cpu_conv_out");
    log_tensor(cpu_ofm, "cpu_ofm");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm.data), ifm_size,
                 reinterpret_cast<void*>(aie_wgt.data), wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm.data, aie_wgt.data);
    g_compute_graph.end();
#endif //USE_CERT_LIBRARY

    int epsilon = 1;
    int err = cmp_tensor(cpu_ofm, aie_ofm, qdq_params.ofm_sign, epsilon);
    if (err == 0)
        printf("CONV_QDQ_A16W8 DI_PASS: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d, Py=%d, Px=%d Vector_coeff=%d \n", Yi, Xi, Ci, Yo, Xo, Co, Ky, Kx, Sy, Sx, Py, Px, vec_coeff);
    else
        printf("CONV_QDQ_A16W8 DI_FAIL: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d, Py=%d, Px=%d Vector_coeff=%d \n", Yi, Xi, Ci, Yo, Xo, Co, Ky, Kx, Sy, Sx, Py, Px, vec_coeff);
#endif // ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm.data);
    adf::GMIO::free(aie_wgt.data);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm.data);
    free(aie_wgt.data);
#endif // !ASM_MODE
    free(cpu_ofm.data);
    return 0;
}


template<typename Ta, typename Tw, typename To>
int run_dwc_qdq(
    int const Ci, int const Yi, int const Xi,
    int const Co, int const Yo, int const Xo,
    int const Ci_orig, int const Yi_orig, int const Xi_orig,
    int const Co_orig, int const Yo_orig, int const Xo_orig,
    int const Ky, int const Kx,
    int const Sy, int const Sx,
    int const Py, int const Px,
    int const Cis, int const Yis, int const Xis,
    int const Cos, int const Yos, int const Xos,
    int const out_shift,
    int const ifm_sign, int const wgt_sign, int const ofm_sign,
    int const Co_split, int const vec_coeff,
    int const read_ifm, int const read_wgt,
    int const debug_mode,
    int const rd_md,
    std::string const md_path,
    std::string node_name
) {
    assert((Ci == Co) && "In DWC, CI must be equal to CO.");
    assert((Cis == Cos) && "In DWC, CIS must be equal to COS.");
    using Tacc = int32_t;
    using Tc0 = float;
    using Tc2 = float;
    int Co_padded = iceil(Co, (Cos * Co_split));
    int Kx_padded = (Kx < 4) ? 4 : Kx;
    int Ky_padded = (Ky < 3) ? 3 : Ky;
    DwcWgtTensor_qdq_RT_Params qdq_params;
    qdq_params.shift_out = out_shift;
    if (debug_mode) {
        printf("Padded COUT = %d \n", Co_padded);
        printf("Padded WGT dimensions: Ky_padded x Kx_padded x COUT: %d x %d x %d \n", Ky_padded, Kx_padded, Co_padded);
    }
    int ifm_size = ActTensor<Ta>::size(Ci, Yi, Xi);
    int wgt_size = DwcWgtTensor_qdq<Tw, Tc0, Tc2>::size(Co_padded, Ky_padded, Kx_padded, Cos);
    int acc_size = ActTensor<Tacc>::size(Co, Yo, Xo);
    int ofm_size = ActTensor<To>::size(Co, Yo, Xo);

    if (debug_mode) {
        printf("DDR IFM SIZE = %d \n", ifm_size);
        printf("DDR WGT SIZE = %d \n", wgt_size);
        printf("DDR OFM SIZE = %d \n", ofm_size);
    }

#if !ASM_MODE
    ActTensor<Ta> aie_ifm(
        Ci, Yi, Xi,
        adf::GMIO::malloc(ifm_size)
    );
    DwcWgtTensor_qdq<Tw, Tc0, Tc2> aie_wgt(
        Co_padded, Ky_padded, Kx_padded, Cos, vec_coeff,
        adf::GMIO::malloc(wgt_size)
    );
    ActTensor<To> aie_ofm(
        Co, Yo, Xo,
        adf::GMIO::malloc(ofm_size)
    );
#else
    ActTensor<Ta> aie_ifm(
        Ci, Yi, Xi,
        malloc(ifm_size)
    );
    DwcWgtTensor_qdq<Tw, Tc0, Tc2> aie_wgt(
        Co_padded, Ci, Ky_padded, Kx_padded, Cos,
        malloc(wgt_size)
    ); 
#endif // !ASM_MODE
    ActTensor<Tacc> cpu_dwc_out(
        Co, Yo, Xo,
        malloc(acc_size)
    );
    ActTensor<To> cpu_ofm(
        Co, Yo, Xo,
        malloc(ofm_size)
    );
    if (debug_mode) {
        printf("subv_wgt_size = %d \n", aie_wgt.subv_wgt_size);
        printf("subv_c0_size = %d \n", aie_wgt.subv_c0_size);
        printf("subv_c2_size = %d \n", aie_wgt.subv_c2_size);
        printf("subv_qdq_size = %d \n", aie_wgt.subv_qdq_size);
        printf("total_subv_size = %d \n", aie_wgt.subv_size);
    }
    init_random_dwc_qdq_a16w8(aie_ifm, aie_wgt, qdq_params, Co, Ky, Kx); 
    cpu_dwc(aie_ifm, aie_wgt, cpu_dwc_out, Sy, Sx, Py, Px);
    cpu_3term_qdq(aie_ifm, cpu_dwc_out, aie_wgt, cpu_ofm, Ky, Kx, Sy, Sx, Py, Px, debug_mode, ofm_sign);

    #if ASM_MODE
    write_bin_file("ifm.bin", (char*)aie_ifm.data, ifm_size);
    write_bin_file("wgt.bin", (char*)aie_wgt.data, wgt_size);
    write_bin_file("ofm.bin", (char*)cpu_ofm.data, ofm_size);
    write_external_buffer_json(ofm_size, ifm_size, wgt_size);
#endif // ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(aie_ifm, "ifm", Yi, Xi, Cis, 64);
    log_tensor(aie_wgt, "wgt");
    log_tensor(cpu_dwc_out, "cpu_dwc_out");
    log_tensor(cpu_ofm, "cpu_ofm");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm.data), ifm_size,
                 reinterpret_cast<void*>(aie_wgt.data), wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm.data, aie_wgt.data);
    g_compute_graph.end();
#endif //USE_CERT_LIBRARY

    int epsilon = 1;
    int err = cmp_tensor(cpu_ofm, aie_ofm, ofm_sign, epsilon);
    if (err == 0)
        printf("DWC_CONV_QDQ_A16W8 DI_PASS: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d, Py=%d, Px=%d Vector_coeff=%d \n", Yi, Xi, Ci, Yo, Xo, Co, Ky, Kx, Sy, Sx, Py, Px, vec_coeff);
    else
        printf("DWC_CONV_QDQ_A16W8 DI_FAIL: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d, Py=%d, Px=%d Vector_coeff=%d \n", Yi, Xi, Ci, Yo, Xo, Co, Ky, Kx, Sy, Sx, Py, Px, vec_coeff);
#endif // ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm.data);
    adf::GMIO::free(aie_wgt.data);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm.data);
    free(aie_wgt.data);
#endif // !ASM_MODE
    free(cpu_ofm.data);
    return 0;
}

int main(void)
{
    auto cfg = load_json("conv_cfg.json");
    int const Ci = extract_json(cfg,"C_IN");
    int const Yi = extract_json(cfg,"Y_IN");
    int const Xi = extract_json(cfg,"X_IN");
    int const Co = extract_json(cfg,"C_OUT");
    int const Yo = extract_json(cfg,"Y_OUT");
    int const Xo = extract_json(cfg,"X_OUT");

    int const Ci_orig = extract_json(cfg,"C_IN_ORIG");
    int const Yi_orig = extract_json(cfg,"Y_IN_ORIG");
    int const Xi_orig = extract_json(cfg,"X_IN_ORIG");
    int const Co_orig = extract_json(cfg,"C_OUT_ORIG");
    int const Yo_orig = extract_json(cfg,"Y_OUT_ORIG");
    int const Xo_orig = extract_json(cfg,"X_OUT_ORIG");

    int const Ky = extract_json(cfg,"KERNEL_Y");
    int const Kx = extract_json(cfg,"KERNEL_X");

    int const Sy = extract_json(cfg,"STRIDE_Y");
    int const Sx = extract_json(cfg,"STRIDE_X");

    int const Cis = extract_json(cfg,"CIS");
    int const Yis = extract_json(cfg,"YIS");
    int const Xis = extract_json(cfg,"XIS");
    int const Cos = extract_json(cfg,"COS");
    int const Yos = extract_json(cfg,"YOS");
    int const Xos = extract_json(cfg,"XOS");
    int const vec_coeff = extract_json(cfg, "COEFF_VECTOR");
    int const group = extract_json(cfg, "GROUP");

    int const Py = extract_json(cfg,"PAD_Y");
    int const Px = extract_json(cfg,"PAD_X");

    int const Co_split = extract_json(cfg,"C_OUT_SPLIT");

    int const act_mode = extract_json(cfg, "ACT_MODE");
    int out_shift = extract_json(cfg, "SHIFT_OUT");
    int const bias_shift = extract_json(cfg, "BIAS_SHIFT");
    int const sign_act = extract_json(cfg, "SIGN_ACT");
    int const sign_wgt = extract_json(cfg, "SIGN_WGT");
    int const sign_out = extract_json(cfg, "SIGN_OUT");
    int const debug_mode = extract_json(cfg, "DEBUG");
    int const qdq = extract_json(cfg, "QDQ");
    std::string node_name = extract_json_str(cfg, "NODE_NAME");
    std::string const md_path = extract_json_str(cfg, "MD_PATH");
    std::string const dtype_act = extract_json_str(cfg, "DTYPE_ACT");
    std::string const dtype_wgt = extract_json_str(cfg, "DTYPE_WGT");
    std::string const dtype_bias = extract_json_str(cfg, "DTYPE_BIAS");
    std::string const dtype_ofm = extract_json_str(cfg, "DTYPE_OFM");
    int const read_md = extract_json(cfg, "READ_MD");

    int const read_ifm = extract_json(cfg, "READ_IFM");
    int const read_wgt = extract_json(cfg, "READ_WGT");

    if (read_md) {
        out_shift = 0;
    }

    if (debug_mode) {
        std::cout << "=== GEMM WGT Formatting ===\n"
        << "NODE_NAME   : " << node_name  << '\n'
        << "MD_PATH     : " << md_path    << '\n'
        << "READ_MD     : " << read_md    << '\n'
        << "DTYPE_ACT   : " << dtype_act  << '\n'
        << "DTYPE_WGT   : " << dtype_wgt  << '\n'
        << "DTYPE_BIAS  : " << dtype_bias << '\n'
        << "DTYPE_OFM   : " << dtype_ofm  << '\n'
        << "==========================\n";
        printf("IFM dimension: YIN x XIN x CIN = %d x %d x %d \n", Yi, Xi, Ci);
        printf("WGT dimensions: CIN x Ky x Kx x COUT: %d x %d x %d x %d \n", Ci, Ky, Kx, Co);
        printf("OFM dimension: YOUT x XOUT x COUT = %d x %d x %d \n", Yo, Xo, Co);
        printf("IFM Subvol dimension: YIS x XIS x CIS = %d X %d x %d \n", Yis, Xis, Cis);
        printf("WGT Subvol dimension: COS x Ky x Kx x CIS = %d x %d X %d x %d \n", Cos, Ky, Kx, Cis);
        printf("OFM Subvol dimension: YOS x XOS x COS = %d X %d x %d \n", Yos, Xos, Cos);
        printf("Kernel dimension: Ky x Kx = %d x %d \n", Ky, Kx);
        printf("Stride dimension: Sy x Sx = %d x %d \n", Sy, Sx);
        printf("Padding dimension: Py x Px = %d x %d \n", Py, Px);
        printf("C_OUT_SPLIT = %d \n", Co_split);
        printf("ACT_MODE = %d \n", act_mode);
        printf("SHIFT_OUT = %d, BIAS_SHIFT = %d \n", out_shift, bias_shift);
        printf("SIGN_ACT = %d, SIGN_WGT = %d, SIGN_OUT = %d \n", sign_act, sign_wgt, sign_out);
        printf("QDQ = %d \n", qdq);
        printf("DTYPE_ACT = %s, DTYPE_WGT = %s, DTYPE_BIAS = %s, DTYPE_OFM = %s \n", dtype_act.c_str(), dtype_wgt.c_str(), dtype_bias.c_str(), dtype_ofm.c_str());
        printf("READ_IFM = %d, READ_WGT = %d \n", read_ifm, read_wgt);
    }

    // NOTE: GROUP=1 indicates normal conv
    if (group == 1){
        if(qdq == 0){
            if (dtype_act == "uint8" && dtype_wgt == "uint8" && dtype_ofm == "uint8" &&  dtype_bias == "int16") {
                run_conv_noqdq<uint8_t, uint8_t, uint8_t, int16_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else if (dtype_act == "int8" && dtype_wgt == "int8" && dtype_ofm == "int8" &&  dtype_bias == "int16") {
                run_conv_noqdq<int8_t, int8_t, int8_t, int16_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else if (dtype_act == "uint8" && dtype_wgt == "int8" && dtype_ofm == "uint8" &&  dtype_bias == "int16") {
                run_conv_noqdq<uint8_t, int8_t, uint8_t, int16_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else if (dtype_act == "int8" && dtype_wgt == "uint8" && dtype_ofm == "int8" &&  dtype_bias == "int16") {
                run_conv_noqdq<int8_t, uint8_t, int8_t, int16_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else {
                printf("ERROR: Data type combination not supported for no QDQ mode \n");
                return -1;
            } 
        } else {
            if (dtype_act == "uint16" && dtype_wgt == "uint8" && dtype_ofm == "uint16" && (dtype_bias == "0" || dtype_bias == "int32")) {
                run_conv_qdq<uint16_t, uint8_t, uint16_t, int32_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split, vec_coeff,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else if (dtype_act == "int16" && dtype_wgt == "int8" && dtype_ofm == "int16" && (dtype_bias == "0" || dtype_bias == "int32")) {
                run_conv_qdq<int16_t, int8_t, int16_t, int32_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split, vec_coeff,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else if (dtype_act == "uint16" && dtype_wgt == "int8" && dtype_ofm == "uint16" && (dtype_bias == "0" || dtype_bias == "int32")) {
                run_conv_qdq<uint16_t, int8_t, uint16_t, int32_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split, vec_coeff,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else if (dtype_act == "int16" && dtype_wgt == "uint8" && dtype_ofm == "int16" && (dtype_bias == "0" || dtype_bias == "int32")) {
                run_conv_qdq<int16_t, uint8_t, int16_t, int32_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    act_mode, out_shift, bias_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split, vec_coeff,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else {
                std::cout << "ERROR: Data type combination not supported for QDQ mode"
                            << " DT_ACT: " << dtype_act
                            << " DT_WGT: " << dtype_wgt
                            << " DT_OFM: " << dtype_ofm
                            << " DT_BIAS: " << dtype_bias
                            << std::endl;
                return -1;
            }
        }
    } else {    // Group > 1 == Cout indicated DWC 
        if(qdq == 1){
            if (dtype_act == "uint16" && dtype_wgt == "uint8" && dtype_ofm == "uint16" && (dtype_bias == "int32" || dtype_bias == "0")) {
                return run_dwc_qdq<uint16_t, uint8_t, uint16_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    out_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split, vec_coeff,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else if (dtype_act == "int16" && dtype_wgt == "int8" && dtype_ofm == "int16" && (dtype_bias == "int32" || dtype_bias == "0")) {
                run_dwc_qdq<int16_t, int8_t, int16_t>(
                    Ci, Yi, Xi,
                    Co, Yo, Xo,
                    Ci_orig, Yi_orig, Xi_orig,
                    Co_orig, Yo_orig, Xo_orig,
                    Ky, Kx,
                    Sy, Sx,
                    Py, Px,
                    Cis, Yis, Xis,
                    Cos, Yos, Xos,
                    out_shift,
                    sign_act, sign_wgt, sign_out,
                    Co_split, vec_coeff,
                    read_ifm, read_wgt,
                    debug_mode,
                    read_md,
                    md_path,
                    node_name
                );
            } else {
                std::cout << "ERROR: Data type combination not supported for DWC QDQ mode"
                            << " DT_ACT: " << dtype_act
                            << " DT_WGT: " << dtype_wgt
                            << " DT_OFM: " << dtype_ofm
                            << std::endl;
                return -1;
            }
        }
    }

    return 0;
}
