// Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
////////////////////////////////////////////////////////////////////////
#include "model_runner.hpp"

#include <algorithm>
#include <array>
#include <cmath>
#include <filesystem>
#include <numeric>
#include <ranges>
#include <sstream>

#include "aiebu/aiebu.h"
#include "aiebu/aiebu_assembler.h"
#include "aie_context.hpp"
#include "timing_logger.hpp"

#define LOG_VERBOSE()                                                                                                  \
    if (this->config_.verbose)                                                                                         \
    std::cout
using namespace waic_runner;
namespace waic_runner
{

// pass by value so timers are not disturbed by sort
using time_type_us = std::chrono::microseconds;
void print_runtime_info(std::vector<time_type_us> timers, configuration_runner cfg)
{
    // get average, median, variance

    auto sz = timers.size();

    if (sz == 1)
    {
        time_type_us mean_in_us = timers[0];

        if (cfg.verbose)
        {
            std::cout << "XRT Run          : " << mean_in_us.count() << " us" << std::endl;
            std::cout << "Average: " << mean_in_us.count() << " us" << std::endl;
        }

        return;
    }

    time_type_us null_time_{};
    const std::chrono::duration<double, std::micro> mean =
        std::accumulate(timers.begin(), timers.end(), null_time_) / sz;

    auto variance_func = [&mean, &sz](int64_t accumulator, const time_type_us &val) {
        int64_t diff = (int64_t)((val - mean).count());
        return accumulator + ((diff * diff) / (sz - 1));
    };

    auto var = (std::accumulate(timers.begin(), timers.end(), (int64_t)0, variance_func));
    double std_dev = std::sqrt((double)var);

    // get median by using nth element
    auto m_idx = timers.begin() + sz / 2;
    std::nth_element(timers.begin(), m_idx, timers.end());
    time_type_us median = *m_idx;

    if (cfg.verbose)
    {
        std::cout << "XRT Run          : " << mean.count() << " us" << std::endl;
        std::cout << "Average: " << mean.count() << " us Median: " << median.count() << " us stddev: " << std_dev
                  << " us" << std::endl;
    }
}

// #pragma pack(push,1)
// struct txn_header
//{
//     uint32_t op;
//     uint32_t size_in_bytes;
// };
// #pragma pack(pop)
//
// static_assert(sizeof(txn_header) == 8);

} // namespace waic_runner

configuration_runner prepare_default_config(const std::string &workload_path, size_t num_runs,
                                            const std::string &prefix, bo_size_bytes size_bos,
                                            const std::string &device)
{
    configuration_runner config;

    config.device = device;
    config.workload_path = workload_path + "/";

    config.num_runs = num_runs;

    config.xclbin_filename = "4x4.xclbin";
    // TODO: may need to have multiple instr_bo
    config.txnbin_filename = prefix + "instr_bo_fname.bin";
    config.input_filename = "input_bo_fname.bin";
    config.output_filename = prefix + "output_bo_fname.bin";
    config.scratch_filename = prefix + "scratch_bo_fname.bin";
    config.wgts_filename = prefix + "const_bo_fname.bin";
    config.param_filename = prefix + "super_instr_bo_fname.bin";
    config.ctrl_filename = prefix + "ctrl_pkt.bin";
    config.ctrl_pkt_info_filename = prefix + "ctrl_pkt_info.json";

    config.bo_sizes.output_size_bytes = size_bos.output_size_bytes;
    config.bo_sizes.input_size_bytes = size_bos.input_size_bytes;
    config.bo_sizes.scratch_size_bytes = size_bos.scratch_size_bytes;
    config.bo_sizes.no_scratch_buff = size_bos.no_scratch_buff;

    config.verbose = false;

    return config;
}

inline size_t calculate_vector_size(const std::vector<uint8_t>& v) {
  return v.size() * sizeof(uint8_t);
}

// Read all the bin files
model_runner::model_runner(configuration_runner config) : config_(config)
{
    wgts_path_ = this->config_.workload_path + this->config_.wgts_filename;
    param_path_ = this->config_.workload_path + this->config_.param_filename;
    output_path_ = this->config_.workload_path + this->config_.output_filename;
    scratch_path_ = this->config_.workload_path + this->config_.scratch_filename;

    std::filesystem::path wgt_file_path(wgts_path_);
    wgts_buf_size_ = std::filesystem::file_size(wgt_file_path);

    std::filesystem::path param_file_path(param_path_);
    param_buf_size_ = std::filesystem::file_size(param_file_path);

    if (this->config_.bo_sizes.output_size_bytes == 0)
    {
        std::filesystem::path output_file_path(output_path_);
        output_buf_size_ = std::filesystem::file_size(output_file_path);
    }
    else
    {
        output_buf_size_ = this->config_.bo_sizes.output_size_bytes;
    }

    if (this->config_.bo_sizes.scratch_size_bytes == 0 && this->config_.bo_sizes.no_scratch_buff == 0)
    {
        std::filesystem::path scratch_file_path(scratch_path_);
        scratch_buf_size_ = std::filesystem::file_size(scratch_file_path);
    }
    else
    {
        scratch_buf_size_ = this->config_.bo_sizes.scratch_size_bytes;
    }

    is_ctrl_pkt_ = 1;

    try
    {
        ctrlpkt_json_ = read_json_file(this->config_.workload_path + this->config_.ctrl_pkt_info_filename);
        ctrlpkt_info_ = extract_ctrlpkt_patch_info(ctrlpkt_json_);
    }
    catch (...)
    {
        std::cout << "NO Ctrl Pkt Patching" << std::endl;
        is_ctrl_pkt_ = 0;
    }

    if (this->config_.device == "WXB2")
    {
        load_txn_bin(this->config_.workload_path + this->config_.txnbin_filename);
        load_xclbin(this->config_.xclbin_filename);
        // if (elf_flow_) {
        //     ctrl_buf_ = ReadBinaryFile(
        //         this->config_.workload_path + this->config_.ctrl_filename,
        //         this->config_.verbose);
        // }
    }
    else if (this->config_.device == "QHW4")
    {
        load_elf(this->config_.xclbin_filename);
        kernelName_ = "DPU:subgraph_" + this->config_.kernel_name;
    }
    else
    {
        throw std::runtime_error("The config device type is not supported.");
    }
}

