#include "fusion_test.hpp"
#include <any>
#include <chrono>
#include <iostream>
#include <sstream>
#include <stdexcept>
#include <string>
#include <utility>

#include "timing_logger.hpp"
#include "matrix.hpp"
#include "meta_state.hpp"
#include "meta_utils.hpp"
#include "passes.hpp"
#include "subgraph_op.hpp"
#include "txn_utils.hpp"
#include "utils.hpp"

#define LOG_VERBOSE()                                                                                                  \
    if (verbose_)                                                                                                      \
    std::cout
using json = nlohmann::json;

using txn_vec_t = std::vector<uint8_t>;
using namespace waic_runner;
namespace waic_runner
{
template <typename srcT, typename Func> auto for_each(const std::vector<srcT> &src, Func &&f)
{
    using dstT = decltype(f(srcT{}));
    std::vector<dstT> res;
    res.reserve(src.size());
    for (const auto &item : src)
    {
        res.push_back(f(item));
    }
    return res;
}

template <typename T> static T json_get(const json &js, const std::string &key, const T &value)
{
    return js.find(key) != js.end() ? js.at(key).template get<T>() : value;
}

static std::map<std::string, std::any> extract_op_attrs(const json &op_info)
{
    std::map<std::string, std::any> attrs;
    if (op_info.find("attrs") == op_info.end())
    {
        return attrs;
    }

    for (const auto &[attr_name, attr_info] : op_info.at("attrs").items())
    {
        const std::string dtype = attr_info.at("type").template get<std::string>();
        const std::vector<std::string> values = attr_info.at("value").template get<std::vector<std::string>>();

        if (dtype == "float")
        {
            attrs[attr_name] = for_each(values, [](const auto &s) { return std::stof(s); });
        }
        else if (dtype == "int")
        {
            attrs[attr_name] = for_each(values, [](const auto &s) { return std::stoi(s); });
        }
        else if (dtype == "str")
        {
            attrs[attr_name] = values;
        }
        else
        {
            throw std::runtime_error("Unsupported dtype for attrs in JSON");
        }
    }
    return attrs;
}

static std::map<std::string, std::any> load_aux_info(const json &aux_info)
{
    std::map<std::string, std::any> res;

    // Original outputs
    {
        if (aux_info.find("original_outputs") != aux_info.end())
        {
            std::map<std::string, Tensor> tensors;
            for (const auto &[name, tinfo] : aux_info.at("original_outputs").items())
            {
                Tensor tensor{nullptr, tinfo.at("shape").template get<std::vector<size_t>>(),
                              tinfo.at("dtype").template get<std::string>()};
                tensors[name] = tensor;
            }
            res["original_outputs"] = std::any(tensors);
        }
    }
    // Original Inputs
    {
        if (aux_info.find("original_inputs") != aux_info.end())
        {
            std::map<std::string, Tensor> tensors;
            for (const auto &[name, tinfo] : aux_info.at("original_inputs").items())
            {
                Tensor tensor{nullptr, tinfo.at("shape").template get<std::vector<size_t>>(),
                              tinfo.at("dtype").template get<std::string>()};
                tensors[name] = tensor;
            }
            res["original_inputs"] = std::any(tensors);
        }
    }

    return res;
}

static Metadata load_meta_string(const std::string &meta_string)
{
    json data;
    try
    {
        data = json::parse(meta_string, nullptr, true);
    }
    catch (std::exception &e)
    {
        throw std::runtime_error("Failed to parse JSON String");
    }

    std::string device = json_get<std::string>(data, "device", "WXB2");
    Metadata meta = {};
    if (device == "WXB2")
    {
        meta.device = "WXB2";
        meta.json_path = "in-memory string";
        meta.major_version = data.at("dd_meta_major_version");
        meta.minor_version = data.at("dd_meta_minor_version");
        std::string ver = std::to_string(meta.major_version) + "." + std::to_string(meta.minor_version);

        // oplist
        for (const auto &opinfo : data.at("op_list"))
        {
            meta.op_list.push_back({opinfo.at("name").template get<std::string>(),
                                    opinfo.at("type").template get<std::string>(),
                                    opinfo.at("in_args").template get<std::vector<std::string>>(),
                                    opinfo.at("const_args").template get<std::vector<std::string>>(),
                                    opinfo.at("out_args").template get<std::vector<std::string>>(),
                                    {}});
            meta.op_list.back().attr = extract_op_attrs(opinfo);
        }

        // tensor info
        for (const auto &[name, tinfo] : data.at("fused_tensors").items())
        {
            meta.fused_tensors[name] = {tinfo.at("buffer_size").template get<size_t>(),
                                        tinfo.at("xrt_arg_id").template get<size_t>(),
                                        tinfo.at("packed_tensors").template get<std::vector<std::string>>()};
            if (name == "in") {
                meta.fused_tensors["in_onnx"] = { tinfo.at("buffer_size").template get<size_t>(),
                                        tinfo.at("xrt_arg_id").template get<size_t>(),
                                        tinfo.at("packed_tensors").template get<std::vector<std::string>>() };
            }
            if (name == "out") {
                meta.fused_tensors["out_onnx"] = { tinfo.at("buffer_size").template get<size_t>(),
                                        tinfo.at("xrt_arg_id").template get<size_t>(),
                                        tinfo.at("packed_tensors").template get<std::vector<std::string>>() };
            }
        }

        // tensor_map
        for (const auto &[name, offset_info] : data.at("tensor_map").items())
        {
            meta.tensor_map[name] = {offset_info.at("packed_buffer_label").template get<std::string>(),
                                     offset_info.at("offset").template get<size_t>(),
                                     0, // additional offset is 0
                                     0,
                                     offset_info.at("xrt_arg_id").template get<size_t>(),
                                     offset_info.at("dtype").template get<std::string>(),
                                     offset_info.at("shape").template get<std::vector<size_t>>(),
                                     offset_info.at("size_in_bytes").template get<size_t>(),
                                     json_get<std::string>(offset_info, "format", ""),
                                     json_get<std::string>(offset_info, "file_name", ""),
                                     json_get<size_t>(offset_info, "file_size", 0)};
        }
    }
    else if (device == "QHW4")
    {
        meta.major_version = data.at("meta_major_version");
        meta.minor_version = data.at("meta_minor_version");
        meta.device = "QHW4";
        uint32_t subgraph_idx = data.at("subgraph_index");
        meta.subgraph_idx = std::to_string(subgraph_idx);
        // tensor info
        Metadata::TensorInfo tinfo;
        // in
        size_t bo_size = data["bo_sizes"]["BO_1"];
        tinfo.size = bo_size;
        tinfo.xrt_arg_idx = 1;
        meta.fused_tensors["in"] = tinfo; // { bo_size, 1, "" };
        meta.fused_tensors["in_onnx"] = tinfo;
        // out
        bo_size = data["bo_sizes"]["BO_0"];
        tinfo.size = bo_size;
        tinfo.xrt_arg_idx = 0;
        meta.fused_tensors["out"] = tinfo;
        meta.fused_tensors["out_onnx"] = tinfo;
        // scrach
        tinfo.size = 0;
        tinfo.xrt_arg_idx = 0;
        meta.fused_tensors["scratch"] = tinfo;
        // const
        tinfo.size = 0;
        tinfo.xrt_arg_idx = 2;
        meta.fused_tensors["const"] = tinfo;
        // param
        tinfo.size = 0;
        tinfo.xrt_arg_idx = 3;
        meta.fused_tensors["super_instr"] = tinfo;
        meta.fused_tensors["ctrl_pkt"] = tinfo;

        for (const auto &fulltensorinfo : data.at("inputs"))
        {
            meta.fused_tensors["in"].packed_tensors.push_back(fulltensorinfo.at("name"));
            meta.fused_tensors["in_onnx"].packed_tensors.push_back(fulltensorinfo.at("name"));
            meta.tensor_map[fulltensorinfo.at("name")] = {
                "in",
                fulltensorinfo["L3_alloc"][1],
                0, 0,
                fulltensorinfo["L3_alloc"][0],
                fulltensorinfo.at("onnx_dtype").template get<std::string>(),
                fulltensorinfo.at("onnx_shape").template get<std::vector<size_t>>(),
                fulltensorinfo["L3_alloc"][2],
                fulltensorinfo.at("onnx_format").template get<std::string>(),
                json_get<std::string>(fulltensorinfo, "file_name", ""),
                json_get<size_t>(fulltensorinfo, "file_size", 0)};
            meta.in_tensors.push_back({fulltensorinfo.at("name").template get<std::string>(),
                                       fulltensorinfo.at("onnx_shape").template get<std::vector<size_t>>(),
                                       fulltensorinfo.at("onnx_dtype").template get<std::string>(),
                                       fulltensorinfo.at("onnx_format").template get<std::string>(),
                                       fulltensorinfo.at("hw_shape").template get<std::vector<size_t>>(),
                                       fulltensorinfo.at("hw_dtype").template get<std::string>(),
                                       fulltensorinfo.at("hw_format").template get<std::string>(),
                                       fulltensorinfo.at("L3_alloc").template get<std::vector<size_t>>(),
                                       json_get<std::string>(fulltensorinfo, "file_name", ""),
                                       json_get<size_t>(fulltensorinfo, "file_size", 0)});
        }

        for (const auto &fulltensorinfo : data.at("outputs"))
        {
            meta.fused_tensors["out"].packed_tensors.push_back(fulltensorinfo.at("name"));
            meta.fused_tensors["out_onnx"].packed_tensors.push_back(fulltensorinfo.at("name"));
            meta.tensor_map[fulltensorinfo.at("name")] = {
                "out",
                fulltensorinfo["L3_alloc"][1],
                0, 0,
                fulltensorinfo["L3_alloc"][0],
                fulltensorinfo.at("onnx_dtype").template get<std::string>(),
                fulltensorinfo.at("onnx_shape").template get<std::vector<size_t>>(),
                fulltensorinfo["L3_alloc"][2],
                fulltensorinfo.at("onnx_format").template get<std::string>(),
                json_get<std::string>(fulltensorinfo, "file_name", ""),
                json_get<size_t>(fulltensorinfo, "file_size", 0)};
            meta.out_tensors.push_back({fulltensorinfo.at("name").template get<std::string>(),
                                        fulltensorinfo.at("onnx_shape").template get<std::vector<size_t>>(),
                                        fulltensorinfo.at("onnx_dtype").template get<std::string>(),
                                        fulltensorinfo.at("onnx_format").template get<std::string>(),
                                        fulltensorinfo.at("hw_shape").template get<std::vector<size_t>>(),
                                        fulltensorinfo.at("hw_dtype").template get<std::string>(),
                                        fulltensorinfo.at("hw_format").template get<std::string>(),
                                        fulltensorinfo.at("L3_alloc").template get<std::vector<size_t>>(),
                                        json_get<std::string>(fulltensorinfo, "file_name", ""),
                                        json_get<size_t>(fulltensorinfo, "file_size", 0)});
        }
    }

    return meta;
}

static Metadata load_meta_json(const std::string &meta_json)
{
#ifdef _WIN32
    std::filesystem::path file_path(meta_json);
    std::string long_path = make_long_path(file_path);
    std::ifstream ifs(long_path, std::ios::binary);
#else
    std::ifstream ifs(meta_json);
#endif
    if (!ifs.is_open())
    {
        throw std::runtime_error("Can not open meta json file");
    }
    std::stringstream ss;
    ss << ifs.rdbuf();
    std::string meta_string = ss.str();
    auto meta = load_meta_string(meta_string);
    meta.json_path = meta_json;
    ifs.close();
    return meta;
}

void update_pdi(Metadata *meta)
{
    for (size_t op_id = 0; op_id < meta->op_list.size(); op_id++)
    {
        auto &attr = meta->op_list.at(op_id).attr;
        if (attr.count("pm_id") && attr.at("pm_id").type() == typeid(std::vector<int>))
        {

            const auto &pm_id = std::any_cast<std::vector<int> &>(attr.at("pm_id"));
            meta->op_list.at(op_id).pdi_id = pm_id[0];
        }
    }
}

void update_ref_idx(Metadata* meta)
{
    const auto& out_buf_names = meta->fused_tensors.at("out_onnx").packed_tensors;
    for (int i = 0; i < out_buf_names.size(); i++)
    {
        auto& tensor_info = MAP_AT(meta->tensor_map, out_buf_names[i]);
        tensor_info.ref_idx = i;
    }

    const auto& in_buf_names = meta->fused_tensors.at("in_onnx").packed_tensors;
    for (int i = 0; i < in_buf_names.size(); i++)
    {
        auto& tensor_info = MAP_AT(meta->tensor_map, in_buf_names[i]);
        tensor_info.ref_idx = i;
    }
}

FusionRuntime::FusionRuntime(const rtcfg &cfg, const json &tilings_data)
{
    // xclbin_fname_ = "";
    verbose_ = cfg.debug_cfg.enable_trace;
    binfile_path_ = cfg.HWbin_path;
    prebuilt_bin_dir_ = cfg.prebuilt_bin_dir;
    cache_dir_ = cfg.cache_path;
    use_inmem_ = cfg.compile_cfg.use_inmem;
    Metadata init_meta = load_meta_json(cfg.meta_json_path);
    if (init_meta.op_list.size())
    {
        update_ref_idx(&init_meta);
        meta_ = remove_identity_ops(init_meta);
        /*auto& out_buf_names = meta_.fused_tensors.at("out_onnx").packed_tensors;

        for (int i = 0; i < out_buf_names.size(); i++)
        {
            const auto& tensor_info = MAP_AT(meta_.tensor_map, out_buf_names[i]);
            std::cout << tensor_info.additional_offset << ", " << tensor_info.ref_idx << std::endl;
        }*/
        meta_ = remove_concat_runtime_ops(meta_);
        meta_ = remove_split_runtime_ops(meta_);
        meta_ = remove_gather_runtime_ops(meta_);
        meta_ = remove_slice_runtime_ops(meta_);
        update_pdi(&meta_);
        prepare_formatting_ops("out_onnx", producer_ops_out_);
        prepare_formatting_ops("in_onnx", producer_ops_in_);
    }
    else
    {
        meta_ = init_meta;
    }
    elf_flow_ = check_elf_flow(cfg.xclbin_path, true);
    prefix_ = cfg.prefix;
    tilings_data_ = tilings_data;
}

FusionRuntime::FusionRuntime(const std::string &meta_file, const std::string &HWbin_path, const std::string &cache_dir, const bool &is_meta_string)
{
    binfile_path_ = HWbin_path;
    cache_dir_ = cache_dir;
    Metadata init_meta;
    if (is_meta_string) {
      init_meta = load_meta_string(meta_file);
    } else {
      init_meta = load_meta_json(meta_file);
    }
    if (init_meta.op_list.size())
    {
        update_ref_idx(&init_meta);
        meta_ = remove_identity_ops(init_meta);
        meta_ = remove_concat_runtime_ops(meta_);
        meta_ = remove_split_runtime_ops(meta_);
        meta_ = remove_gather_runtime_ops(meta_);
        meta_ = remove_slice_runtime_ops(meta_);
        update_pdi(&meta_);
        prepare_formatting_ops("out_onnx", producer_ops_out_);
        prepare_formatting_ops("in_onnx", producer_ops_in_);
    }
    else
    {
        meta_ = init_meta;
    }
}

void FusionRuntime::allocate_host_bos(const Metadata &meta)
{
    const_vec_file_ptr_ = create_tmpfile();
    input_vec_file_ptr_ = create_tmpfile();
    scratch_vec_ = std::vector<uint8_t>(MAP_AT(meta.fused_tensors, "scratch").size);
    output_vec_ = std::vector<uint8_t>(MAP_AT(meta.fused_tensors, "out").size);
    super_instr_vec_file_ptr_ = create_tmpfile();
    ctrl_pkt_vec_file_ptr_ = create_tmpfile();
}

void FusionRuntime::release_host_resources()
{
    if (input_vec_file_ptr_)
    {
        fclose(input_vec_file_ptr_);
    }

    if (const_vec_file_ptr_)
    {
        fclose(const_vec_file_ptr_);
    }

    if (super_instr_vec_file_ptr_)
    {
        fclose(super_instr_vec_file_ptr_);
    }

    if (ctrl_pkt_vec_file_ptr_)
    {
        fclose(ctrl_pkt_vec_file_ptr_);
    }
}

// For every Op, collect all the const data it has.
// Pass everything to the OpInterface and let it copy to
//    the right place.
void FusionRuntime::load_const(const Metadata &meta)
{

    for (const auto &op_info : meta.op_list)
    {
        const auto &tensor_names = op_info.const_args;

        std::vector<Tensor> const_tensors;

        // Get the offset of this op's const buffer in const bo.
        size_t offset = 0;
        // if const_map is empty, call all initilize_const_params for constant
        // initilization, if any.
        if (meta.const_map.find(op_info.name) != meta.const_map.end())
        {
            const auto &tensor_info = meta.const_map.at(op_info.name);
            offset = tensor_info.offset;
        }

        using signature = void(ConstBufferIO &, const std::vector<Tensor> &, const std::map<std::string, std::any> &);
        fseek64(const_vec_file_ptr_, offset, SEEK_SET);
        auto io = TmpFileConst(const_vec_file_ptr_);

        waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<int8_t, int8_t, int8_t>(
            op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
        subop_.set_verbose(verbose_);
        std::vector<Tensor> tensors;
        subop_.initialize_const_params(io, const_tensors, op_info.attr);
    }
    fseek64(const_vec_file_ptr_, 0, SEEK_SET);
}

void FusionRuntime::fill_super_instr(const Metadata &meta)
{
    for (const auto &op_info : meta.op_list)
    {

        auto offset = MAP_AT(meta.super_instr_map, op_info.name).offset;

        waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<int8_t, int8_t, int8_t>(
            op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
        subop_.set_verbose(verbose_);
        std::vector<Tensor> tensors;
        auto super_instr = subop_.get_super_kernel_params(tensors, tensors, op_info.attr);

        fseek64(super_instr_vec_file_ptr_, offset, SEEK_SET);
        std::fwrite(super_instr.data(), 1, super_instr.size(), super_instr_vec_file_ptr_);
    }
    fseek64(super_instr_vec_file_ptr_, 0, SEEK_SET);
}

void FusionRuntime::fill_ctrl_pkts(const Metadata &meta)
{
    for (const auto &op_info : meta.op_list)
    {

        auto offset = MAP_AT(meta.ctrl_pkt_map, op_info.name).offset;

        waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<int8_t, int8_t, int8_t>(
            op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
        subop_.set_verbose(verbose_);
        std::vector<Tensor> tensors;
        auto ctrl_pkts = subop_.get_ctrl_pkts(tensors, tensors, op_info.attr);

        fseek64(ctrl_pkt_vec_file_ptr_, offset, SEEK_SET);
        std::fwrite(ctrl_pkts.data(), 1, ctrl_pkts.size(), ctrl_pkt_vec_file_ptr_);
    }
    fseek64(ctrl_pkt_vec_file_ptr_, 0, SEEK_SET);
}

void FusionRuntime::fetch_txn_bins(Metadata &meta)
{

    for (auto &op_info : meta.op_list)
    {
        auto offset = MAP_AT(meta.ctrl_pkt_map, op_info.name).offset;

        subgraph_op subop_ = subgraph_op<int8_t, int8_t, int8_t>(op_info.type, binfile_path_, prebuilt_bin_dir_,
                                                                 tilings_data_, op_info.attr, use_inmem_);
        subop_.set_verbose(verbose_);
        std::vector<Tensor> tensors;
        auto txn_vec = subop_.get_transaction_bin(tensors, tensors, op_info.attr);
        auto args_map = subop_.get_buffer_reqs(tensors, tensors, op_info.attr);

        txn_util patched_txn(txn_vec);
        patched_txn.patch(op_info, meta, args_map);
        op_info.txn_bin = std::move(patched_txn.to_vector());
    }
}

inline void append_bytes(std::vector<uint8_t> &bo, const void* data, size_t size) {
  const uint8_t* p = static_cast<const uint8_t*>(data);
  bo.insert(bo.end(), p, p + size);
}

void FusionRuntime::write_in_vec(
    std::vector<uint8_t> &const_bo,
    std::vector<uint8_t> &instr_bo,
    std::vector<uint8_t> &super_instr_bo)
{
    save_tmpfile_in_vec(const_vec_file_ptr_, const_bo);
    save_tmpfile_in_vec(super_instr_vec_file_ptr_, super_instr_bo);
    save_tmpfile_in_vec(ctrl_pkt_vec_file_ptr_, super_instr_bo);
    size_t size;
    if (elf_flow_)
    {
        size = opt_txns_.size();
        append_bytes(instr_bo, &size, sizeof(size));
        for (const auto &instr : opt_txns_)
        {
            size = instr.size();
            append_bytes(instr_bo, &size, sizeof(size));
            append_bytes(instr_bo, instr.data(), instr.size());
        }
    }
    else
    {
        size = fused_instr_vec_.size();
        append_bytes(instr_bo, &size, sizeof(size));
        for (const auto &instr : fused_instr_vec_)
        {
            size = instr.size();
            append_bytes(instr_bo, &size, sizeof(size));
            append_bytes(instr_bo, instr.data(), instr.size());
        }
    }
}

void FusionRuntime::save_files()
{
    // create the cache_dir folder if not existing
    if (!std::filesystem::exists(cache_dir_.c_str()))
    {
        std::filesystem::create_directory(cache_dir_.c_str());
    }
    std::string filename = prefix_ + "const_bo_fname.bin";
    auto filepath = std::filesystem::path{cache_dir_} / filename;

    save_tmpfile_on_disk(filepath, const_vec_file_ptr_, "wb");

    filename = prefix_ + "super_instr_bo_fname.bin";
    filepath = std::filesystem::path{cache_dir_} / filename;
    save_tmpfile_on_disk(filepath, super_instr_vec_file_ptr_, "wb");
    // if (!elf_flow_) {
    save_tmpfile_on_disk(filepath, ctrl_pkt_vec_file_ptr_, "ab+");
    //}
    // else {
    //    filename = prefix_ + "ctrl_pkt.bin";
    //    filepath = std::filesystem::path{ cache_dir_ } / filename;
    //    save_tmpfile_on_disk(filepath, ctrl_pkt_vec_file_ptr_, "wb");
    //}

    filename = prefix_ + "instr_bo_fname.bin";
    filepath = std::filesystem::path{cache_dir_} / filename;
    size_t size;
    std::ofstream outFile(filepath.string(), std::ios::binary);
    if (elf_flow_)
    {
        size = opt_txns_.size();
        outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));
        for (const auto &instr : opt_txns_)
        {
            size = instr.size();
            outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));
            outFile.write(reinterpret_cast<const char *>(instr.data()), instr.size());
        }
    }
    else
    {
        size = fused_instr_vec_.size();
        outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));
        for (const auto &instr : fused_instr_vec_)
        {
            size = instr.size();
            outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));
            outFile.write(reinterpret_cast<const char *>(instr.data()), instr.size());
        }
    }
    outFile.close();
}

