// Copyright 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
////////////////////////////////////////////////////////////////////////
#pragma once
#include <algorithm>
#include <cassert>
#include <cstdlib>
#include <iomanip>
#include <iterator>

#include "utils.hpp"

// XRT headers
#include "experimental/xrt_elf.h"
#include "experimental/xrt_error.h"
#include "experimental/xrt_ext.h"
#include "experimental/xrt_module.h"
#include "json_reader.hpp"
#include "xrt/xrt_bo.h"
#include "xrt/xrt_device.h"
#include "xrt/xrt_kernel.h"

namespace waic_runner
{
enum BO_CLASS
{
    BO_IFM = 0,
    BO_OFM,
    BO_SCRATH,
    BO_CONST,
    BO_PARAM
};

struct bo
{
    uint64_t sz;
    enum BO_CLASS type;
    unsigned xrt_id;
};

struct bo_size_bytes
{
    size_t input_size_bytes = 0;
    size_t scratch_size_bytes = 0;
    size_t output_size_bytes = 0;
    bool no_scratch_buff = 0;
};

struct configuration_runner
{
    std::string device;
    std::string kernel_name;
    size_t num_runs;
    std::string workload_path;

    std::string xclbin_filename;
    std::string txnbin_filename;
    std::string input_filename;
    std::string output_filename;
    std::string wgts_filename;
    std::string param_filename;
    std::string ctrl_filename;
    std::string scratch_filename;
    std::string ctrl_pkt_info_filename;

    bo_size_bytes bo_sizes;

    bool verbose;

    bool dump_data = false;
    bool is_profiling = true;
};

class model_runner
{

  public:
    // Constructor for runner
    model_runner(configuration_runner config);
    model_runner
      (
        configuration_runner config,
        std::unique_ptr<std::vector<char>> ctrlpkt_vec,
        std::unique_ptr<std::vector<uint8_t>> wgts_vec,
        std::unique_ptr<std::vector<uint8_t>> txn_bin_vec,
        std::unique_ptr<std::vector<uint8_t>> param_vec,
        std::unique_ptr<std::vector<char>> xclbin_vec
      );

    void run_init(const bool &is_context_cache=false);
    // Performs the run on hardware
    int run_execute(const std::vector<uint8_t> &input_data, std::vector<uint8_t> &output_data);
    int run_execute();
    void *get_inputbo_ptr();
    void *get_outputbo_ptr();

    // std::string get_input_ref_filepath();

  private:
    configuration_runner config_;
    bool is_ctrl_pkt_;
    bool elf_flow_ = true;

    xrt::device device_;
    xrt::xclbin xclbin_;
    std::vector<xrt::kernel> kernels_;
    xrt::hw_context context_;
    std::vector<xrt::module> elf_mods_;
    std::vector<xrt::run> runs_;

    // create vectors for input/output data
    // std::vector<uint32_t> ref_golden;

    // TODO Factor out into execution_context or so
    xrt::bo ifm_bos_;
    xrt::bo ofm_bos_;
    std::vector<xrt::bo> instr_bos_;
    xrt::bo param_bos_;
    xrt::bo scratch_bos_;
    xrt::bo wgts_bos_;
    size_t num_partitions_;

    std::string kernelName_;
    std::string scratch_path_;
    std::string output_path_;
    std::string wgts_path_;
    std::string param_path_;

    std::vector<std::vector<uint8_t>> instr_bufs_;
    size_t scratch_buf_size_;
    size_t output_buf_size_;
    size_t wgts_buf_size_;
    size_t param_buf_size_;
    std::vector<char> ctrlpkt_vec_;
    std::vector<uint8_t> wgts_vec_;
    std::vector<uint8_t> txnbin_vec_;
    std::vector<uint8_t> param_vec_;
    json ctrlpkt_json_;
    std::vector<CtrlPktPatchInfo> ctrlpkt_info_;

    std::chrono::microseconds ifm_copy_total_, wts_copy_total_, ofm_copy_total_, scratch_copy_total_, ifm_sync_total_,
        wts_sync_total_, ofm_sync_total_, scratch_sync_total_, ofm_pre_sync_total_, kernel_run_total_;

    // std::vector<uint8_t> patched_txn_for_buffers(xrt::bo &ifm, xrt::bo &ofm,
    // xrt::bo &ctrl_pkts) const; std::vector<uint8_t>
    // patched_ctrl_pkts_for_buffers(xrt::bo &ifm, xrt::bo &ofm) const;

    std::vector<std::chrono::microseconds> kernel_run_times_;

    void load_xclbin(const std::string &xclbin_filename);
    void load_xclbin_data(const std::vector<char> &xclbin_data, const std::string &xclbin_filename);
    void load_txn_bin(const std::string &txnbin_filename);
    void load_txn_bin_file(const std::vector<uint8_t> &txnbin_vec);
    void load_elf(const std::string &elf_filename);
    void read_file_to_bo(xrt::bo &bo, std::string path, size_t seek, size_t buf_size);
    void read_data_to_bo(xrt::bo &bo, std::vector<uint8_t> &data, size_t seek, size_t buf_size);

    void assign_kernel_args(xrt::run &run, std::size_t buffer_set) const;
    void allocate_runs();
    void allocate_runs_qhw4();
    int run();
    void patch_ctrl_pkt();

    xrt::bo allocate_xrt_buffer(const xrt::hw_context &ctx, const size_t &sz, xrt::bo::flags flag,
                                xrt::memory_group grp);
    void convert_to_elf();
};
} // namespace waic_runner

waic_runner::configuration_runner prepare_default_config(const std::string &workload_path, size_t num_runs,
                                                         const std::string &prefix,
                                                         waic_runner::bo_size_bytes size_bos = {0, 0, 0},
                                                         const std::string &device = "WXB2");
