#include "meta_state.hpp"
#include <variant>
namespace waic_runner
{
void save_meta(const Metadata &meta, const std::string &state_name, bool verbose)
{
    size_t size;
    std::ofstream outFile(state_name, std::ios::binary);
    if (!outFile)
    {
        std::cerr << "Error opening file for writing!" << std::endl;
        return;
    }

    // Save version
    outFile.write(reinterpret_cast<const char *>(&meta.major_version), sizeof(meta.major_version));
    outFile.write(reinterpret_cast<const char *>(&meta.minor_version), sizeof(meta.minor_version));

    // Save device string
    write_string(outFile, meta.device);

    // Save the size of the fused_tensors
    size = meta.fused_tensors.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each fused_tensors
    for (const auto &pair : meta.fused_tensors)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        save_TensorInfo(outFile, pair.second);
    }

    // Save the size of the tensor_map
    size = meta.tensor_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each tensor_map
    for (const auto &pair : meta.tensor_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        save_OffsetInfo(outFile, pair.second);
    }

    // Save the size of the super_instr_map
    size = meta.super_instr_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each super_instr_map
    for (const auto &pair : meta.super_instr_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        outFile.write(reinterpret_cast<const char *>(&pair.second), sizeof(pair.second));
    }

    // Save the size of the const_map
    size = meta.const_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each const_map
    for (const auto &pair : meta.const_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        outFile.write(reinterpret_cast<const char *>(&pair.second), sizeof(pair.second));
    }

    // Save the size of the ctrl_pkt_map
    size = meta.ctrl_pkt_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each ctrl_pkt_map
    for (const auto &pair : meta.ctrl_pkt_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        outFile.write(reinterpret_cast<const char *>(&pair.second), sizeof(pair.second));
    }

    // Save max
    outFile.write(reinterpret_cast<const char *>(&meta.max_op_scratch_pad_size), sizeof(meta.max_op_scratch_pad_size));
    outFile.write(reinterpret_cast<const char *>(&meta.max_tensor_padding_sz), sizeof(meta.max_tensor_padding_sz));

    // Save the size of the partitions
    size = meta.partitions.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each partitions
    for (const auto &pd : meta.partitions)
    {
        outFile.write(reinterpret_cast<const char *>(&pd), sizeof(pd));
    }

    outFile.close();
    if (verbose)
    {
        std::cout << "Meta Data saved to file successfully!" << std::endl;
    }
}

void save_meta(const Metadata &meta, std::vector<uint8_t> &state_data, bool verbose)
{
    size_t size;
    std::ostringstream outFile(std::ios::binary);

    // Save version
    outFile.write(reinterpret_cast<const char *>(&meta.major_version), sizeof(meta.major_version));
    outFile.write(reinterpret_cast<const char *>(&meta.minor_version), sizeof(meta.minor_version));

    // Save device string
    write_string(outFile, meta.device);

    // Save the size of the fused_tensors
    size = meta.fused_tensors.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each fused_tensors
    for (const auto &pair : meta.fused_tensors)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        save_TensorInfo(outFile, pair.second);
    }

    // Save the size of the tensor_map
    size = meta.tensor_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each tensor_map
    for (const auto &pair : meta.tensor_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        save_OffsetInfo(outFile, pair.second);
    }

    // Save the size of the super_instr_map
    size = meta.super_instr_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each super_instr_map
    for (const auto &pair : meta.super_instr_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        outFile.write(reinterpret_cast<const char *>(&pair.second), sizeof(pair.second));
    }

    // Save the size of the const_map
    size = meta.const_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each const_map
    for (const auto &pair : meta.const_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        outFile.write(reinterpret_cast<const char *>(&pair.second), sizeof(pair.second));
    }

    // Save the size of the ctrl_pkt_map
    size = meta.ctrl_pkt_map.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each ctrl_pkt_map
    for (const auto &pair : meta.ctrl_pkt_map)
    {
        write_string(outFile, pair.first);
        // size_t keySize = pair.first.size();
        // outFile.write(reinterpret_cast<const char*>(&keySize), sizeof(keySize));
        // outFile.write(pair.first.c_str(), keySize);

        outFile.write(reinterpret_cast<const char *>(&pair.second), sizeof(pair.second));
    }

    // Save max
    outFile.write(reinterpret_cast<const char *>(&meta.max_op_scratch_pad_size), sizeof(meta.max_op_scratch_pad_size));
    outFile.write(reinterpret_cast<const char *>(&meta.max_tensor_padding_sz), sizeof(meta.max_tensor_padding_sz));

    // Save the size of the partitions
    size = meta.partitions.size();
    outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));

    // Save each partitions
    for (const auto &pd : meta.partitions)
    {
        outFile.write(reinterpret_cast<const char *>(&pd), sizeof(pd));
    }

    std::string str = outFile.str();
    state_data.insert(state_data.end(), str.begin(), str.end());
    if (verbose)
    {
        std::cout << "Meta Data saved to file successfully!" << std::endl;
    }
}