void FusionRuntime::save_ctrl_pkt_info(const Metadata &meta, bool elf_flow)
{
    const std::string ctrlpkt_info_fname = cache_dir_ + "./" + prefix_ + "ctrl_pkt_info.json";
    json ctrl_info;

    uint64_t ctrl_pkt_bo_offset = 0;
    // if (!elf_flow) {
    ctrl_pkt_bo_offset = MAP_AT(meta.fused_tensors, "super_instr").size;
    // }

    ctrl_info["version"] = "1.1";
    ctrl_info["ctrl_pkt_xrt_arg_idx"] = 5;
    ctrl_info["ctrl_pkt_patch_info"] = json::array();

    for (auto &op : meta.op_list)
    {
        auto &patch_info = op.ctrl_pkt_patch_info;
        // if ctrl packet ddoes not exist for an op, skip the patching
        if ((meta.ctrl_pkt_map.find(op.name) == meta.ctrl_pkt_map.end()) || (!patch_info.size()))
        {
            continue;
        }
        auto op_offset = meta.ctrl_pkt_map.at(op.name).offset + ctrl_pkt_bo_offset;

        for (auto &patch : patch_info)
        {
            auto offset = op_offset + patch.offset;
            ctrl_info["ctrl_pkt_patch_info"].push_back({//{"name", op.name},
                                                        {"offset", offset},
                                                        {"size", patch.size},
                                                        {"xrt_arg_idx", patch.xrt_arg_idx},
                                                        {"bo_offset", patch.bo_offset}});
        }
    }

    std::ofstream jsonf(ctrlpkt_info_fname);
    jsonf << std::setw(4) << ctrl_info << std::endl;
}

