// Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
//
// Permission is hereby granted, free of charge, to any person obtaining a copy
// of this software and associated documentation files (the "Software"), to deal
// in the Software without restriction, including without limitation the rights
// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
// copies of the Software, and to permit persons to whom the Software is
// furnished to do so, subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in
// all copies or substantial portions of the Software.
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
// SOFTWARE.

#pragma once
#ifndef AIE_CONTEXT_HPP__
#define AIE_CONTEXT_HPP__

#include <memory>
#include <mutex>
#include <unordered_map>
#include <iostream>

// XRT headers
#include "xrt/xrt_bo.h"
#include "xrt/xrt_device.h"
#include "xrt/xrt_kernel.h"

#include <fstream>

#include <dd_export.hpp>
constexpr std::uint64_t DDR_AIE_ADDR_OFFSET = std::uint64_t{0x80000000};

namespace {
constexpr unsigned int DEVICE_INDEX = 0;
constexpr auto NPU_KERNEL_NAME = "DPU";

std::string get_first_kernel_name(const xrt::xclbin &xclbin_) {
  std::string kernel_name = NPU_KERNEL_NAME;
  for (const auto &kernel : xclbin_.get_kernels()) {
    const auto candidate_kernel_name = kernel.get_name();
    if ((candidate_kernel_name.rfind("XDP_KERNEL", 0) != 0) &&
        (candidate_kernel_name.rfind("vadd", 0) != 0)) {
      kernel_name = candidate_kernel_name;
      break;
    }
  }
  return kernel_name;
}

std::vector<std::string> get_all_kernel_names(const xrt::xclbin &xclbin_) {
  std::vector<std::string> kernel_names;
  for (const auto &kernel : xclbin_.get_kernels()) {
    const auto candidate_kernel_name = kernel.get_name();
    if ((candidate_kernel_name.rfind("XDP_KERNEL", 0) != 0) &&
        (candidate_kernel_name.rfind("vadd", 0) != 0)) {
      kernel_names.push_back(candidate_kernel_name);
    }
  }
  return kernel_names;
}

bool create_xrt_context() {
  std::cout << "Creating new context" << std::endl;
  return true;
}

} // namespace

namespace waic_runner {

    // Read binary file to a vector
    template <typename T = char>
    std::vector<T> read_bin_file(const std::string& filename) {

        std::cout << "Opening file: " << filename << std::endl;
        std::ifstream ifs(filename, std::ios::binary);
        if (!ifs.is_open()) {
            throw std::runtime_error("Couldn't open file for reading");
        }

        std::cout << "Loading data from " << filename << std::endl;
        std::vector<T> dst;

        try {
            ifs.seekg(0, ifs.end);
            auto size = ifs.tellg();
            ifs.seekg(0, ifs.beg);
            dst.resize(size / sizeof(T));
            ifs.read((char*)dst.data(), size);
        }
        catch (std::exception& e) {
            throw std::runtime_error("Failed to read contents from file");
        }
        std::cout << "Loading data from " << filename << "DONE" << std::endl;

        return dst;
    }

using context_id_t = std::uint32_t;
using xrt_key_id_t = std::string;
constexpr std::uint32_t MAX_NUM_XRT_CONTEXTS = 15;
constexpr std::uint32_t MAX_NUM_QHW4_CONTEXTS = 128;
class aie_context {
private:
  static DYNAMIC_DISPATCH_API
      std::unordered_map<xrt_key_id_t, std::shared_ptr<aie_context>>
          ctx_map_;
  static DYNAMIC_DISPATCH_API
      std::unordered_map<xrt_key_id_t, std::shared_ptr<aie_context>>
      qhw4_ctx_map_;
  static DYNAMIC_DISPATCH_API std::mutex aie_ctx_mutex_;
  bool init_{};
  xrt::device device_;
  xrt::xclbin xclbin_;
  xrt::uuid uuid_;
  xrt::hw_context context_;
  std::unordered_map<std::string, xrt::kernel> kernels_;
  xrt::kernel kernel_;
  std::map<std::string, std::uint32_t> qos_;
  xrt_key_id_t xrt_key_;

public:
  aie_context() = default;

  aie_context(const std::vector<char>& xclbin,
              const std::map<std::string, std::uint32_t> &qos,
              xrt_key_id_t xrt_key)
      : init_(create_xrt_context()), device_(DEVICE_INDEX), xclbin_(xclbin),
        uuid_(device_.register_xclbin(xclbin_)),
        context_(device_, xclbin_.get_uuid(), qos),
        kernel_(context_, get_first_kernel_name(xclbin_)) {

    std::vector<std::string> kernel_names = get_all_kernel_names(xclbin_);
    for (const auto &kernel_name : kernel_names) {
      kernels_[kernel_name] = xrt::kernel(context_, kernel_name);
    }
    std::cout << "Created new context" << std::endl;
  }