model_runner::model_runner(
    configuration_runner config,
    std::unique_ptr<std::vector<char>> ctrlpkt_vec,
    std::unique_ptr<std::vector<uint8_t>> wgts_vec,
    std::unique_ptr<std::vector<uint8_t>> txnbin_vec,
    std::unique_ptr<std::vector<uint8_t>> param_vec,
    std::unique_ptr<std::vector<char>> xclbin_vec
    ) : config_(config)
      , ctrlpkt_vec_(*ctrlpkt_vec)
      , wgts_vec_(*wgts_vec)
      , txnbin_vec_(*txnbin_vec)
      , param_vec_(*param_vec)
{
    output_path_ = this->config_.workload_path + this->config_.output_filename;

    wgts_buf_size_ = calculate_vector_size(wgts_vec_);
    param_buf_size_ = calculate_vector_size(param_vec_);

    if (this->config_.bo_sizes.output_size_bytes == 0)
    {
        std::filesystem::path output_file_path(output_path_);
        output_buf_size_ = std::filesystem::file_size(output_file_path);
    }
    else
    {
        output_buf_size_ = this->config_.bo_sizes.output_size_bytes;
    }

    if (this->config_.bo_sizes.scratch_size_bytes == 0 && this->config_.bo_sizes.no_scratch_buff == 0)
    {
        std::filesystem::path scratch_file_path(scratch_path_);
        scratch_buf_size_ = std::filesystem::file_size(scratch_file_path);
    }
    else
    {
        scratch_buf_size_ = this->config_.bo_sizes.scratch_size_bytes;
    }

    is_ctrl_pkt_ = 1;

    try
    {
        std::string ctrlpkt_json_data(ctrlpkt_vec_.data(), ctrlpkt_vec_.size());
        ctrlpkt_json_ = parse_json(ctrlpkt_json_data);
        ctrlpkt_info_ = extract_ctrlpkt_patch_info(ctrlpkt_json_);
    }
    catch (...)
    {
        std::cout << "NO Ctrl Pkt Patching" << std::endl;
        is_ctrl_pkt_ = 0;
    }

    if (this->config_.device == "WXB2")
    {
        load_txn_bin_file(txnbin_vec_);
        load_xclbin_data(*xclbin_vec, this->config_.workload_path);
    }
    else if (this->config_.device == "QHW4")
    {
        load_elf(this->config_.xclbin_filename);
        kernelName_ = "DPU:subgraph_" + this->config_.kernel_name;
    }
    else
    {
        throw std::runtime_error("The config device type is not supported.");
    }
}

void model_runner::load_elf(const std::string &elf_filename)
{
    // Initializes device instance
    std::string new_elf_filename;
    if (0) { // if use each subgraph control.elf
        std::filesystem::path p(elf_filename);
        std::string fextension = p.extension().string();
        new_elf_filename = elf_filename.substr(0, elf_filename.find_last_of('.')) + "_fused_hw_package_subgraph_" + this->config_.kernel_name + fextension;
    }
    else {
        new_elf_filename = elf_filename;
    }

    LOG_VERBOSE() << "ELF filename: " << new_elf_filename << std::endl;
    unsigned int device_index = 0;
    //device_ = xrt::device(device_index);
    //xrt::elf elf{new_elf_filename};

    // assume single partition in each subgraph
    num_partitions_ = 1;
    LOG_VERBOSE() << "Attempting to create hw_context... (reboot device if this fails) " << std::endl;
    context_ = waic_runner::aie_context::get_instance_qhw4(new_elf_filename, this->config_.verbose)->get_context();
    //context_ = xrt::hw_context{ device_, elf };
    LOG_VERBOSE() << "hw_context successfully created" << std::endl;
}

void model_runner::load_xclbin(const std::string &xclbin_filename)
{
    // Initializes xclbin, device instance
    LOG_VERBOSE() << "XCLBIN filename: " << xclbin_filename << std::endl;
    unsigned int device_index = 0;
    device_ = xrt::device(device_index);
    xclbin_ = xrt::xclbin(xclbin_filename);

    LOG_VERBOSE() << "Registering xclbin to device... uuid: " << xclbin_.get_uuid() << std::endl;
    device_.register_xclbin(xclbin_);

    LOG_VERBOSE() << "Attempting to create hw_context... (reboot device if this fails) " << std::endl;

    context_ = waic_runner::aie_context::get_instance(xclbin_filename)->get_context();

    LOG_VERBOSE() << "hw_context successfully created" << std::endl;
    elf_flow_ = check_elf_flow(xclbin_);
    if (!elf_flow_)
    {
        // assume WAIC always support elf_flow
        throw std::runtime_error("The xclbin doesn't support elf flow.");
    }
    if (!elf_flow_)
    {
        kernels_.resize(1);
        kernels_[0] = xrt::kernel(context_, KERNEL_NAME);
    }
}

