#include "meta_utils.hpp"
#include "passes.hpp"
#include "ipu_hw_config.hpp"
#include "utils.hpp"

#include <txn_utils.hpp>

#include <xaiengine.h>
namespace waic_runner {
    // record_timer ops
    const std::vector<uint8_t> timer_get_transaction_bin(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr) {

        // Get timer id
        uint32_t timer_id;
        if (attr.find("timer_id") != attr.end()) {
            timer_id = std::any_cast<uint32_t>(attr.find("timer_id")->second);
        }
        else {
            throw std::runtime_error("Can't find timer_id in attrs");
        }

        // Initialize AIE Driver. Hardcode for STRIX for now
        XAie_Config ConfigPtr{
            XAIE_DEV_GEN_AIE2P,      XAIE_BASE_ADDR,          XAIE_COL_SHIFT,
            XAIE_ROW_SHIFT,          XAIE_NUM_ROWS,           XAIE_NUM_COLS,
            XAIE_SHIM_ROW,           XAIE_MEM_TILE_ROW_START, XAIE_MEM_TILE_NUM_ROWS,
            XAIE_AIE_TILE_ROW_START, XAIE_AIE_TILE_NUM_ROWS,  {0} };

        XAie_InstDeclare(DevInst, &ConfigPtr);
        XAie_CfgInitialize(&(DevInst), &ConfigPtr);

        XAie_StartTransaction(&DevInst, XAIE_TRANSACTION_DISABLE_AUTO_FLUSH);

        record_timer_op_t timer_op;
        timer_op.id = timer_id;

        XAie_AddCustomTxnOp(&DevInst, XAIE_IO_CUSTOM_OP_RECORD_TIMER, &timer_op,
            sizeof(timer_op));

        uint8_t* txn_ptr = XAie_ExportSerializedTransaction(&DevInst, 0, 0);
        XAie_TxnHeader* Hdr = (XAie_TxnHeader*)txn_ptr;
        auto size = Hdr->TxnSize;

        std::vector<uint8_t> txn(size, 0);
        memcpy((void*)txn.data(), (void*)txn_ptr, size);

        // check if there is an API to free txn pointer
        free(txn_ptr);
        XAie_Finish(&DevInst);

        return txn;
    }

    const std::vector<uint8_t> preemption_get_transaction_bin(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr) {

        // Get preemption id

        // Initialize AIE Driver. Hardcode for STRIX for now
        XAie_Config ConfigPtr{
            XAIE_DEV_GEN_AIE2P,      XAIE_BASE_ADDR,          XAIE_COL_SHIFT,
            XAIE_ROW_SHIFT,          XAIE_NUM_ROWS,           XAIE_NUM_COLS,
            XAIE_SHIM_ROW,           XAIE_MEM_TILE_ROW_START, XAIE_MEM_TILE_NUM_ROWS,
            XAIE_AIE_TILE_ROW_START, XAIE_AIE_TILE_NUM_ROWS,  {0} };

        XAie_InstDeclare(DevInst, &ConfigPtr);
        XAie_CfgInitialize(&(DevInst), &ConfigPtr);

        XAie_StartTransaction(&DevInst, XAIE_TRANSACTION_DISABLE_AUTO_FLUSH);

        XAie_PreemptHdr preemption_op;
        preemption_op.Op = XAie_TxnOpcode::XAIE_IO_PREEMPT;
        preemption_op.Preempt_level = XAie_Preempt_level::NOOP;

        // XAie_AddCustomTxnOp(&DevInst, XAIE_IO_PREEMPT, &preemption_op,
        //                     sizeof(preemption_op));
        XAie_Txn_Preempt(&DevInst, &preemption_op);

        uint8_t* txn_ptr = XAie_ExportSerializedTransaction(&DevInst, 0, 0);
        XAie_TxnHeader* Hdr = (XAie_TxnHeader*)txn_ptr;
        auto size = Hdr->TxnSize;

        std::vector<uint8_t> txn(size, 0);
        memcpy((void*)txn.data(), (void*)txn_ptr, size);

        // check if there is an API to free txn pointer
        free(txn_ptr);
        XAie_Finish(&DevInst);

        return txn;
    }

    const std::vector<uint8_t> pm_load_get_transaction_bin(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr,
        const std::string& bin_path,
        bool verbose) {

        // Get pm id
        uint8_t pm_id;
        if (attr.find("pm_id") != attr.end()) {
            pm_id = std::any_cast<uint8_t>(attr.find("pm_id")->second);
        }
        else {
            throw std::runtime_error("Can't find pm_id in attrs");
        }

        std::string filename = (std::filesystem::path(bin_path) /
            (std::string{ "txn_pm_" }
                + std::to_string(static_cast<unsigned int>(pm_id))
                + std::string{ ".bin" })).u8string();

        if (verbose) {
            std::cout << "Loading pm txn bin from " << filename << std::endl;
        }

        return ReadBinaryFile(filename, verbose);
    }

    const std::vector<uint8_t> identity_get_transaction_bin(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr) {
        std::vector<uint8_t> txn_vec(sizeof(XAie_TxnHeader), 0);
        XAie_TxnHeader* Hdr = (XAie_TxnHeader*)txn_vec.data();
        Hdr->TxnSize = uint32_t(sizeof(
            XAie_TxnHeader)); // transactions header size without any instructions
        Hdr->NumOps = 0;
        return txn_vec;
    }
}