  aie_context(const xrt::elf& elf,
      xrt_key_id_t xrt_key)
      : init_(create_xrt_context()), device_(DEVICE_INDEX)
      {
      context_ = xrt::hw_context{ device_, elf };

      std::cout << "Created new context" << std::endl;
  }
  static void destroy_ctx_map() { ctx_map_.clear(); }

  static void destroy_qhw4_ctx_map() { qhw4_ctx_map_.clear(); }

  static std::shared_ptr<aie_context>
  get_instance(const std::string &xclbin_fname, context_id_t context_id = 0,
               const std::map<std::string, std::uint32_t> &qos = {},
               const std::vector<char> &xclbin_content = {}) {
  
      std::cout << "Getting context with xclbin: " << xclbin_fname << ", context_id = " << std::to_string(context_id) << std::endl;
    auto xrt_key =
        xclbin_fname + std::string("_context_id_") + std::to_string(context_id);

    auto clean_stale_contexts = [&]() {
      bool removed_contexts = false;

      std::vector<std::string> stale_xclbins;

      for (auto it = ctx_map_.begin(); it != ctx_map_.end(); ++it) {
        // check if only copy is the one in cache
        if (it->second.use_count() == 1) {
          stale_xclbins.push_back(it->first);
        }
      }

      if (stale_xclbins.size() != 0) {
        for (auto &stale_xclbin : stale_xclbins) {
          ctx_map_.erase(stale_xclbin);
        }

        removed_contexts = true;
      } else {
                std::cout <<  "[Warning] Could not find xrt context to remove from cache" << std::endl;
        //TODO: should we throw here ?? Not sure what xrt behaviour will be
      }

      return removed_contexts;
    };
    std::lock_guard<std::mutex> guard(aie_ctx_mutex_);
    if (ctx_map_.find(xrt_key) != ctx_map_.end()) {
            std::cout << "Context found in map" << std::endl;
      return ctx_map_[xrt_key];
    } else {

      if (ctx_map_.size() == MAX_NUM_XRT_CONTEXTS) {

            std::cout << "[Warning] Maximum number of xrt contexts hit" << std::endl;
        (void)clean_stale_contexts();
      }
     
      std::cout << "Context not found in map, creating new one" << std::endl;
      std::cout << "Current num contexts " << std::to_string(ctx_map_.size()) << std::endl;

      bool retry = false;
      std::uint32_t num_retries = 0;
      constexpr std::uint32_t MAX_NUM_RETRIES = 1;
      do {
        try {

          if (xclbin_content.size() == 0) {
            std::vector<char> buffer =
                read_bin_file<char>(xclbin_fname);
            ctx_map_[xrt_key] =
                std::make_shared<aie_context>(buffer, qos, xrt_key);
          } else {
            ctx_map_[xrt_key] =
                std::make_shared<aie_context>(xclbin_content, qos, xrt_key);
          }

          retry = false;
        } catch (...) {
          std::cout << "[Warning] Retrying xrt context creation and cleanup" << std::endl;
          retry = clean_stale_contexts();
          retry = retry && (num_retries < MAX_NUM_RETRIES);
          num_retries++;
        }
      } while (retry);

      return ctx_map_.at(xrt_key);
    }
  }