void model_runner::load_xclbin_data(const std::vector<char> &xclbin_data, const std::string &xclbin_path)
{
    // Initializes xclbin, device instance
    unsigned int device_index = 0;
    device_ = xrt::device(device_index);
    xclbin_ = xrt::xclbin(xclbin_data);

    LOG_VERBOSE() << "Registering xclbin to device... uuid: " << xclbin_.get_uuid() << std::endl;
    device_.register_xclbin(xclbin_);

    LOG_VERBOSE() << "Attempting to create hw_context... (reboot device if this fails) " << std::endl;

    context_ = waic_runner::aie_context::get_instance(xclbin_path + "out.xclbin", xclbin_data)->get_context();

    LOG_VERBOSE() << "hw_context successfully created" << std::endl;
    elf_flow_ = check_elf_flow(xclbin_);
    if (!elf_flow_)
    {
        // assume WAIC always support elf_flow
        throw std::runtime_error("The xclbin doesn't support elf flow.");
    }
    if (!elf_flow_)
    {
        kernels_.resize(1);
        kernels_[0] = xrt::kernel(context_, KERNEL_NAME);
    }
}

void model_runner::load_txn_bin(const std::string &txnbin_filename)
{
    std::ifstream inFile(txnbin_filename, std::ios::binary);
    if (!inFile)
    {
        std::cerr << "Error opening file for reading!" << std::endl;
        return;
    }
    size_t vec_size;
    inFile.read(reinterpret_cast<char *>(&vec_size), sizeof(vec_size));
    num_partitions_ = vec_size;
    // Read each vector
    for (size_t i = 0; i < vec_size; ++i)
    {
        size_t each_vec_size;
        inFile.read(reinterpret_cast<char *>(&each_vec_size), sizeof(each_vec_size));
        if (each_vec_size > MAX_INSTR_BUFSZ)
        {
            LOG_VERBOSE() << "WARNING: Instruction buffer too large: " << each_vec_size << std::endl;
        }
        std::vector<uint8_t> each_instr(each_vec_size);
        inFile.read(reinterpret_cast<char *>(each_instr.data()), each_vec_size);
        instr_bufs_.push_back(each_instr);
    }
    inFile.close();
}

void model_runner::load_txn_bin_file(const std::vector<uint8_t> &txnbin_vec)
{
    std::istringstream in(std::string(reinterpret_cast<const char*>(txnbin_vec.data()), txnbin_vec.size()), std::ios::binary);
    if (!in)
    {
        std::cerr << "Error creating memory string!" << std::endl;
        return;
    }
    size_t vec_size;
    in.read(reinterpret_cast<char *>(&vec_size), sizeof(vec_size));
    num_partitions_ = vec_size;
    // Read each vector
    for (size_t i = 0; i < vec_size; ++i)
    {
        size_t each_vec_size;
        in.read(reinterpret_cast<char *>(&each_vec_size), sizeof(each_vec_size));
        if (each_vec_size > MAX_INSTR_BUFSZ)
        {
            LOG_VERBOSE() << "WARNING: Instruction buffer too large: " << each_vec_size << std::endl;
        }
        std::vector<uint8_t> each_instr(each_vec_size);
        in.read(reinterpret_cast<char *>(each_instr.data()), each_vec_size);
        instr_bufs_.push_back(each_instr);
    }
}

void model_runner::read_file_to_bo(xrt::bo &bo, std::string path, size_t seek, size_t buf_size)
{
    std::ifstream file(path, std::ios::binary);
    if (!file)
    {
        std::cout << "Unable to open .bin file" << std::endl;
        std::vector<uint8_t> dummy_buffer;
        dummy_buffer.resize(buf_size);
        bo.write(dummy_buffer.data(), dummy_buffer.size(), 0);
        return;
    }
    constexpr size_t buffer_size = 4096;
    std::vector<char> buffer(buffer_size);

    size_t read = 0;
    while (file)
    {
        file.read(buffer.data(), buffer_size);
        std::streamsize bytesRead = file.gcount(); // Get the number of bytes actually read
        if (bytesRead > 0)
        {
            bo.write(buffer.data(), bytesRead, seek);
            seek += bytesRead;
        }
        else if (file.eof())
        {
            // Reached the end of the file
            break;
        }
    }
    file.close();
}

void model_runner::read_data_to_bo(xrt::bo &bo, std::vector<uint8_t> &data, size_t seek, size_t buf_size)
{
    if (data.empty())
    {
        std::cout << "Unable to read data from buffer" << std::endl;
        std::vector<uint8_t> dummy_buffer;
        dummy_buffer.resize(buf_size);
        bo.write(dummy_buffer.data(), dummy_buffer.size(), 0);
        return;
    }
    constexpr size_t buffer_size = 4096;
    size_t offset = 0;
    const size_t total_size = data.size();

    size_t read = 0;
    while (offset < total_size)
    {
        size_t bytes_to_write = std::min(buffer_size, total_size - offset);
        bo.write(reinterpret_cast<const char*>(data.data() + offset), bytes_to_write, seek);
        seek += bytes_to_write;
        offset += bytes_to_write;
    }
}