void FusionRuntime::write_ctrl_pkt_info_in_vec(const Metadata &meta, bool elf_flow,
                                               std::vector<uint8_t> &ctrl_pkt_info)
{
    json ctrl_info;
    uint64_t ctrl_pkt_bo_offset = 0;
    ctrl_pkt_bo_offset = MAP_AT(meta.fused_tensors, "super_instr").size;

    ctrl_info["version"] = "1.1";
    ctrl_info["ctrl_pkt_xrt_arg_idx"] = 5;
    ctrl_info["ctrl_pkt_patch_info"] = json::array();

    for (auto &op : meta.op_list)
    {
        auto &patch_info = op.ctrl_pkt_patch_info;
        // if ctrl packet ddoes not exist for an op, skip the patching
        if ((meta.ctrl_pkt_map.find(op.name) == meta.ctrl_pkt_map.end()) || (!patch_info.size()))
        {
            continue;
        }
        auto op_offset = meta.ctrl_pkt_map.at(op.name).offset + ctrl_pkt_bo_offset;

        for (auto &patch : patch_info)
        {
            auto offset = op_offset + patch.offset;
            ctrl_info["ctrl_pkt_patch_info"].push_back({//{"name", op.name},
                                                        {"offset", offset},
                                                        {"size", patch.size},
                                                        {"xrt_arg_idx", patch.xrt_arg_idx},
                                                        {"bo_offset", patch.bo_offset}});
        }
    }

    std::ostringstream oss;
    oss << std::setw(4) << ctrl_info << std::endl;
    std::string json_text = oss.str();
    ctrl_pkt_info.insert(ctrl_pkt_info.end(), json_text.begin(), json_text.end());
}

