// Copyright 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
////////////////////////////////////////////////////////////////////////
#pragma once
#include <iostream>
#include <fstream>
#include <vector>
#include "op_common.hpp"
namespace waic_runner {
  template <typename T>
  static void write_string(T &file, const std::string &str)
  {
      size_t length = str.size();
      file.write(reinterpret_cast<const char *>(&length), sizeof(length));
      file.write(str.data(), length);
  }

  template <typename T>
  static void save_TensorInfo(T &outFile, const Metadata::TensorInfo &tinfo)
  {
      outFile.write(reinterpret_cast<const char *>(&tinfo.size), sizeof(tinfo.size));
      outFile.write(reinterpret_cast<const char *>(&tinfo.xrt_arg_idx), sizeof(tinfo.xrt_arg_idx));
      // packed_tensors
      size_t size = tinfo.packed_tensors.size();
      outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));
      for (const auto &pt : tinfo.packed_tensors)
      {
          write_string(outFile, pt);
          // size = pt.size();
          // outFile.write(reinterpret_cast<const char*>(&size), sizeof(size));
          // outFile.write(pt.c_str(), size);
      }
  }

  template <typename T>
  static void save_OffsetInfo(T &outFile, const Metadata::OffsetInfo &oinfo)
  {
      size_t size;
      write_string(outFile, oinfo.parent_name);
      // size = oinfo.parent_name.size();
      // outFile.write(reinterpret_cast<const char*>(&size), sizeof(size));
      // outFile.write(oinfo.parent_name.c_str(), size);

      outFile.write(reinterpret_cast<const char *>(&oinfo.offset), sizeof(oinfo.offset));
      outFile.write(reinterpret_cast<const char*>(&oinfo.additional_offset), sizeof(oinfo.additional_offset));
      outFile.write(reinterpret_cast<const char*>(&oinfo.ref_idx), sizeof(oinfo.ref_idx));
      outFile.write(reinterpret_cast<const char *>(&oinfo.xrt_arg_idx), sizeof(oinfo.xrt_arg_idx));

      write_string(outFile, oinfo.dtype);
      /*size = oinfo.dtype.size();
      outFile.write(reinterpret_cast<const char*>(&size), sizeof(size));
      outFile.write(oinfo.dtype.c_str(), size);*/

      // shape
      size = oinfo.shape.size();
      outFile.write(reinterpret_cast<const char *>(&size), sizeof(size));
      for (const auto &sp : oinfo.shape)
      {
          outFile.write(reinterpret_cast<const char *>(&sp), sizeof(sp));
      }

      outFile.write(reinterpret_cast<const char *>(&oinfo.size_in_bytes), sizeof(oinfo.size_in_bytes));

      write_string(outFile, oinfo.format);
      write_string(outFile, oinfo.file_name);
      // size = oinfo.file_name.size();
      // outFile.write(reinterpret_cast<const char*>(&size), sizeof(size));
      // outFile.write(oinfo.file_name.c_str(), size);

      outFile.write(reinterpret_cast<const char *>(&oinfo.file_size), sizeof(oinfo.file_size));
  }

  template <typename T>
  static std::string read_string(T &inFile)
  {
    size_t length;
    inFile.read(reinterpret_cast<char *>(&length), sizeof(length));
    std::string str(length, '\0');
    inFile.read(&str[0], length);
    return str;
  }

  template <typename T>
  static Metadata::TensorInfo load_TensorInfo(T &inFile)
  {
      Metadata::TensorInfo tinfo;
      inFile.read(reinterpret_cast<char *>(&tinfo.size), sizeof(tinfo.size));
      inFile.read(reinterpret_cast<char *>(&tinfo.xrt_arg_idx), sizeof(tinfo.xrt_arg_idx));
      // packed_tensors
      size_t size;
      inFile.read(reinterpret_cast<char *>(&size), sizeof(size));
      for (size_t i = 0; i < size; ++i)
      {
          std::string value = read_string(inFile);
          tinfo.packed_tensors.push_back(value);
      }
      return tinfo;
  }

  template <typename T>
  static Metadata::OffsetInfo load_OffsetInfo(T &inFile)
  {
      Metadata::OffsetInfo oinfo;
      oinfo.parent_name = read_string(inFile);

      inFile.read(reinterpret_cast<char *>(&oinfo.offset), sizeof(oinfo.offset));
      inFile.read(reinterpret_cast<char*>(&oinfo.additional_offset), sizeof(oinfo.additional_offset));
      inFile.read(reinterpret_cast<char*>(&oinfo.ref_idx), sizeof(oinfo.ref_idx));
      inFile.read(reinterpret_cast<char *>(&oinfo.xrt_arg_idx), sizeof(oinfo.xrt_arg_idx));

      oinfo.dtype = read_string(inFile);

      // shape
      size_t size = oinfo.shape.size();
      inFile.read(reinterpret_cast<char *>(&size), sizeof(size));
      for (size_t i = 0; i < size; ++i)
      {
          size_t value;
          inFile.read(reinterpret_cast<char *>(&value), sizeof(value));
          oinfo.shape.push_back(value);
      }

      inFile.read(reinterpret_cast<char *>(&oinfo.size_in_bytes), sizeof(oinfo.size_in_bytes));
      oinfo.format = read_string(inFile);
      oinfo.file_name = read_string(inFile);
      inFile.read(reinterpret_cast<char *>(&oinfo.file_size), sizeof(oinfo.file_size));

      return oinfo;
  }

	void save_meta(const Metadata& meta, const std::string& state_name, bool verbose);
	void save_meta(const Metadata& meta, std::vector<uint8_t> &state_data, bool verbose);
	void load_meta(Metadata& meta, const std::string& state_name, bool verbose);
	void load_meta_data(Metadata& meta, const std::vector<uint8_t> &state_data, bool verbose);
}