// void model_runner::read_file_to_bo(xrt::bo &bo, std::string path, size_t seek, size_t buf_size) {
//     FILE* file = fopen(path.c_str(), "rb");
//     if (!file) {
//       std::vector<uint8_t> dummy_buffer;
//       dummy_buffer.resize(buf_size);
//       bo.write(dummy_buffer.data(), dummy_buffer.size(), 0);
//       return;
//     }
//     constexpr size_t buffer_size = 4096;
//     char buffer[buffer_size];
//     size_t read = 0;
//     while ((read = fread(buffer, 1, buffer_size, file)) > 0) {
//         bo.write(buffer, read, seek);
//         seek += read;
//     }
//     fclose(file);
// }

void model_runner::assign_kernel_args(xrt::run &run, std::size_t buffer_set) const
{
    std::array<uint64_t, 5> kargv = {0, 0, 0, 0, 0};

    if (!elf_flow_)
    {
        kargv[0] = ifm_bos_.address() + DDR_AIE_ADDR_OFFSET;
        kargv[1] = ofm_bos_.address() + DDR_AIE_ADDR_OFFSET;
        kargv[2] = scratch_bos_.address() + DDR_AIE_ADDR_OFFSET;
        kargv[3] = wgts_bos_.address() + DDR_AIE_ADDR_OFFSET;
        kargv[4] = param_bos_.address() + DDR_AIE_ADDR_OFFSET;

        run.set_arg(0, OPCODE);
        run.set_arg(1, instr_bos_[buffer_set]);
        run.set_arg(2, instr_bos_[buffer_set].size() / sizeof(uint32_t));
        for (int i = 0; i < 5; ++i)
        {
            run.set_arg(3 + i, kargv[i]);
        }
    }
    else
    {
        run.set_arg(0, ELF_OPCODE);
        run.set_arg(1, 0);
        run.set_arg(2, 0);
        run.set_arg(3, ifm_bos_);
        run.set_arg(4, ofm_bos_);
        run.set_arg(5, scratch_bos_);
        run.set_arg(6, wgts_bos_);
        run.set_arg(7, param_bos_);
    }
}

void model_runner::patch_ctrl_pkt()
{
    LOG_VERBOSE() << "Contrl packet pach Init .." << std::endl;

    uint8_t *bo_map = param_bos_.map<uint8_t *>();
    uint64_t input_bo_addr = ifm_bos_.address() + DDR_AIE_ADDR_OFFSET;
    uint64_t output_bo_addr = ofm_bos_.address() + DDR_AIE_ADDR_OFFSET;
    uint64_t scratch_bo_addr = scratch_bos_.address() + DDR_AIE_ADDR_OFFSET;
    uint64_t const_bo_addr = wgts_bos_.address() + DDR_AIE_ADDR_OFFSET;
    uint64_t super_instr_bo_addr = param_bos_.address() + DDR_AIE_ADDR_OFFSET;

    auto patch_bd_addr = [](uint8_t *dest, uint64_t ddr_addr) {
        uint32_t addr_low = (uint32_t)(ddr_addr & 0xFFFFFFFF);
        uint16_t addr_high = (uint16_t)((ddr_addr & 0x0000FFFF00000000ULL) >> 32);
        *(uint32_t *)(dest) = addr_low;
        *(uint16_t *)(dest + 4) = addr_high;
    };

    for (auto &patch : ctrlpkt_info_)
    {
        uint64_t *ptr = (uint64_t *)(bo_map + patch.offset);
        switch (patch.xrt_arg_idx)
        {
        case BO_CLASS::BO_IFM: {
            patch_bd_addr((uint8_t *)ptr, patch.bo_offset + input_bo_addr);
            break;
        }
        case BO_CLASS::BO_CONST: {
            patch_bd_addr((uint8_t *)ptr, patch.bo_offset + const_bo_addr);
            break;
        }
        case BO_CLASS::BO_PARAM: {
            patch_bd_addr((uint8_t *)ptr, patch.bo_offset + super_instr_bo_addr);
            break;
        }
        case BO_CLASS::BO_SCRATH: {
            patch_bd_addr((uint8_t *)ptr, patch.bo_offset + scratch_bo_addr);
            break;
        }
        case BO_CLASS::BO_OFM: {
            patch_bd_addr((uint8_t *)ptr, patch.bo_offset + output_bo_addr);
            break;
        }
        default:
            throw std::runtime_error("Unknow arg type for op");
            break;
        }
    }

    param_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    LOG_VERBOSE() << "Contrl packet pach Done .." << std::endl;
}