std::vector<std::vector<uint8_t>> FusionRuntime::generate_fused_txns(const Metadata &meta)
{

    std::vector<std::vector<uint8_t>> fused_txns;
    // assume partition size is 1
    fused_txns.reserve(meta.partitions.size());
    std::vector<std::vector<uint8_t>> txns;
    opt_txns_.clear();
    opt_txns_.reserve(meta.partitions.size());

    size_t partition_index = 0;

    for (const auto &partition : meta.partitions)
    {
        // get fused transactions
        std::vector<txn_vec_t> txn_vecs;
        auto &op_range = partition.op_range;
        size_t const num_ops = op_range.second > op_range.first ? (op_range.second - op_range.first) : (1);
        txn_vecs.reserve(num_ops);
        for (auto ind = op_range.first; ind < op_range.second; ind++)
        {
            const auto &op_info = meta.op_list.at(ind);
            auto txn_bin = op_info.txn_bin;
            txn_vecs.push_back(txn_bin);
        }
        auto fused_txn = txn_util::fuse_txns(txn_vecs);

        opt_txns_.push_back(fused_txn);

        txn_util txn = txn_util(opt_txns_.back());
        auto ibuf_op = transaction_op(txn.txn);

        fused_txns.push_back(std::move(ibuf_op.get_txn_op()));

        partition_index += 1;
    }

    return fused_txns;
}

void FusionRuntime::prepare_formatting_ops(const std::string &io_name, std::vector<ProducerEntry> &producer_ops)
{
    // 1) Clear any old data
    producer_ops.clear();

    // 2) Get the list of final output tensor names
    const auto &out_buf_names = meta_.fused_tensors.at(io_name).packed_tensors;

    // 3) Build a ProducerEntry for each final output in order
    if (io_name.find("out") != std::string::npos)
    {
        for (const auto &tensor_name : out_buf_names)
        {
            bool found = false;

            // Search for which op produces this tensor
            for (auto &op_info : meta_.op_list)
            {
                for (size_t out_i = 0; out_i < op_info.out_args.size(); ++out_i)
                {
                    if (op_info.out_args[out_i] == tensor_name)
                    {
                        // Build a single operator instance
                        // waic_runner::subgraph_op subop_ =
                        //    waic_runner::subgraph_op<int8_t, int8_t, int8_t>(binfile_path_, op_info.attr);

                        ProducerEntry entry{
                            op_info, // store operator metadata
                            out_i    // which output index in op_info.out_args
                        };

                        // Add to our vector in the same order as out_buf_names
                        producer_ops.push_back(std::move(entry));
                        found = true;
                        break; // Stop searching out_args
                    }
                }
                if (found)
                {
                    break; // Stop searching op_list
                }
            }

            // if (!found) {
            //     std::cout << "WARNING: No producing operator found for final output: " <<
            //         tensor_name << std::endl;
            // }
        }
    }
    else if (io_name.find("in") != std::string::npos)
    {
        // if one input tensor is connected to more than one op, then assume all ops process the same input the same
        // way.
        for (const auto &tensor_name : out_buf_names)
        {
            bool found = false;

            // Search for which op produces this tensor
            for (auto &op_info : meta_.op_list)
            {
                for (size_t out_i = 0; out_i < op_info.in_args.size(); ++out_i)
                {
                    if (op_info.in_args[out_i] == tensor_name)
                    {
                        // Build a single operator instance
                        // waic_runner::subgraph_op subop_ =
                        //    waic_runner::subgraph_op<int8_t, int8_t, int8_t>(binfile_path_, op_info.attr);

                        ProducerEntry entry{
                            op_info, // store operator metadata
                            out_i    // which output index in op_info.out_args
                        };

                        // Add to our vector in the same order as out_buf_names
                        producer_ops.push_back(std::move(entry));
                        found = true;
                        break; // Stop searching out_args
                    }
                }
                if (found)
                {
                    break; // Stop searching op_list
                }
            }

            // if (!found) {
            //     std::cout << "WARNING: No producing operator found for final output: " <<
            //         tensor_name << std::endl;
            // }
        }
    }
    else
    {
        throw std::runtime_error("Not valid io_name.");
    }
}

