#pragma once

#include "file_ptr.hpp"
#include "json_reader.hpp"
#include "tfunc_impl.hpp"
#include <any>
#include <map>
#include <numeric>
#include <set>
#include <stdexcept>
#include <string>
#include <utility>
#include <vector>
namespace waic_runner
{
// Equivalent to .at() method of std::vector/std::array
#define ARRAY_AT(x, idx) vector_get_value_at(x, idx, #x, __FILE__, __LINE__)

// Equivalent to index access of a new/malloc buffer
#define PTR_AT(ptr, sz, idx) ptr_get_at(ptr, sz, idx, #ptr, __FILE__, __LINE__)

// Equivalent to .at() method of std::map/std::unordered_map
#define MAP_AT(x, key) map_get_value_at(x, key, #x, __FILE__, __LINE__)
    // BO sequence: OUTPUT, INPUT, CONST, PARAM, CTRL_PKT
    constexpr std::array<size_t, 5> WAIC_BO_SEQUENCE = { 0, 1, 2, 3, 4 };
    constexpr std::array<size_t, 5> MLADF_BO_SEQUENCE = { 2, 0, 1, 4, 3 };

struct Tensor
{
    void *data{nullptr};
    std::vector<size_t> shape;
    std::string dtype;
};

struct OpArgMap
{
    enum OpArgType
    {
        INPUT,
        OUTPUT,
        SCRATCH_PAD,
        CONST_INPUT,
        CONST_KERNEL_PARAM_INPUT,
        CTRL_PKT_BIN,
    };
    OpArgType arg_type;
    size_t xrt_arg_idx;
    size_t onnx_arg_idx;
    size_t offset;
    size_t size; // in bytes
    size_t padding_offset = 0;
};

class ConstArray
{
  public:
    virtual char *ptr()
    {
        return nullptr;
    }
    virtual ~ConstArray()
    {
    }
};

class ConstBufferIO
{
  public:
    virtual void update_offset(size_t offset) = 0;
    virtual void write(size_t offset, void *src, size_t size) = 0;
    virtual std::unique_ptr<ConstArray> get_buffer(size_t offset, size_t size) = 0;
    virtual std::vector<char> read(size_t offset, size_t size) = 0;
};

static void write_to_file(FILE *file, size_t offset, void *src, size_t size)
{
    auto old_offset = ftell64(file);
    fseek64(file, offset, SEEK_CUR);
    auto written = fwrite(src, 1, size, file);
    // rewind
    fseek64(file, old_offset, SEEK_SET);
}

    class TmpFileBuffer : public ConstArray {
    public:
        TmpFileBuffer(FILE* file, size_t offset, size_t size) {
            file_ = file;
            offset_ = offset;
            size_ = size;
            data_ = (char*)malloc(size);
            memset(data_, 0, size);
        }
        char* ptr() override final { return data_; }
        virtual ~TmpFileBuffer() override {
            write_to_file(file_, offset_, data_, size_);
            free(data_);
        }

    private:
        char* data_;
        FILE* file_;
        size_t offset_;
        size_t size_;
    };

class TmpFileConst : public ConstBufferIO
{
  public:
    TmpFileConst(FILE *file_ptr)
    {
        this->file_ = file_ptr;
    }
    void update_offset(size_t offset) override final
    {
        fseek64(this->file_, offset, SEEK_CUR);
    }

    std::unique_ptr<ConstArray> get_buffer(size_t offset, size_t size) override final
    {
        return std::make_unique<TmpFileBuffer>(file_, offset, size);
    }
    void write(size_t offset, void *src, size_t size) override final
    {
        write_to_file(file_, offset, src, size);
    }

    std::vector<char> read(size_t offset, size_t size) override final
    {
        auto old_offset = ftell64(file_);
        fseek64(file_, offset, SEEK_CUR);
        std::vector<char> ret(size);
        auto sz = fread(ret.data(), size, 1, file_);
        fseek64(file_, old_offset, SEEK_SET);
        return ret;
    }

