#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__

using namespace std;

#include <iostream>
#include <vector>
#include <cstring>
#include <cassert>

#include "common.hpp"
#include "resize.hpp"

#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

template <typename T>
int run_nni_int16(
    int aie_rows, int aie_cols,
    int h_in, int w_in, int c_in,
    int h_out, int w_out, int c_out,
    int num_interpolation_h, int num_interpolation_w,
    int is_int16, int is_bfloat16) {
    srand(0xABCD);
    using Telem = T;

    int ifm_size = h_in * w_in * c_in * sizeof(Telem);
    int wgt_size = 4;
    int ofm_size = h_out * w_out * c_out * sizeof(Telem);

#if !ASM_MODE    
    auto aie_ifm = static_cast<Telem*>(adf::GMIO::malloc(ifm_size));
    auto aie_wgt = static_cast<Telem*>(adf::GMIO::malloc(wgt_size));
    auto aie_ofm = static_cast<Telem*>(adf::GMIO::malloc(ofm_size));
#else
    auto aie_ifm = static_cast<Telem*>(malloc(ifm_size));
    auto aie_wgt = static_cast<Telem*>(malloc(wgt_size));
    auto aie_ofm = static_cast<Telem*>(malloc(ofm_size));
#endif // !ASM_MODE
    auto cpu_ofm = static_cast<Telem*>(malloc(ofm_size));
    std::memset(aie_wgt, 0, wgt_size);
    ActTensor<Telem> ifm(c_in, h_in, w_in, aie_ifm);
    ActTensor<Telem> aie_out_mat(c_out, h_out, w_out, aie_ofm);
    ActTensor<Telem> cpu_out_mat(c_out, h_out, w_out, cpu_ofm);

    if (is_int16) {
        rand_tensor(ifm, -32768, 32767, 1);
    } else if (is_bfloat16) {
        rand_tensor(ifm, -128, 128, 0);    
    }

    cpu_nni(ifm, cpu_out_mat, num_interpolation_h, num_interpolation_w);

#if ASM_MODE
    write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), ifm_size);
    write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), wgt_size);
    write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm), ofm_size);
#endif

#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(ifm, "AIE IFM =\n");
    log_tensor(cpu_out_mat, "CPU OFM =\n");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm), ofm_size,
                 reinterpret_cast<void*>(aie_ifm), ifm_size,
                 reinterpret_cast<void*>(aie_wgt), wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm, aie_ifm, aie_wgt);
    g_compute_graph.end();
#endif // USE_CERT_LIBRARY
    log_tensor(aie_out_mat, "AIE OFM =\n");
    int epsilon = 1;
    int err = cmp_tensor(cpu_out_mat, aie_out_mat, 1, epsilon);
    if (err == 0)
        printf("DI_PASS\n");
    else
        printf("DI_FAIL\n");
#endif // !ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm);
    adf::GMIO::free(aie_wgt);
    adf::GMIO::free(aie_ofm);
#else
    free(aie_ifm);
    free(aie_wgt);
    free(aie_ofm);
#endif // !ASM_MODE
    free(cpu_ofm);
    return 0;
}

int main(void)
{
    auto cfg = load_json("resize_cfg.json");
    int aie_rows = extract_json(cfg, "AIE_ROWS");
    int aie_cols = extract_json(cfg, "AIE_COLS");
    int h_in = extract_json(cfg, "H_IN");
    int w_in = extract_json(cfg, "W_IN");
    int c_in = extract_json(cfg, "C_IN");
    int h_out = extract_json(cfg, "H_OUT");
    int w_out = extract_json(cfg, "W_OUT");
    int c_out = extract_json(cfg, "C_OUT");
    int num_interpolation_h = extract_json(cfg, "SCALE_Y");
    int num_interpolation_w = extract_json(cfg, "SCALE_X");
    int is_int16 = extract_json(cfg, "INT_16");
    int is_bfloat16 = extract_json(cfg, "BFLOAT_16");

    int read_ifm = extract_json(cfg, "READ_IFM");
    int const read_md = extract_json(cfg, "READ_MD");
    std::string const node_name = extract_json_str(cfg, "NODE_NAME");
    std::string const md_path = extract_json_str(cfg, "MD_PATH");

    run_nni_int16<int16_t>(
        aie_rows, aie_cols,
        h_in, w_in, c_in,
        h_out, w_out, c_out,
        num_interpolation_h, num_interpolation_w,
        is_int16, is_bfloat16);

    return 0;
}