// Performs the run on hardware
int model_runner::run()
{
    if (this->config_.num_runs < 1)
    {
        return -1;
    }

    int ret_val = 0;

    struct pending_run
    {
        bool active = false;
        xrt::run run;
        std::chrono::time_point<std::chrono::high_resolution_clock> start_time;
        std::chrono::time_point<std::chrono::high_resolution_clock> end_time;

        void start()
        {
            start_time = std::chrono::high_resolution_clock::now();
            run.start();
            active = true;
        }

        void sync(xrt::device &device, std::vector<time_type_us> &timers)
        {
            run.wait2();

            auto run_state = run.state();
            end_time = std::chrono::high_resolution_clock::now();
            timers.push_back(std::chrono::duration_cast<time_type_us>(end_time - start_time));

            if (run_state != 4)
            {
                xrt::error error(device, XRT_ERROR_CLASS_AIE);
                std::cout << std::endl
                          << "---" << std::endl
                          << "XRT error details:" << std::endl
                          << error.to_string() << std::endl
                          << "---" << std::endl
                          << std::endl;
                throw std::runtime_error("xrt::run() failed, run_state = " + std::to_string(run_state));
            }

            // run = {};
            active = false;
        }
    };

    std::vector<pending_run> pending_runs(num_partitions_);

    for (size_t i = 0; i < num_partitions_; i++)
    {
        const std::size_t buffer_set = i;
        std::size_t kernel_idx = 0;
        if (elf_flow_)
        {
            kernel_idx = buffer_set;
        }
        auto &pending_run = pending_runs[buffer_set];

        if (pending_run.active)
        {
            LOG_VERBOSE() << "Waiting for kernel " << buffer_set << std::endl;
            pending_run.sync(device_, this->kernel_run_times_);
        }

        pending_run.run = runs_[buffer_set]; // xrt::run(kernels_[kernel_idx]);
        // assign_kernel_args(pending_run.run, buffer_set);
    }

    for (size_t i = 0; i < this->config_.num_runs; i++)
    {
        for (auto &pending_run : pending_runs)
        {
            pending_run.start();
            pending_run.sync(device_, this->kernel_run_times_);
        }
    }

    return ret_val;
}

void model_runner::allocate_runs()
{
    runs_.resize(num_partitions_);

    for (size_t i = 0; i < num_partitions_; i++)
    {
        std::size_t kernel_idx = 0;
        if (elf_flow_)
        {
            kernel_idx = i;
        }
        runs_[i] = xrt::run(kernels_[kernel_idx]);
        assign_kernel_args(runs_[i], i);
    }
}

void model_runner::allocate_runs_qhw4()
{
    runs_.resize(1);

    runs_[0] = xrt::run(kernels_[0]);
    runs_[0].set_arg(0, ofm_bos_);
    runs_[0].set_arg(1, ifm_bos_);
    runs_[0].set_arg(2, wgts_bos_);
    runs_[0].set_arg(3, param_bos_);
}

xrt::bo model_runner::allocate_xrt_buffer(const xrt::hw_context &ctx, const size_t &sz, xrt::bo::flags flag,
                                          xrt::memory_group grp)
{
    if (this->config_.device == "WXB2") {
        if (elf_flow_)
        {
            return xrt::ext::bo(ctx, sz);
        }
        return xrt::bo(ctx, sz, flag, grp);
    }
    else {
        return xrt::bo(ctx, sz, flag, 0);
    }
}

void model_runner::convert_to_elf()
{

    elf_mods_.clear();

    auto i8_vec_to_char_vector = [](const std::vector<uint8_t> &input) {
        std::vector<char> result(input.begin(), input.end());
        return result;
    };

    for (const auto &txn : instr_bufs_)
    {
        try
        {
            auto txn_buf = i8_vec_to_char_vector(txn);
            // Serialize the JSON object to a string
            // std::string json_str = ctrlpkt_json_.dump();
            // std::vector<char> patch_json(json_str.begin(), json_str.end());
            // auto ctrl_pkt_bin = i8_vec_to_char_vector(ctrl_buf_);
            aiebu::aiebu_assembler as(aiebu::aiebu_assembler::buffer_type::blob_instr_transaction, txn_buf, {}, {}, {},
                                      {}, {});
            auto elf = as.get_elf();
            std::istringstream elf_stream(std::string(elf.begin(), elf.end()));
            xrt::elf elf_obj(elf_stream);
            elf_mods_.emplace_back(xrt::module(elf_obj));
        }
        catch (std::exception e)
        {
            std::cout << e.what() << std::endl;
            throw std::runtime_error("convert_to_elf failed");
        }
    }
}