  private:
    FILE *file_;
};

struct Partition
{
    // describes [start, end) interval
    std::pair<size_t, size_t> op_range;
    uint8_t pdi_id;
};

struct FullTensorInfo
{
    std::string name;
    std::vector<size_t> onnx_shape;
    std::string onnx_dtype;
    std::string onnx_format;
    std::vector<size_t> hw_shape;
    std::string hw_dtype;
    std::string hw_format;
    std::vector<size_t> L3_alloc;
    std::string file_name;
    size_t file_size;
};

struct Metadata
{
    uint32_t major_version;
    uint32_t minor_version;
    std::string device;
    struct OpInfo
    {
        std::string name;
        std::string type;
        std::vector<std::string> in_args;
        std::vector<std::string> const_args;
        std::vector<std::string> out_args;
        std::vector<uint8_t> txn_bin;
        std::vector<CtrlPktPatchInfo> ctrl_pkt_patch_info;
        std::map<std::string, std::any> attr;
        std::uint8_t pdi_id = 0;
    };
    struct TensorInfo
    {
        size_t size;
        size_t xrt_arg_idx;
        std::vector<std::string> packed_tensors;
    };
    struct OffsetInfo
    {
        std::string parent_name; // Parent packed_tensor's name
        size_t offset;           // Offset in the parent tensor
        size_t additional_offset;    // Additional offset based on NPU tensor
        int ref_idx; // idx in NPU tensor
        size_t xrt_arg_idx;
        std::string dtype;
        std::vector<size_t> shape;
        size_t size_in_bytes; // Final size as per the kernel's reqs.
        std::string format;
        std::string file_name;
        size_t file_size;
    };

    struct Span
    {
        size_t offset;
        size_t size;
    };

    std::vector<OpInfo> op_list;
    std::map<std::string, TensorInfo> fused_tensors; // fused_tensor.name --> TensorInfo
    std::map<std::string, OffsetInfo> tensor_map;    // onnxtensor.name --> OffsetInfo
    std::map<std::string, Span> super_instr_map;     // op.name --> Op's super buffer
    std::map<std::string, Span> const_map;           // op.name --> Op's const buffer
    std::map<std::string, Span> ctrl_pkt_map;        // op.name --> Op's ctrl pkt buffer
    std::set<std::string> scratch_op_set;            // set of ops which require internal scratch pad
    size_t max_op_scratch_pad_size;                  // max internal scratch pad for all op
    size_t max_tensor_padding_sz;                    // max padding for input tensor of op

    // Placeholder to keep any extra info
    std::map<std::string, std::any> aux_info;

    std::string json_path;

    // Information on PDI partitioning
    std::vector<Partition> partitions;