template <typename InT>
void format_input(const Tensor &in_tensor, void *hw_in_ptr, size_t sz, const FullTensorInfo &ftenInfo)
{
    // shape from the tensor
    auto raw_input_shape = in_tensor.shape;
    std::vector<size_t> input_shape_nhwc = convert_nchw_shape(in_tensor, ftenInfo.onnx_format);
    auto input_datatype_size = get_size_of_type(ftenInfo.onnx_dtype);
    size_t actual_input_size =
        std::accumulate(raw_input_shape.begin(), raw_input_shape.end(), (size_t)1, std::multiplies<size_t>());
    actual_input_size *= input_datatype_size;

    auto cpu_in = (InT *)in_tensor.data;
    auto aie_in = (InT *)hw_in_ptr;
    // std::vector<InT> out(sz/sizeof(InT));

    // shape from AIE
    size_t input_bo_size = get_size_of_type(ftenInfo.hw_dtype);
    // input_bo_size = ftenInfo.L3_alloc[2];
    for (size_t i = 0; i < ftenInfo.hw_shape.size(); i++)
    {
        input_bo_size *= ftenInfo.hw_shape[i];
    }

    if (input_bo_size > sz)
    {
        throw std::runtime_error("subgraph_op : The size of hw_in is not correct.");
    }

    if (actual_input_size > sz)
    {
        throw std::runtime_error("subgraph_op : The size of hw_in is smaller than onnx shape.");
    }

    std::vector<InT> in_trans(actual_input_size / sizeof(InT), 0);
    std::string skip_nchw_conversion = get_env_var("SKIP_NCHW_CONVERSION", "0");
    if (ftenInfo.onnx_format.find("NCHW") != std::string::npos && (skip_nchw_conversion == "0"))
    {
        if (raw_input_shape.size() == 2)
        {
            // raw_input_shape is from meta.json
            // cpu_in: nchw
            ActMatrix<InT> Cpu_in(raw_input_shape[0], raw_input_shape[1], cpu_in);
            // in_trans: nhwc
            ActMatrix<InT> In_trans(raw_input_shape[1], raw_input_shape[0], in_trans.data());
            for (size_t w = 0; w < raw_input_shape[1]; w++)
            {
                for (size_t c = 0; c < raw_input_shape[0]; c++)
                {
                    In_trans.at(w, c) = Cpu_in.at(c, w);
                }
            }
        }
        else if (raw_input_shape.size() == 3)
        {
            // raw_input_shape is from meta.json
            // cpu_in: nchw
            ActMatrix<InT> Cpu_in(raw_input_shape[0], raw_input_shape[1], raw_input_shape[2], cpu_in);
            // in_trans: nhwc
            ActMatrix<InT> In_trans(raw_input_shape[0], raw_input_shape[2], raw_input_shape[1], in_trans.data());
            for (size_t h = 0; h < raw_input_shape[0]; h++)
            {
                for (size_t w = 0; w < raw_input_shape[2]; w++)
                {
                    for (size_t c = 0; c < raw_input_shape[1]; c++)
                    {
                        In_trans.at(h, w, c) = Cpu_in.at(h, c, w);
                    }
                }
            }
        }
        else if (raw_input_shape.size() == 4)
        {
            // raw_input_shape is from meta.json
            // cpu_in: nchw
            ActMatrix<InT> Cpu_in(raw_input_shape[0], raw_input_shape[1], raw_input_shape[2], raw_input_shape[3],
                                  cpu_in);
            // in_trans: nhwc
            ActMatrix<InT> In_trans(raw_input_shape[0], raw_input_shape[2], raw_input_shape[3], raw_input_shape[1],
                                    in_trans.data());
            for (size_t n = 0; n < raw_input_shape[0]; n++)
            {
                for (size_t h = 0; h < raw_input_shape[2]; h++)
                {
                    for (size_t w = 0; w < raw_input_shape[3]; w++)
                    {
                        for (size_t c = 0; c < raw_input_shape[1]; c++)
                        {
                            In_trans.at(n, h, w, c) = Cpu_in.at(n, c, h, w);
                        }
                    }
                }
            }
        }
        else
        {
            throw std::runtime_error("output shape is not supported");
        }
    }
    else
    {
        memcpy((void *)in_trans.data(), (void *)cpu_in, actual_input_size);
    }

    // Get padded shape
    std::vector<size_t> input_shape_pad = ftenInfo.hw_shape;

    if ((actual_input_size == sz) && (ftenInfo.onnx_dtype == ftenInfo.hw_dtype))
    {
        memcpy((void *)hw_in_ptr, (void *)in_trans.data(), (actual_input_size));
    }
    else
    {
        if ((ftenInfo.onnx_dtype == ftenInfo.hw_dtype)) {
            // only support inner_dim padding
            size_t raw_in_ch = input_shape_nhwc[input_shape_nhwc.size() - 1];
            size_t padded_in_ch = input_shape_pad[input_shape_pad.size() - 1];
            size_t out_ch_size = 1;
            for (size_t i = 0; i < input_shape_pad.size()-1; i++)
            {
                out_ch_size *= input_shape_pad[i];
            }
            size_t raw_out_ch_size = 1;
            for (size_t i = 0; i < input_shape_nhwc.size() - 1; i++)
            {
                raw_out_ch_size *= input_shape_nhwc[i];
            }
            if (raw_out_ch_size != out_ch_size) {
                std::cout << "Padded outer size: " << out_ch_size << ", Raw outer size: " << raw_out_ch_size << std::endl;
                throw std::runtime_error("Input Tensor outer channel size doesn't match.");
            }            
            size_t offset = 0;
            for (size_t i = 0; i < out_ch_size; i++) {
                memcpy((void*)(static_cast<int8_t*>(hw_in_ptr) + offset), (void*)&in_trans[i * raw_in_ch], raw_in_ch * input_datatype_size);
                offset += padded_in_ch * input_datatype_size;
            }
        }
        else {
            throw std::runtime_error("!!! Invalid Input Tensor padding !!!");
        }
    }
}

template <typename OutT>
void format_output(const Tensor &out_tensor, void *hw_out_ptr, size_t sz, const FullTensorInfo &ftenInfo)
{
    // shape from the tensor
    auto output_shape = out_tensor.shape; // nchw shape
    std::vector<size_t> output_shape_nhwc = convert_nchw_shape(out_tensor, ftenInfo.onnx_format);
    auto output_datatype_size = get_size_of_type(ftenInfo.onnx_dtype);
    size_t actual_output_size =
        std::accumulate(output_shape.begin(), output_shape.end(), (size_t)1, std::multiplies<size_t>());
    actual_output_size *= output_datatype_size;

    auto aie_out = (OutT *)out_tensor.data;
    auto out = (OutT *)hw_out_ptr;
    size_t output_bo_size = get_size_of_type(ftenInfo.hw_dtype);
    for (size_t i = 0; i < ftenInfo.hw_shape.size(); i++)
    {
        output_bo_size *= ftenInfo.hw_shape[i];
    }
    if (sz < output_bo_size)
    {
        throw std::runtime_error("subgraph_op : The size of hw_out is not correct.");
    }

    // Get padded shape for tensor_idx and convert to size_t
    std::vector<size_t> output_shape_pad = ftenInfo.hw_shape;
    if ((actual_output_size == sz) && (ftenInfo.onnx_dtype == ftenInfo.hw_dtype))
    {
        memcpy((void *)aie_out, (void *)hw_out_ptr, (actual_output_size));
    }
    else
    {
        if ((ftenInfo.onnx_dtype == ftenInfo.hw_dtype)) {
            // only support inner_dim padding
            size_t raw_in_ch = output_shape_nhwc[output_shape_nhwc.size() - 1];
            size_t padded_in_ch = output_shape_pad[output_shape_pad.size() - 1];
            size_t out_ch_size = 1;
            for (size_t i = 0; i < output_shape_pad.size() - 1; i++)
            {
                out_ch_size *= output_shape_pad[i];
            }
            size_t raw_out_ch_size = 1;
            for (size_t i = 0; i < output_shape_nhwc.size() - 1; i++)
            {
                raw_out_ch_size *= output_shape_nhwc[i];
            }
            if (raw_out_ch_size != out_ch_size) {
                std::cout << "Padded outer size: " << out_ch_size << ", Raw outer size: " << raw_out_ch_size << std::endl;
                throw std::runtime_error("Output Tensor outer channel size doesn't match.");
            }
            size_t offset = 0;
            for (size_t i = 0; i < out_ch_size; i++) {
                memcpy((void*)&aie_out[i * raw_in_ch], (void*)(static_cast<int8_t*>(hw_out_ptr) + offset), raw_in_ch * output_datatype_size);
                offset += padded_in_ch * output_datatype_size;
            }
        }
        else {
            throw std::runtime_error("!!! Invalid Output Tensor padding !!!");
        }
    }

    std::string skip_nchw_conversion = get_env_var("SKIP_NCHW_CONVERSION", "0");
    if (ftenInfo.onnx_format.find("NCHW") != std::string::npos && (skip_nchw_conversion == "0"))
    {
        std::vector<OutT> out_trans(actual_output_size / sizeof(OutT), 0);
        if (output_shape.size() == 2)
        {
            // output_shape is from meta.json
            // out_trans: nchw
            ActMatrix<OutT> Out_trans(output_shape[0], output_shape[1], out_trans.data());
            // aie_out: nhwc
            ActMatrix<OutT> Aie_out(output_shape[1], output_shape[0], aie_out);
            for (size_t w = 0; w < output_shape[1]; w++)
            {
                for (size_t c = 0; c < output_shape[0]; c++)
                {
                    Out_trans.at(c, w) = Aie_out.at(w, c);
                }
            }
        }
        else if (output_shape.size() == 3)
        {
            // output_shape is from meta.json
            // out_trans: nchw
            ActMatrix<OutT> Out_trans(output_shape[0], output_shape[1], output_shape[2], out_trans.data());
            // aie_out: nhwc
            ActMatrix<OutT> Aie_out(output_shape[0], output_shape[2], output_shape[1], aie_out);
            for (size_t h = 0; h < output_shape[0]; h++)
            {
                for (size_t w = 0; w < output_shape[2]; w++)
                {
                    for (size_t c = 0; c < output_shape[1]; c++)
                    {
                        Out_trans.at(h, c, w) = Aie_out.at(h, w, c);
                    }
                }
            }
        }
        else if (output_shape.size() == 4)
        {
            // output_shape is from meta.json
            // out_trans: nchw
            ActMatrix<OutT> Out_trans(output_shape[0], output_shape[1], output_shape[2], output_shape[3],
                                      out_trans.data());
            // aie_out: nhwc
            ActMatrix<OutT> Aie_out(output_shape[0], output_shape[2], output_shape[3], output_shape[1], aie_out);
            for (size_t n = 0; n < output_shape[0]; n++)
            {
                for (size_t h = 0; h < output_shape[2]; h++)
                {
                    for (size_t w = 0; w < output_shape[3]; w++)
                    {
                        for (size_t c = 0; c < output_shape[1]; c++)
                        {
                            Out_trans.at(n, c, h, w) = Aie_out.at(n, h, w, c);
                        }
                    }
                }
            }
        }
        else
        {
            throw std::runtime_error("output shape is not supported");
        }
        memcpy((void *)aie_out, (void *)out_trans.data(), actual_output_size);
    }
}