void model_runner::run_init(const bool &is_context_cache)
{
    ifm_copy_total_ = time_type_us{};
    wts_copy_total_ = time_type_us{};
    ofm_copy_total_ = time_type_us{};
    scratch_copy_total_ = time_type_us{};
    ifm_sync_total_ = time_type_us{};
    wts_sync_total_ = time_type_us{};
    ofm_sync_total_ = time_type_us{};
    scratch_sync_total_ = time_type_us{};
    ofm_pre_sync_total_ = time_type_us{};
    kernel_run_total_ = time_type_us{};

    // create vectors for input/output data
    std::vector<uint32_t> ifmTensor;
    std::vector<uint32_t> WTS;

    LOG_VERBOSE() << "Loading data from file" << std::endl;

    // create kernels for elf flow
    if (this->config_.device == "WXB2")
    {
        if (elf_flow_)
        {
            kernels_.resize(num_partitions_);
            convert_to_elf();
            for (size_t i = 0; i < num_partitions_; i++)
            {
                kernels_[i] = xrt::ext::kernel(context_, elf_mods_[i], KERNEL_NAME_ELF);
            }
        }
    }
    else if (this->config_.device == "QHW4")
    {
        kernels_.resize(1);
        LOG_VERBOSE() << "Create kernel: " << kernelName_ << std::endl;
        kernels_[0] = xrt::ext::kernel{context_, kernelName_};
    }
    else
    {
        throw std::runtime_error("The config device type is not supported.");
    }

    LOG_VERBOSE() << "Create kernel Done" << std::endl;

    int group_id;
    if (this->config_.device == "WXB2") {
        group_id = kernels_[0].group_id(HOST_BO_GROUP_ID);
    }
    else if (this->config_.device == "QHW4")
    {
        group_id = 0;
    }
    else
    {
        throw std::runtime_error("The config device type is not supported.");
    }
    size_t input_buf_size = this->config_.bo_sizes.input_size_bytes;
    if (input_buf_size < XRT_BO_MIN_SIZE)
    {
        size_t buf_size = XRT_BO_MIN_SIZE;
        ifm_bos_ =
            allocate_xrt_buffer(context_, buf_size, xrt::bo::flags::host_only, group_id);
        memset(ifm_bos_.map(), XRT_BO_INIT_VALUE, ifm_bos_.size());
        ifm_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    }
    else
    {
        ifm_bos_ = allocate_xrt_buffer(context_, input_buf_size, xrt::bo::flags::host_only, group_id);
    }
    LOG_VERBOSE() << "Allocate buffer for ifm" << std::endl;

    if (output_buf_size_ < XRT_BO_MIN_SIZE)
    {
        size_t buf_size = XRT_BO_MIN_SIZE;
        ofm_bos_ =
            allocate_xrt_buffer(context_, buf_size, xrt::bo::flags::host_only, group_id);
        memset(ofm_bos_.map(), XRT_BO_INIT_VALUE, ofm_bos_.size());
        ofm_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    }
    else
    {
        ofm_bos_ = allocate_xrt_buffer(context_, output_buf_size_, xrt::bo::flags::host_only, group_id);
    }
    LOG_VERBOSE() << "Allocate buffer for ofm" << std::endl;

    if (!config_.bo_sizes.no_scratch_buff)
    {
        if (scratch_buf_size_ < XRT_BO_MIN_SIZE)
        {
            size_t buf_size = XRT_BO_MIN_SIZE;
            scratch_bos_ = allocate_xrt_buffer(context_, buf_size, xrt::bo::flags::host_only, group_id);
            memset(scratch_bos_.map(), XRT_BO_INIT_VALUE, scratch_bos_.size());
            scratch_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
        }
        else
        {
            scratch_bos_ = allocate_xrt_buffer(context_, scratch_buf_size_, xrt::bo::flags::host_only, group_id);
        }
        LOG_VERBOSE() << "Allocate buffer for scratch" << std::endl;
    }

    if (wgts_buf_size_ < XRT_BO_MIN_SIZE)
    {
        size_t buf_size = XRT_BO_MIN_SIZE;
        wgts_bos_ =
            allocate_xrt_buffer(context_, buf_size, xrt::bo::flags::host_only, group_id);
        memset(wgts_bos_.map(), XRT_BO_INIT_VALUE, wgts_bos_.size());
        wgts_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    }
    else
    {
        wgts_bos_ = allocate_xrt_buffer(context_, wgts_buf_size_, xrt::bo::flags::host_only, group_id);
    }
    LOG_VERBOSE() << "Allocate buffer for wgt" << std::endl;

    if (param_buf_size_ < XRT_BO_MIN_SIZE)
    {
        size_t buf_size = XRT_BO_MIN_SIZE;
        param_bos_ =
            allocate_xrt_buffer(context_, buf_size, xrt::bo::flags::host_only, group_id);
        memset(param_bos_.map(), XRT_BO_INIT_VALUE, param_bos_.size());
        param_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    }
    else
    {
        param_bos_ = allocate_xrt_buffer(context_, param_buf_size_, xrt::bo::flags::host_only, group_id);
    }
    LOG_VERBOSE() << "Allocate buffer for param" << std::endl;
    
    if (!elf_flow_)
    {
        instr_bos_.resize(num_partitions_);
        for (size_t i = 0; i < num_partitions_; ++i)
        {
            auto each_buf = instr_bufs_[i];
            instr_bos_[i] =
                allocate_xrt_buffer(context_, each_buf.size(), xrt::bo::flags::cacheable, kernels_[0].group_id(1));
        }
        LOG_VERBOSE() << "Writing instruction bo" << std::endl;
        for (size_t i = 0; i < num_partitions_; ++i)
        {
            auto each_buf = instr_bufs_[i];
            instr_bos_[i].write(each_buf.data(), each_buf.size(), 0);
            instr_bos_[i].sync(XCL_BO_SYNC_BO_TO_DEVICE);
        }
    }

    LOG_VERBOSE() << "Writing control packet bo" << std::endl;

    if (is_context_cache) {
      read_data_to_bo(param_bos_, param_vec_, 0, param_buf_size_);
    } else {
      read_file_to_bo(param_bos_, param_path_, 0, param_buf_size_);
    }
    param_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    // add ctrlpkt patching
    if (is_ctrl_pkt_)
    { //&& !elf_flow_
        patch_ctrl_pkt();
    }

    // flush ofm buffer cache
    auto ofm_pre_sync_start = std::chrono::high_resolution_clock::now();
    ofm_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    auto ofm_pre_sync_end = std::chrono::high_resolution_clock::now();
    ofm_pre_sync_total_ += std::chrono::duration_cast<time_type_us>(ofm_pre_sync_end - ofm_pre_sync_start);

    // copy + sync wts
    auto wts_copy_start = std::chrono::high_resolution_clock::now();
    if (is_context_cache) {
      read_data_to_bo(wgts_bos_, wgts_vec_, 0, wgts_buf_size_);
    } else {
      read_file_to_bo(wgts_bos_, wgts_path_, 0, wgts_buf_size_);
    }
    auto wts_copy_end = std::chrono::high_resolution_clock::now();
    wts_copy_total_ += std::chrono::duration_cast<time_type_us>(wts_copy_end - wts_copy_start);

    auto wts_sync_start = std::chrono::high_resolution_clock::now();
    wgts_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    auto wts_sync_end = std::chrono::high_resolution_clock::now();
    wts_sync_total_ += std::chrono::duration_cast<time_type_us>(wts_sync_end - wts_sync_start);

    // copy + sync scratch
    if (scratch_buf_size_ > 0)
    {
        auto scratch_copy_start = std::chrono::high_resolution_clock::now();
        read_file_to_bo(scratch_bos_, scratch_path_, 0, scratch_buf_size_);
        auto scratch_copy_end = std::chrono::high_resolution_clock::now();
        scratch_copy_total_ += std::chrono::duration_cast<time_type_us>(scratch_copy_end - scratch_copy_start);

        auto scratch_sync_start = std::chrono::high_resolution_clock::now();
        scratch_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
        auto scratch_sync_end = std::chrono::high_resolution_clock::now();
        scratch_sync_total_ += std::chrono::duration_cast<time_type_us>(scratch_sync_end - scratch_sync_start);
    }
    if (this->config_.device == "WXB2")
    {
        allocate_runs();
    }
    else if (this->config_.device == "QHW4")
    {
        allocate_runs_qhw4();
        LOG_VERBOSE() << "Set run arg" << std::endl;
    }
    else
    {
        throw std::runtime_error("The config device type is not supported.");
    }
}