    // for new flow
    std::vector<FullTensorInfo> in_tensors;
    std::vector<FullTensorInfo> out_tensors;
    std::string subgraph_idx;
};

struct tidInfo
{
    int tid;
    int num_usage; // number of inputs/outputs used in op_list
};

struct GroupBuff
{
    size_t offset;
    size_t size;
    std::vector<int> packed_tids;
    std::vector<std::string> packed_tensors;
};

struct cpcfg
{
    uint32_t profile = 0; // pass profile level. 0 - None, 1 - subgraph, 2 - subgraph+PDI
    // partition, 3 - subgraph + PDI partition + ops
    bool optimize_scratch = true;
    // use fused transaction, but run each op serially
    bool eager_mode = false;
    // enable preemption in elf flow
    bool enable_preemption = true;
    // enable fast pm load
    bool enable_fast_pm = true;
    // use inmem
    bool use_inmem = false;
};

struct dbgcfg
{
    // dump the data for debug
    bool dump_data = false;
    // use to save trace/log
    bool enable_trace = false;
	// to get detailed timing info
    bool is_profiling = true;
};

struct rtcfg
{
    std::string xclbin_path;
    std::string HWbin_path;
    std::string tilings_json;
    std::string meta_json_path;
    std::string cache_path;
    std::string prebuilt_bin_dir;
    int compile_flag;
    int runtime_flag;
    cpcfg compile_cfg;
    dbgcfg debug_cfg;
    std::string prefix;
};

/// @brief Concat multiple vectors to a single vector
template <typename Vec, typename... Vecs> static Vec concat_vectors(const Vec &vec0, const Vecs &...vecs)
{
    auto sizeof_vec = [](const Vec &vec) -> size_t { return vec.size(); };
    auto concat_vec = [](Vec &dst, const Vec &src) { dst.insert(dst.end(), src.begin(), src.end()); };

    size_t total_size = (sizeof_vec(vec0) + ... + sizeof_vec(vecs));

    Vec res;
    res.reserve(total_size);
    (concat_vec(res, vec0), ..., concat_vec(res, vecs));

    return res;
}

static std::vector<std::string> get_op_args(const Metadata::OpInfo &op_info)
{
    return concat_vectors(op_info.in_args, op_info.const_args, op_info.out_args);
}

inline std::vector<rtcfg> load_rtcfg(const std::string rtcfg_file)
{
    std::vector<rtcfg> cfgs;
    std::ifstream ifs(rtcfg_file);
    if (!ifs.is_open())
    {
        throw std::runtime_error("Can not open cfg json file");
    }

    json data;
    try
    {
        data = json::parse(ifs);
        ifs.close();
    }
    catch (std::exception& e)
    {
        std::cout << e.what() << std::endl;
        ifs.close();
        throw std::runtime_error("Can not parse json file");
    }

    int num_subgraphs = 0;
    bool single_graph = 0;
    std::vector<std::string> excludeVector = { "xclbin",  "HWbin_path", "Tilings_json", "Cache_dir", "prebuilt_bin_dir",
                                              "Compile", "Runtime",    "Compile_cfg",  "Debug_cfg" };
    std::vector<std::string> skeys = {};
    if (data.contains("meta_json"))
    {
        num_subgraphs = 1;
        single_graph = 1;
    }
    else
    {
        skeys = get_keys(data, excludeVector);
        num_subgraphs = skeys.size();
    }
    for (int i = 0; i < num_subgraphs; i++)
    {
        rtcfg cfg;
        if (single_graph == 1)
        {
            cfg.prefix = "default_";
            cfg.meta_json_path = data.at("meta_json");
        }
        else
        {
            cfg.prefix = skeys[i]; // +"_";
            cfg.meta_json_path = data[skeys[i]]["meta_json"];
        }
        cfg.xclbin_path = data.at("xclbin");
        cfg.HWbin_path = data.at("HWbin_path");
        cfg.tilings_json = data.at("Tilings_json");
        cfg.cache_path = data.at("Cache_dir");
        cfg.prebuilt_bin_dir = data.at("prebuilt_bin_dir");
        cfg.compile_flag = data.at("Compile");
        cfg.runtime_flag = data.at("Runtime");

        if (data.contains("Compile_cfg"))
        {
            auto temp = data["Compile_cfg"];
            if (temp.contains("profile"))
            {
                cfg.compile_cfg.profile = temp["profile"];
            }
            if (temp.contains("eager_mode"))
            {
                if (temp["eager_mode"] == 0)
                {
                    cfg.compile_cfg.eager_mode = false;
                }
                else
                {
                    cfg.compile_cfg.eager_mode = true;
                }
            }
            if (temp.contains("optimize_scratch"))
            {
                if (temp["optimize_scratch"] == 0)
                {
                    cfg.compile_cfg.optimize_scratch = false;
                }
                else
                {
                    cfg.compile_cfg.optimize_scratch = true;
                }
            }
            if (temp.contains("enable_preemption"))
            {
                if (temp["enable_preemption"] == 0)
                {
                    cfg.compile_cfg.enable_preemption = false;
                }
                else
                {
                    cfg.compile_cfg.enable_preemption = true;
                }
            }
            if (temp.contains("enable_fast_pm"))
            {
                if (temp["enable_fast_pm"] == 0)
                {
                    cfg.compile_cfg.enable_fast_pm = false;
                }
                else
                {
                    cfg.compile_cfg.enable_fast_pm = true;
                }
            }
            if (temp.contains("use_inmem"))
            {
                if (temp["use_inmem"] == 0)
                {
                    cfg.compile_cfg.use_inmem = false;
                }
                else
                {
                    cfg.compile_cfg.use_inmem = true;
                }
            }
        }

        if (data.contains("Debug_cfg"))
        {
            auto temp = data["Debug_cfg"];
            if (temp.contains("dump_data"))
            {
                if (temp["dump_data"] == 0)
                {
                    cfg.debug_cfg.dump_data = false;
                }
                else
                {
                    cfg.debug_cfg.dump_data = true;
                }
            }
            if (temp.contains("enable_trace"))
            {
                if (temp["enable_trace"] == 0)
                {
                    cfg.debug_cfg.enable_trace = false;
                }
                else
                {
                    cfg.debug_cfg.enable_trace = true;
                }
            }
            if (temp.contains("is_profiling"))
            {
                if (temp["is_profiling"] == 0)
                {
                    cfg.debug_cfg.is_profiling = false;
                }
                else
                {
                    cfg.debug_cfg.is_profiling = true;
                }
            }
        }
        cfgs.push_back(cfg);
    }

    return (cfgs);
}
} // namespace waic_runner