void FusionRuntime::split_outputs(const std::vector<Tensor> &outputs, void *output_ptr)
{
    // 1) Basic check: user-provided outputs vs. metadata
    size_t n_meta_outputs = MetaUtils::get_num_outputs(meta_);
    if (outputs.size() != n_meta_outputs)
    {
        throw std::runtime_error("Number of outputs doesn't match with number of "
                                 "metadata outputs");
    }

    const auto &out_buf_names = meta_.fused_tensors.at("out_onnx").packed_tensors;
    const auto& npu_out_buf_names = meta_.fused_tensors.at("out").packed_tensors;
    auto hwout_tensors = MetaUtils::get_output_tensors(meta_);
    // 4) loop in the same order as out_buf_names & producer_ops_
    if (out_buf_names.size() != producer_ops_out_.size() && meta_.major_version < 2)
    {
        throw std::runtime_error("Mismatch: The number of final outputs doesn't match producer_ops_ "
                                 "size.");
    }

    for (int i = 0; i < out_buf_names.size(); i++)
    {
        if (outputs[i].shape != hwout_tensors[i].shape)
        {
            throw std::runtime_error("output tensor shapes doesn't match with the "
                                     "Runtime output tensor shapes");
        }
        if (outputs[i].dtype != hwout_tensors[i].dtype)
        {
            throw std::runtime_error("output tensor dtype doesn't match with the "
                                     "Runtime output tensor dtype");
        }
        
        const auto &tensor_info = MAP_AT(meta_.tensor_map, out_buf_names[i]);
        const auto& npu_tensor_info = MAP_AT(meta_.tensor_map, npu_out_buf_names[tensor_info.ref_idx]);
        size_t sz = std::accumulate(outputs[i].shape.begin(), outputs[i].shape.end(), size_t{1}, std::multiplies{}) *
                    get_size_of_type(outputs[i].dtype);

        // Pointer to hardware data in output_bo_
        void *hw_data_ptr = static_cast<char *>(output_ptr) + npu_tensor_info.offset + tensor_info.additional_offset;
        size_t hw_tensor_bo_sz = tensor_info.size_in_bytes;

        if (meta_.major_version > 1)
        { // new meta without compiling stage
            LOG_VERBOSE() << "Output tensor " << i << " offset: " << tensor_info.offset << std::endl;
            LOG_VERBOSE() << "Output tensor " << i << " OFM BO_id: " << meta_.out_tensors[i].L3_alloc[0] << std::endl;
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                format_output<uint8_t>(outputs[i],      // users final output
                                       hw_data_ptr,     // device buffer pointer
                                       hw_tensor_bo_sz, // device buffer size
                                       meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos)
            {
                format_output<int8_t>(outputs[i],      // users final output
                                      hw_data_ptr,     // device buffer pointer
                                      hw_tensor_bo_sz, // device buffer size
                                      meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                format_output<uint16_t>(outputs[i],      // users final output
                                        hw_data_ptr,     // device buffer pointer
                                        hw_tensor_bo_sz, // device buffer size
                                        meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos)
            {
                format_output<int16_t>(outputs[i],      // users final output
                                       hw_data_ptr,     // device buffer pointer
                                       hw_tensor_bo_sz, // device buffer size
                                       meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("float") != std::string::npos)
            {
                format_output<float>(outputs[i],      // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     meta_.out_tensors[i]);
            }
            else
            {
                throw std::runtime_error("Not valid datatype");
            }
        }
        else
        {
            // Producer entry for this final output
            auto &entry = producer_ops_out_[i];
            auto &op_info = entry.op_info;
            size_t out_idx = entry.out_index; // which output of the operator

            // call format_output(...) with the correct out_idx
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                subgraph_op subop_ = subgraph_op<uint8_t, uint8_t, uint8_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     out_idx,         // which operator output index
                                     op_info.attr);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos)
            {
                subgraph_op subop_ = subgraph_op<int8_t, int8_t, int8_t>(op_info.type, binfile_path_, prebuilt_bin_dir_,
                                                                         tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     out_idx,         // which operator output index
                                     op_info.attr);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                subgraph_op subop_ = subgraph_op<uint16_t, uint8_t, uint16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     out_idx,         // which operator output index
                                     op_info.attr);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos) {
                subgraph_op subop_ = subgraph_op<int16_t, int8_t, int16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    out_idx,         // which operator output index
                    op_info.attr);
            }
            else if (tensor_info.dtype.find("float") != std::string::npos) {
                subgraph_op subop_ = subgraph_op<float, uint8_t, uint16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    out_idx,         // which operator output index
                    op_info.attr);
            }
            else {
                throw std::runtime_error("Not valid datatype");
            }
        }
    }
}

void FusionRuntime::split_outputs(const std::vector<Tensor> &outputs, const std::vector<uint8_t> &output_bo)
{

    // 1) Basic check: user-provided outputs vs. metadata
    size_t n_meta_outputs = MetaUtils::get_num_outputs(meta_);
    if (outputs.size() != n_meta_outputs)
    {
        throw std::runtime_error("Number of outputs doesn't match with number of "
                                 "metadata outputs");
    }

    const auto &out_buf_names = meta_.fused_tensors.at("out_onnx").packed_tensors;
    auto hwout_tensors = MetaUtils::get_output_tensors(meta_);
    // 4) loop in the same order as out_buf_names & producer_ops_
    if (out_buf_names.size() != producer_ops_out_.size() && meta_.major_version < 2)
    {
        throw std::runtime_error("Mismatch: The number of final outputs doesn't match producer_ops_ "
                                 "size.");
    }

    for (int i = 0; i < out_buf_names.size(); i++)
    {
        if (outputs[i].shape != hwout_tensors[i].shape)
        {
            throw std::runtime_error("output tensor shapes doesn't match with the "
                                     "Runtime output tensor shapes");
        }
        if (outputs[i].dtype != hwout_tensors[i].dtype)
        {
            throw std::runtime_error("output tensor dtype doesn't match with the "
                                     "Runtime output tensor dtype");
        }

        const auto &tensor_info = MAP_AT(meta_.tensor_map, out_buf_names[i]);
        size_t sz = std::accumulate(outputs[i].shape.begin(), outputs[i].shape.end(), size_t{1}, std::multiplies{}) *
                    get_size_of_type(outputs[i].dtype);

        // Pointer to hardware data in output_bo_
        void *hw_data_ptr = (char *)(output_bo.data()) + tensor_info.offset;
        size_t hw_tensor_bo_sz = tensor_info.size_in_bytes;

        if (meta_.major_version > 1)
        { // new meta without compiling stage
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                format_output<uint8_t>(outputs[i],      // users final output
                                       hw_data_ptr,     // device buffer pointer
                                       hw_tensor_bo_sz, // device buffer size
                                       meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos)
            {
                format_output<int8_t>(outputs[i],      // users final output
                                      hw_data_ptr,     // device buffer pointer
                                      hw_tensor_bo_sz, // device buffer size
                                      meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                format_output<uint16_t>(outputs[i],      // users final output
                                        hw_data_ptr,     // device buffer pointer
                                        hw_tensor_bo_sz, // device buffer size
                                        meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos)
            {
                format_output<int16_t>(outputs[i],      // users final output
                                       hw_data_ptr,     // device buffer pointer
                                       hw_tensor_bo_sz, // device buffer size
                                       meta_.out_tensors[i]);
            }
            else if (tensor_info.dtype.find("float") != std::string::npos)
            {
                format_output<float>(outputs[i],      // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     meta_.out_tensors[i]);
            }
            else
            {
                throw std::runtime_error("Not valid datatype");
            }
        }
        else
        {
            // Producer entry for this final output
            auto &entry = producer_ops_out_[i];
            auto &op_info = entry.op_info;
            size_t out_idx = entry.out_index; // which output of the operator

            // call format_output(...) with the correct out_idx
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                subgraph_op subop_ = subgraph_op<uint8_t, uint8_t, uint8_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    out_idx,         // which operator output index
                    op_info.attr);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos) {
                subgraph_op subop_ = subgraph_op<int8_t, int8_t, int8_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     out_idx,         // which operator output index
                                     op_info.attr);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                subgraph_op subop_ = subgraph_op<uint16_t, uint8_t, uint16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     out_idx,         // which operator output index
                                     op_info.attr);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos) {
                subgraph_op subop_ = subgraph_op<int16_t, int8_t, int16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_output(outputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    out_idx,         // which operator output index
                    op_info.attr);
            }
            else {
                throw std::runtime_error("Not valid datatype");
            }
        }
    }
}