int model_runner::run_execute(const std::vector<uint8_t> &input_data, std::vector<uint8_t> &output_data)
{
    ifm_copy_total_ = time_type_us{};
    ofm_copy_total_ = time_type_us{};
    ifm_sync_total_ = time_type_us{};
    ofm_sync_total_ = time_type_us{};
    ofm_pre_sync_total_ = time_type_us{};
    kernel_run_total_ = time_type_us{};

    // create vectors for input/output data
    std::vector<uint32_t> ifmTensor;
    std::vector<uint32_t> WTS;

    LOG_VERBOSE() << "Loading input" << std::endl;

    // load input tensor
    auto ifm_copy_start = std::chrono::high_resolution_clock::now();
    if (input_data.empty() || nullptr == input_data.data())
    {
        std::cerr << "Empty or invalid input data" << std::endl;
    }
    else
    {
        ifm_bos_.write(input_data.data(), input_data.size(), 0);
    }
    auto ifm_copy_end = std::chrono::high_resolution_clock::now();
    ifm_copy_total_ += std::chrono::duration_cast<time_type_us>(ifm_copy_end - ifm_copy_start);

    auto ifm_sync_start = std::chrono::high_resolution_clock::now();
    ifm_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    auto ifm_sync_end = std::chrono::high_resolution_clock::now();
    ifm_sync_total_ += std::chrono::duration_cast<time_type_us>(ifm_sync_end - ifm_sync_start);

    kernel_run_times_.reserve(config_.num_runs);
    auto run_start = std::chrono::high_resolution_clock::now();
    int err_check = run();
    auto run_end = std::chrono::high_resolution_clock::now();
    kernel_run_total_ += std::chrono::duration_cast<time_type_us>(run_end - run_start);
    if (err_check != 0)
    {
        return err_check;
    }
    auto ofm_sync_start = std::chrono::high_resolution_clock::now();
    ofm_bos_.sync(XCL_BO_SYNC_BO_FROM_DEVICE);
    auto ofm_sync_end = std::chrono::high_resolution_clock::now();
    ofm_sync_total_ += std::chrono::duration_cast<time_type_us>(ofm_sync_end - ofm_sync_start);

    auto ofm_copy_start = std::chrono::high_resolution_clock::now();
    ofm_bos_.read(output_data.data(), output_data.size(), 0);
    auto ofm_copy_end = std::chrono::high_resolution_clock::now();
    ofm_copy_total_ += std::chrono::duration_cast<time_type_us>(ofm_copy_end - ofm_copy_start);

    // save output and scratch buff
    if (config_.dump_data)
    {
        std::vector<uint8_t> scratch_buf_tmp;
        scratch_buf_tmp.resize(scratch_buf_size_);
        write_bin_file(config_.workload_path + "aie_out.bin", (char *)output_data.data(), output_data.size());
        scratch_bos_.sync(XCL_BO_SYNC_BO_FROM_DEVICE);
        scratch_bos_.read(scratch_buf_tmp.data(), scratch_buf_tmp.size(), 0);
        write_bin_file(config_.workload_path + "scratch_bo_fname.bin", (char *)scratch_buf_tmp.data(),
                       scratch_buf_tmp.size());
    }
    LOG_VERBOSE() << std::endl;
    if (config_.verbose)
    {
        print_runtime_info(this->kernel_run_times_, config_);
    }

    LOG_VERBOSE() << "WTS Copy to BO   : " << wts_copy_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "WTS Sync BO      : " << wts_sync_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "Scratch Copy to BO   : " << scratch_copy_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "Scratch Sync to BO   : " << scratch_sync_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "IFM Sync BO      : " << ifm_sync_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "IFM Copy to BO   : " << ifm_copy_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "IFM Sync BO      : " << ifm_sync_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "OFM Pre-Sync BO  : " << ofm_pre_sync_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "OFM Sync BO      : " << ofm_sync_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "OFM Copy from BO : " << ofm_copy_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "Kernel Run total : " << kernel_run_total_.count() << " us" << std::endl;

    return 0;
}

void *model_runner::get_inputbo_ptr()
{
    return (ifm_bos_.map());
}