void load_meta(Metadata &meta, const std::string &state_name, bool verbose)
{
    std::ifstream inFile(state_name, std::ios::binary);
    if (!inFile)
    {
        std::cerr << "Error opening file for reading!" << std::endl;
        return;
    }

    // Read major_version
    inFile.read(reinterpret_cast<char *>(&meta.major_version), sizeof(meta.major_version));
    inFile.read(reinterpret_cast<char *>(&meta.minor_version), sizeof(meta.minor_version));

    // Read device string
    meta.device = read_string(inFile);

    // Read the size of the fused_tensors
    size_t mapSize;
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each fused_tensors
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::TensorInfo value = load_TensorInfo(inFile);
        meta.fused_tensors[key] = value;
    }

    // Read the size of the tensor_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each tensor_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::OffsetInfo value = load_OffsetInfo(inFile);
        meta.tensor_map[key] = value;
    }

    // Read the size of the super_instr_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each super_instr_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::Span value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.super_instr_map[key] = value;
    }

    // Read the size of the const_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each const_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::Span value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.const_map[key] = value;
    }

    // Read the size of the ctrl_pkt_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each ctrl_pkt_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::Span value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.ctrl_pkt_map[key] = value;
    }

    // Read max
    inFile.read(reinterpret_cast<char *>(&meta.max_op_scratch_pad_size), sizeof(meta.max_op_scratch_pad_size));
    inFile.read(reinterpret_cast<char *>(&meta.max_tensor_padding_sz), sizeof(meta.max_tensor_padding_sz));

    // Read the size of the partitions
    size_t vecSize;
    inFile.read(reinterpret_cast<char *>(&vecSize), sizeof(vecSize));

    // Read each partitions
    for (size_t i = 0; i < vecSize; ++i)
    {
        Partition value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.partitions.push_back(value);
    }

    inFile.close();
    if (verbose)
    {
        std::cout << "Data loaded from file successfully!" << std::endl;
    }
}

void load_meta_data(Metadata &meta, const std::vector<uint8_t> &state_data, bool verbose)
{
    std::istringstream inFile(std::string(reinterpret_cast<const char*>(state_data.data()), state_data.size()), std::ios::binary);
    if (!inFile)
    {
        std::cerr << "Error opening file for reading!" << std::endl;
        return;
    }

    // Read major_version
    inFile.read(reinterpret_cast<char *>(&meta.major_version), sizeof(meta.major_version));
    inFile.read(reinterpret_cast<char *>(&meta.minor_version), sizeof(meta.minor_version));

    // Read device string
    meta.device = read_string(inFile);

    // Read the size of the fused_tensors
    size_t mapSize;
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each fused_tensors
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::TensorInfo value = load_TensorInfo(inFile);
        meta.fused_tensors[key] = value;
    }

    // Read the size of the tensor_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each tensor_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::OffsetInfo value = load_OffsetInfo(inFile);
        meta.tensor_map[key] = value;
    }

    // Read the size of the super_instr_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each super_instr_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::Span value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.super_instr_map[key] = value;
    }

    // Read the size of the const_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each const_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::Span value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.const_map[key] = value;
    }

    // Read the size of the ctrl_pkt_map
    inFile.read(reinterpret_cast<char *>(&mapSize), sizeof(mapSize));

    // Read each ctrl_pkt_map
    for (size_t i = 0; i < mapSize; ++i)
    {
        std::string key = read_string(inFile);

        Metadata::Span value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.ctrl_pkt_map[key] = value;
    }

    // Read max
    inFile.read(reinterpret_cast<char *>(&meta.max_op_scratch_pad_size), sizeof(meta.max_op_scratch_pad_size));
    inFile.read(reinterpret_cast<char *>(&meta.max_tensor_padding_sz), sizeof(meta.max_tensor_padding_sz));

    // Read the size of the partitions
    size_t vecSize;
    inFile.read(reinterpret_cast<char *>(&vecSize), sizeof(vecSize));

    // Read each partitions
    for (size_t i = 0; i < vecSize; ++i)
    {
        Partition value;
        inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
        meta.partitions.push_back(value);
    }

    if (verbose)
    {
        std::cout << "Data loaded from file successfully!" << std::endl;
    }
}
} // namespace waic_runner