  static std::shared_ptr<aie_context>
  get_instance(const std::string &xclbin_fname,
               const std::vector<char> &xclbin_data, context_id_t context_id = 0,
               const std::map<std::string, std::uint32_t> &qos = {},
               const std::vector<char> &xclbin_content = {}) {

      std::cout << "Getting context with xclbin: " << xclbin_fname << ", context_id = " << std::to_string(context_id) << std::endl;
    auto xrt_key =
        xclbin_fname + std::string("_context_id_") + std::to_string(context_id);

    auto clean_stale_contexts = [&]() {
      bool removed_contexts = false;

      std::vector<std::string> stale_xclbins;

      for (auto it = ctx_map_.begin(); it != ctx_map_.end(); ++it) {
        // check if only copy is the one in cache
        if (it->second.use_count() == 1) {
          stale_xclbins.push_back(it->first);
        }
      }

      if (stale_xclbins.size() != 0) {
        for (auto &stale_xclbin : stale_xclbins) {
          ctx_map_.erase(stale_xclbin);
        }

        removed_contexts = true;
      } else {
                std::cout <<  "[Warning] Could not find xrt context to remove from cache" << std::endl;
        //TODO: should we throw here ?? Not sure what xrt behaviour will be
      }

      return removed_contexts;
    };
    std::lock_guard<std::mutex> guard(aie_ctx_mutex_);
    if (ctx_map_.find(xrt_key) != ctx_map_.end()) {
            std::cout << "Context found in map" << std::endl;
      return ctx_map_[xrt_key];
    } else {

      if (ctx_map_.size() == MAX_NUM_XRT_CONTEXTS) {

            std::cout << "[Warning] Maximum number of xrt contexts hit" << std::endl;
        (void)clean_stale_contexts();
      }

      std::cout << "Context not found in map, creating new one" << std::endl;
      std::cout << "Current num contexts " << std::to_string(ctx_map_.size()) << std::endl;

      bool retry = false;
      std::uint32_t num_retries = 0;
      constexpr std::uint32_t MAX_NUM_RETRIES = 1;
      do {
        try {

          if (xclbin_content.size() == 0) {
            ctx_map_[xrt_key] =
                std::make_shared<aie_context>(xclbin_data, qos, xrt_key);
          } else {
            ctx_map_[xrt_key] =
                std::make_shared<aie_context>(xclbin_content, qos, xrt_key);
          }

          retry = false;
        } catch (...) {
          std::cout << "[Warning] Retrying xrt context creation and cleanup" << std::endl;
          retry = clean_stale_contexts();
          retry = retry && (num_retries < MAX_NUM_RETRIES);
          num_retries++;
        }
      } while (retry);

      return ctx_map_.at(xrt_key);
    }
  }

  static std::shared_ptr<aie_context>
      get_instance_qhw4(const std::string& elf_fname, bool verbose, context_id_t context_id = 0,
          const std::map<std::string, std::uint32_t>& qos = {},
          const std::vector<char>& xclbin_content = {}) {

      if (verbose) {
          std::cout << "Getting context with elf_file: " << elf_fname << ", context_id = " << std::to_string(context_id) << std::endl;
      }
      auto xrt_key =
          elf_fname + std::string("_context_id_") + std::to_string(context_id);

      auto clean_stale_contexts = [&]() {
          bool removed_contexts = false;

          std::vector<std::string> stale_xclbins;

          for (auto it = qhw4_ctx_map_.begin(); it != qhw4_ctx_map_.end(); ++it) {
              // check if only copy is the one in cache
              if (it->second.use_count() == 1) {
                  stale_xclbins.push_back(it->first);
              }
          }

          if (stale_xclbins.size() != 0) {
              for (auto& stale_xclbin : stale_xclbins) {
                  qhw4_ctx_map_.erase(stale_xclbin);
              }

              removed_contexts = true;
          }
          else {
              std::cout << "[Warning] Could not find xrt context to remove from cache" << std::endl;
              //TODO: should we throw here ?? Not sure what xrt behaviour will be
          }

          return removed_contexts;
          };
      std::lock_guard<std::mutex> guard(aie_ctx_mutex_);
      if (qhw4_ctx_map_.find(xrt_key) != qhw4_ctx_map_.end()) {
          if (verbose) {
              std::cout << "Context found in map" << std::endl;
          }
          return qhw4_ctx_map_[xrt_key];
      }
      else {

          if (qhw4_ctx_map_.size() == MAX_NUM_QHW4_CONTEXTS) {
              if (verbose) {
                  std::cout << "[Warning] Maximum number of xrt contexts hit" << std::endl;
              }
              (void)clean_stale_contexts();
          }
          if (verbose) {
              std::cout << "Context not found in map, creating new one" << std::endl;
              std::cout << "Current num contexts " << std::to_string(qhw4_ctx_map_.size()) << std::endl;
          }

          bool retry = false;
          std::uint32_t num_retries = 0;
          constexpr std::uint32_t MAX_NUM_RETRIES = 1;
          do {
              try {

                  if (xclbin_content.size() == 0) {     
                      xrt::elf elf{ elf_fname };
                      qhw4_ctx_map_[xrt_key] =
                          std::make_shared<aie_context>(elf, xrt_key);
                  }

                  retry = false;
              }
              catch (...) {
                  if (verbose) {
                      std::cout << "[Warning] Retrying xrt context creation and cleanup" << std::endl;
                  }
                  retry = clean_stale_contexts();
                  retry = retry && (num_retries < MAX_NUM_RETRIES);
                  num_retries++;
              }
          } while (retry);

          return qhw4_ctx_map_.at(xrt_key);
      }
  }

  //xrt::device &get_device() { return device_; }
  xrt::hw_context &get_context() { return context_; }
  //xrt::xclbin &get_xclbin() { return xclbin_; }
  //const xrt_key_id_t &get_xrt_context_id() { return xrt_key_; }

};

} // namespace waic_runner

#endif