void *model_runner::get_outputbo_ptr()
{
    return (ofm_bos_.map());
}

int model_runner::run_execute()
{
    ifm_copy_total_ = time_type_us{};
    ofm_copy_total_ = time_type_us{};
    ifm_sync_total_ = time_type_us{};
    ofm_sync_total_ = time_type_us{};
    ofm_pre_sync_total_ = time_type_us{};
    kernel_run_total_ = time_type_us{};

    // create vectors for input/output data
    std::vector<uint32_t> ifmTensor;
    std::vector<uint32_t> WTS;

    LOG_VERBOSE() << "Loading input" << std::endl;

    // load input tensor    
    auto ifm_sync_start = std::chrono::high_resolution_clock::now();
    ifm_bos_.sync(XCL_BO_SYNC_BO_TO_DEVICE);
    auto ifm_sync_end = std::chrono::high_resolution_clock::now();
    ifm_sync_total_ += std::chrono::duration_cast<time_type_us>(ifm_sync_end - ifm_sync_start);

    if (config_.dump_data)
    {
        write_bin_file(config_.workload_path + "input_bo_fname.bin", (char*)ifm_bos_.map(), this->config_.bo_sizes.input_size_bytes);
    }

    kernel_run_times_.reserve(config_.num_runs);
    auto run_start = std::chrono::high_resolution_clock::now();
    int err_check = run();
    auto run_end = std::chrono::high_resolution_clock::now();
    kernel_run_total_ += std::chrono::duration_cast<time_type_us>(run_end - run_start);
    if (err_check != 0)
    {
        return err_check;
    }
    auto ofm_sync_start = std::chrono::high_resolution_clock::now();
    ofm_bos_.sync(XCL_BO_SYNC_BO_FROM_DEVICE);
    auto ofm_sync_end = std::chrono::high_resolution_clock::now();
    ofm_sync_total_ += std::chrono::duration_cast<time_type_us>(ofm_sync_end - ofm_sync_start);

    // save output and scratch buff
    if (config_.dump_data)
    {
        write_bin_file(config_.workload_path + "aie_out.bin", (char*)ofm_bos_.map(), output_buf_size_);

        if (scratch_buf_size_ > 0) {
            std::vector<uint8_t> scratch_buf_tmp;
            scratch_buf_tmp.resize(scratch_buf_size_);
            scratch_bos_.sync(XCL_BO_SYNC_BO_FROM_DEVICE);
            scratch_bos_.read(scratch_buf_tmp.data(), scratch_buf_tmp.size(), 0);
            write_bin_file(config_.workload_path + "scratch_bo_fname.bin", (char*)scratch_buf_tmp.data(),
                scratch_buf_tmp.size());
        }
    }
    LOG_VERBOSE() << std::endl;
    if (config_.verbose)
    {
        print_runtime_info(this->kernel_run_times_, config_);
    }

    LOG_VERBOSE() << "WTS Copy to BO   : " << wts_copy_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "WTS Sync BO      : " << wts_sync_total_.count() << " us (once)" << std::endl;
    //GLOBAL_LATENCY_STREAM() << "WTS_Copy_to_BO: " << wts_copy_total_.count() << " us\n";
    GLOBAL_LATENCY_STREAM() << "TESTAPP::WTS_Sync_BO: " << wts_sync_total_.count() << " us\n";
    
    LOG_VERBOSE() << "Scratch Copy to BO   : " << scratch_copy_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "Scratch Sync to BO   : " << scratch_sync_total_.count() << " us (once)" << std::endl;
    //GLOBAL_LATENCY_STREAM() << "Scratch_Copy_to_BO: " << scratch_copy_total_.count() << " us\n";
    GLOBAL_LATENCY_STREAM() << "TESTAPP::Scratch_Sync_BO: " << scratch_sync_total_.count() << " us\n";
    
    LOG_VERBOSE() << "IFM Sync BO      : " << ifm_sync_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "IFM Copy to BO   : " << ifm_copy_total_.count() << " us (once)" << std::endl;
    LOG_VERBOSE() << "IFM Sync BO      : " << ifm_sync_total_.count() << " us" << std::endl;
    //GLOBAL_LATENCY_STREAM() << "IFM_Copy_to_BO: " << ifm_copy_total_.count() << " us\n";
    GLOBAL_LATENCY_STREAM() << "TESTAPP::IFM_Sync_BO: " << ifm_sync_total_.count() << " us\n";
    
    LOG_VERBOSE() << "OFM Pre-Sync BO  : " << ofm_pre_sync_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "OFM Sync BO      : " << ofm_sync_total_.count() << " us" << std::endl;
    LOG_VERBOSE() << "OFM Copy from BO : " << ofm_copy_total_.count() << " us" << std::endl;
    GLOBAL_LATENCY_STREAM() << "TESTAPP::OFM_Pre_Sync_BO: " << ofm_pre_sync_total_.count() << " us\n";
    GLOBAL_LATENCY_STREAM() << "TESTAPP::OFM_Sync_BO: " << ofm_sync_total_.count() << " us\n";
    //GLOBAL_LATENCY_STREAM() << "OFM_Copy_from_BO: " << ofm_copy_total_.count() << " us\n";
    
    LOG_VERBOSE() << "Kernel Run total : " << kernel_run_total_.count() << " us" << std::endl;
    GLOBAL_LATENCY_STREAM() << "TESTAPP::Kernel_Run_total: " << kernel_run_total_.count() << " us\n";

    return 0;
}