void FusionRuntime::merge_inputs(const std::vector<Tensor> &inputs, void *input_ptr, void *output_ptr)
{
    size_t n_meta_inputs = MetaUtils::get_num_inputs(meta_);
    if (inputs.size() != n_meta_inputs)
    {
        throw std::runtime_error("Number of inputs doesn't match with that of metadata");
    }

    const auto &in_buf_names = MAP_AT(meta_.fused_tensors, "in_onnx").packed_tensors;
    //auto hwin_tensors = MetaUtils::get_input_tensors(meta_);
    if (in_buf_names.size() != producer_ops_in_.size() && meta_.major_version < 2)
    {
        throw std::runtime_error("Mismatch: The number of final outputs doesn't match producer_ops_ "
                                 "size.");
    }

    for (int i = 0; i < in_buf_names.size(); i++)
    {
        const auto &tensor_info = MAP_AT(meta_.tensor_map, in_buf_names[i]);
        size_t sz = std::accumulate(inputs[i].shape.begin(), inputs[i].shape.end(), size_t{1}, std::multiplies{}) *
                    get_size_of_type(inputs[i].dtype);

        // Pointer to hardware data in output_bo_
        void *hw_data_ptr;
        size_t hw_tensor_bo_sz = tensor_info.size_in_bytes;

        if (meta_.major_version > 1)
        { // new meta without compiling stage
            LOG_VERBOSE() << "Input tensor " << i << " offset: " << tensor_info.offset << std::endl;
            if (meta_.in_tensors[i].L3_alloc[0] == 1)
            { // ifm
                LOG_VERBOSE() << "Input tensor " << i << " IFM BO_id: " << meta_.in_tensors[i].L3_alloc[0] << std::endl;
                hw_data_ptr = static_cast<char *>(input_ptr) + tensor_info.offset;
            }
            else
            { // ofm
                LOG_VERBOSE() << "Input tensor " << i << " OFM BO_id: " << meta_.in_tensors[i].L3_alloc[0] << std::endl;
                hw_data_ptr = static_cast<char *>(output_ptr) + tensor_info.offset;
            }
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                format_input<uint8_t>(inputs[i],       // users final output
                                      hw_data_ptr,     // device buffer pointer
                                      hw_tensor_bo_sz, // device buffer size
                                      meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos)
            {
                format_input<int8_t>(inputs[i],       // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                format_input<uint16_t>(inputs[i],       // users final output
                                       hw_data_ptr,     // device buffer pointer
                                       hw_tensor_bo_sz, // device buffer size
                                       meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos)
            {
                format_input<int16_t>(inputs[i],       // users final output
                                      hw_data_ptr,     // device buffer pointer
                                      hw_tensor_bo_sz, // device buffer size
                                      meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("float") != std::string::npos)
            {
                format_input<float>(inputs[i],       // users final output
                                    hw_data_ptr,     // device buffer pointer
                                    hw_tensor_bo_sz, // device buffer size
                                    meta_.in_tensors[i]);
            }
            else
            {
                throw std::runtime_error("Not valid datatype");
            }
        }
        else
        {
            hw_data_ptr = static_cast<char *>(input_ptr) + tensor_info.offset;
            // Producer entry for this final output
            auto &entry = producer_ops_in_[i];
            auto &op_info = entry.op_info;
            size_t in_idx = entry.out_index; // which input of the operator

            // call format_input(...) with the correct in_idx
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<uint8_t, uint8_t, uint8_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],       // users final output
                                    hw_data_ptr,     // device buffer pointer
                                    hw_tensor_bo_sz, // device buffer size
                                    in_idx,          // which operator output index
                                    op_info.attr);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos)
            {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<int8_t, int8_t, int8_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],       // users final output
                                    hw_data_ptr,     // device buffer pointer
                                    hw_tensor_bo_sz, // device buffer size
                                    in_idx,          // which operator output index
                                    op_info.attr);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<uint16_t, uint8_t, uint16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],       // users final output
                                    hw_data_ptr,     // device buffer pointer
                                    hw_tensor_bo_sz, // device buffer size
                                    in_idx,          // which operator output index
                                    op_info.attr);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos) {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<int16_t, int8_t, int16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    in_idx,         // which operator output index
                    op_info.attr);
            }
            else if (tensor_info.dtype.find("float") != std::string::npos) {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<uint16_t, uint8_t, float>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    in_idx,         // which operator output index
                    op_info.attr);
            }
            else {
                throw std::runtime_error("Not valid datatype");
            }
        }
    }
}

std::vector<uint8_t> FusionRuntime::merge_inputs(const std::vector<Tensor> &inputs)
{
    size_t n_meta_inputs = MetaUtils::get_num_inputs(meta_);
    if (inputs.size() != n_meta_inputs)
    {
        throw std::runtime_error("Number of inputs doesn't match with that of metadata");
    }

    const auto &in_buf_names = MAP_AT(meta_.fused_tensors, "in_onnx").packed_tensors;
    //auto hwin_tensors = MetaUtils::get_input_tensors(meta_);
    if (in_buf_names.size() != producer_ops_in_.size() && meta_.major_version < 2)
    {
        throw std::runtime_error("Mismatch: The number of final outputs doesn't match producer_ops_ "
                                 "size.");
    }

    auto input_bo_size = MAP_AT(meta_.fused_tensors, "in").size;
    std::vector<uint8_t> input_bo(input_bo_size);

    std::vector<uint8_t> output_bo;
    if (meta_.major_version > 1)
    {
        auto output_bo_size = MAP_AT(meta_.fused_tensors, "out").size;
        output_bo.resize(output_bo_size);
    }

    for (int i = 0; i < in_buf_names.size(); i++)
    {
        const auto &tensor_info = MAP_AT(meta_.tensor_map, in_buf_names[i]);
        size_t sz = std::accumulate(inputs[i].shape.begin(), inputs[i].shape.end(), size_t{1}, std::multiplies{}) *
                    get_size_of_type(inputs[i].dtype);

        // Pointer to hardware data in output_bo_
        void *hw_data_ptr;
        size_t hw_tensor_bo_sz = tensor_info.size_in_bytes;

        if (meta_.major_version > 1)
        { // new meta without compiling stage
            if (meta_.in_tensors[i].L3_alloc[0] == 1)
            { // ifm
                hw_data_ptr = (char *)(input_bo.data()) + tensor_info.offset;
            }
            else
            { // ofm
                hw_data_ptr = (char *)(output_bo.data()) + tensor_info.offset;
            }
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                format_input<uint8_t>(inputs[i],       // users final output
                                      hw_data_ptr,     // device buffer pointer
                                      hw_tensor_bo_sz, // device buffer size
                                      meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos)
            {
                format_input<int8_t>(inputs[i],       // users final output
                                     hw_data_ptr,     // device buffer pointer
                                     hw_tensor_bo_sz, // device buffer size
                                     meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                format_input<uint16_t>(inputs[i],       // users final output
                                       hw_data_ptr,     // device buffer pointer
                                       hw_tensor_bo_sz, // device buffer size
                                       meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos)
            {
                format_input<int16_t>(inputs[i],       // users final output
                                      hw_data_ptr,     // device buffer pointer
                                      hw_tensor_bo_sz, // device buffer size
                                      meta_.in_tensors[i]);
            }
            else if (tensor_info.dtype.find("float") != std::string::npos)
            {
                format_input<float>(inputs[i],       // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    meta_.in_tensors[i]);
            }
            else
            {
                throw std::runtime_error("Not valid datatype");
            }
        }
        else
        {
            hw_data_ptr = (char *)(input_bo.data()) + tensor_info.offset;
            // Producer entry for this final output
            auto &entry = producer_ops_in_[i];
            auto &op_info = entry.op_info;
            size_t in_idx = entry.out_index; // which input of the operator

            // call format_input(...) with the correct in_idx
            if (tensor_info.dtype.find("uint8") != std::string::npos)
            {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<uint8_t, uint8_t, uint8_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    in_idx,         // which operator output index
                    op_info.attr);
            }
            else if (tensor_info.dtype.find("int8") != std::string::npos) {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<int8_t, int8_t, int8_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],       // users final output
                                    hw_data_ptr,     // device buffer pointer
                                    hw_tensor_bo_sz, // device buffer size
                                    in_idx,          // which operator output index
                                    op_info.attr);
            }
            else if (tensor_info.dtype.find("uint16") != std::string::npos)
            {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<uint16_t, uint8_t, uint16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],       // users final output
                                    hw_data_ptr,     // device buffer pointer
                                    hw_tensor_bo_sz, // device buffer size
                                    in_idx,          // which operator output index
                                    op_info.attr);
            }
            else if (tensor_info.dtype.find("int16") != std::string::npos) {
                waic_runner::subgraph_op subop_ = waic_runner::subgraph_op<int16_t, int8_t, int16_t>(
                    op_info.type, binfile_path_, prebuilt_bin_dir_, tilings_data_, op_info.attr, use_inmem_);
                subop_.set_verbose(verbose_);
                subop_.format_input(inputs[i],      // users final output
                    hw_data_ptr,     // device buffer pointer
                    hw_tensor_bo_sz, // device buffer size
                    in_idx,         // which operator output index
                    op_info.attr);
            }
            else {
                throw std::runtime_error("Not valid datatype");
            }
        }
    }
    if (input_bo_size > 0)
    {
        return (input_bo);
    }
    else
    {
        return (output_bo);
    }
}

Metadata combine_partition(const Metadata &meta)
{
    Metadata new_meta = meta;
    new_meta.partitions.clear();

    const auto &partition = meta.partitions.at(0);
    auto new_partition = partition;
    new_partition.op_range.first = 0;
    new_partition.op_range.second = meta.op_list.size();
    new_meta.partitions.emplace_back(new_partition);

    return new_meta;
}

void FusionRuntime::compile(
      cpcfg cfg,
      std::vector<uint8_t> &const_bo,
      std::vector<uint8_t> &instr_bo,
      std::vector<uint8_t> &super_instr_bo,
      std::vector<uint8_t> &ctrl_pkt_info
    )
{
    LOG_VERBOSE() << "Running compiling: " << std::endl;
    LOG_VERBOSE() << "Setting profile level to " << cfg.profile << std::endl;
    LOG_VERBOSE() << "Setting eager mode to " << cfg.eager_mode << std::endl;
    LOG_VERBOSE() << "Setting optimized scratch to " << cfg.optimize_scratch << std::endl;
    LOG_VERBOSE() << "enable_preemption " << cfg.enable_preemption << std::endl;
    LOG_VERBOSE() << "enable_fast_pm " << cfg.enable_fast_pm << std::endl;
    LOG_VERBOSE() << "use_inmem " << cfg.use_inmem << std::endl;

    generate_pdi_partitions_pass(meta_, cfg.eager_mode);
    if (cfg.enable_fast_pm)
    {
        meta_ = insert_pm_load_nodes(meta_);
    }
    if (cfg.enable_preemption)
    {
        if (elf_flow_)
        {
            meta_ = insert_preemption_nodes(meta_);
            generate_pdi_partitions_pass(meta_, cfg.eager_mode);
        }
        else
        {
            std::cout << "No elf support in xclbin file" << std::endl;
        }
    }
    if (cfg.profile)
    {
        meta_ = insert_record_timer_nodes(meta_, cfg.profile);
        generate_pdi_partitions_pass(meta_, cfg.eager_mode);
    }
    analyze_buffer_reqs(meta_, binfile_path_, prebuilt_bin_dir_, tilings_data_, cfg.use_inmem);
    if (cfg.optimize_scratch == 1)
    {
        std::string opt_ver = get_env_var("OPT_VER", "2");
        if (opt_ver == "1")
        {
            optimize_scratch_buffer(meta_);
        }
        else
        {
            optimize_scratch_buffer_contiguous(meta_);
        }
    }
    allocate_host_bos(meta_);
    load_const(meta_);
    fill_super_instr(meta_);
    fill_ctrl_pkts(meta_);
    relocate_ctrl_pkt_patch_info(meta_, binfile_path_, prebuilt_bin_dir_, elf_flow_, tilings_data_, cfg.use_inmem);
    fetch_txn_bins(meta_);
    if (verbose_)
    {
        std::cout << "Before Combined partitions ..." << std::endl;
        std::cout << MetaUtils::get_partition_pmid(meta_) << std::endl;
    }
    if (!cfg.eager_mode)
    {
        meta_ = combine_partition(meta_);
    }
    fused_instr_vec_ = generate_fused_txns(meta_);
    write_in_vec(const_bo, instr_bo, super_instr_bo);
    write_ctrl_pkt_info_in_vec(meta_, elf_flow_, ctrl_pkt_info);
    release_host_resources();
    // need to save state for real runtime

    LOG_VERBOSE() << "Compiling is done successfully! " << std::endl;
}

void FusionRuntime::compile(cpcfg cfg)
{
    LOG_VERBOSE() << "Running compiling: " << std::endl;
    LOG_VERBOSE() << "Setting profile level to " << cfg.profile << std::endl;
    LOG_VERBOSE() << "Setting eager mode to " << cfg.eager_mode << std::endl;
    LOG_VERBOSE() << "Setting optimized scratch to " << cfg.optimize_scratch << std::endl;
    LOG_VERBOSE() << "enable_preemption " << cfg.enable_preemption << std::endl;
    LOG_VERBOSE() << "enable_fast_pm " << cfg.enable_fast_pm << std::endl;
    LOG_VERBOSE() << "use_inmem " << cfg.use_inmem << std::endl;

    generate_pdi_partitions_pass(meta_, cfg.eager_mode);
    if (cfg.enable_fast_pm)
    {
        meta_ = insert_pm_load_nodes(meta_);
    }
    if (cfg.enable_preemption)
    {
        if (elf_flow_)
        {
            meta_ = insert_preemption_nodes(meta_);
            generate_pdi_partitions_pass(meta_, cfg.eager_mode);
        }
        else
        {
            std::cout << "No elf support in xclbin file" << std::endl;
        }
    }
    if (cfg.profile)
    {
        meta_ = insert_record_timer_nodes(meta_, cfg.profile);
        generate_pdi_partitions_pass(meta_, cfg.eager_mode);
    }
    analyze_buffer_reqs(meta_, binfile_path_, prebuilt_bin_dir_, tilings_data_, cfg.use_inmem);
    if (cfg.optimize_scratch == 1)
    {
        std::string opt_ver = get_env_var("OPT_VER", "2");
        if (opt_ver == "1")
        {
            optimize_scratch_buffer(meta_);
        }
        else
        {
            optimize_scratch_buffer_contiguous(meta_);
        }
    }
    allocate_host_bos(meta_);
    load_const(meta_);
    fill_super_instr(meta_);
    fill_ctrl_pkts(meta_);
    relocate_ctrl_pkt_patch_info(meta_, binfile_path_, prebuilt_bin_dir_, elf_flow_, tilings_data_, cfg.use_inmem);
    fetch_txn_bins(meta_);
    if (verbose_)
    {
        std::cout << "Before Combined partitions ..." << std::endl;
        std::cout << MetaUtils::get_partition_pmid(meta_) << std::endl;
    }
    if (!cfg.eager_mode)
    {
        meta_ = combine_partition(meta_);
    }
    fused_instr_vec_ = generate_fused_txns(meta_);
    save_files();
    save_ctrl_pkt_info(meta_, elf_flow_);
    release_host_resources();
    // need to save state for real runtime

    LOG_VERBOSE() << "Compiling is done successfully! " << std::endl;
}

void FusionRuntime::save_state(const std::string &state_name)
{
    save_meta(meta_, state_name, verbose_);
}

void FusionRuntime::save_state(std::vector<uint8_t> &state_data)
{
    save_meta(meta_, state_data, verbose_);
}

void FusionRuntime::load_state(const std::string &state_name)
{
    if (meta_.major_version < 2)
    {
        load_meta(meta_, state_name, verbose_);
    }
}

void FusionRuntime::load_state(const std::unique_ptr<std::vector<uint8_t>> state_data)
{
    if (meta_.major_version < 2)
    {
        load_meta_data(meta_, *state_data, verbose_);
    }
}

const Metadata &FusionRuntime::get_meta() const
{
    return meta_;
}
} // namespace waic_runner
