/*
 * Copyright (C) 2024 - 2025 Advanced Micro Devices, Inc. All rights reserved.
 */

#include <any>
#include <array>
#include <chrono>
#include <fstream>
#include <iostream>
#include <map>
#include <numeric>
#include <sstream>
#include <tuple>
#include <utility>

#include "utils.hpp"
#include "timing_logger.hpp"
#include "subgraph_op.hpp"
#include "matrix.hpp"
#include "../runtime/wgt_formatting_inmem.hpp"
#include "../runtime/txn_const_padding_inmem.hpp"

using namespace waic_runner;
using namespace waic_runtime;
#define LOG_VERBOSE()          if (verbose_) std::cout
namespace waic_runner {

    template <typename InT, typename WtT, typename OutT>
    std::string subgraph_op<InT, WtT, OutT>::get_instr_key(std::string prefix) const {
        auto ps = std::filesystem::path(binfile_path_ + "/DataGen/Consts/" + prefix + "/").make_preferred();
        return ps.u8string();
    }

    template <typename InT, typename WtT, typename OutT>
    float subgraph_op<InT, WtT, OutT>::bfloat_to_float(uint16_t x) {
        float i = 0;
        uint8_t* src = (uint8_t*)&x;
        uint8_t* tmp = (uint8_t*)&i;
        // copy uint16_t to float (msb)
        std::memcpy(tmp + 2, src, sizeof(uint16_t));
        return i;
    }

    template <typename InT, typename WtT, typename OutT>
    uint16_t subgraph_op<InT, WtT, OutT>::float_to_bfloat16_1(float x) {
        uint32_t i;
        uint8_t* src = (uint8_t*)&x;
        uint8_t* tmp = (uint8_t*)&i;
        // copy float to uint32_t
        std::memcpy(tmp, src, sizeof(float));
        // round to nearest even
        uint32_t lsb = (i >> 16) & 0x1;
        uint32_t bias = 0x7fff + lsb;
        i += bias;
        // extract upper half of input
        uint16_t y = uint16_t(i >> 16);
        return y;
    }

    template <typename InT, typename WtT, typename OutT>
    subgraph_op<InT, WtT, OutT>::subgraph_op(const std::string& op_type,
        const std::string& binfile_path,
        const std::string& prebuilt_bin_dir,
        const json& tilings_data,
        const std::map<std::string, std::any>& attr,
        bool use_inmem) {

        binfile_path_ = binfile_path + "/";
        prebuilt_bin_dir_ = prebuilt_bin_dir;
        tilings_data_ = tilings_data;
        number_inputs_ = 0;
        number_outputs_ = 0;
        onnx_arg_idx_[0] = 0;
        onnx_arg_idx_[1] = 0;
        onnx_arg_idx_[2] = 0;
        support_optype_ = NPUOP;
        use_inmem_ = use_inmem;
        // std::cout << op_type << std::endl;
        if (op_type.find("RECORD_TIMER") != std::string::npos) {
            support_optype_ = RECORDTIME;
        }
        else if (op_type.find("PREEMPTION") != std::string::npos) {
            support_optype_ = PREEMPTION;
        }
        else if (op_type.find("PM_LOAD") != std::string::npos) {
            support_optype_ = PMLOAD;
        }
        else if (op_type.find("noop") != std::string::npos || op_type.find("NoOp") != std::string::npos) {
            support_optype_ = IDENTITY;
        }
        else if (op_type.find("runtime") != std::string::npos || op_type.find("Runtime") != std::string::npos) {
            support_optype_ = IDENTITY;
        }

        if (support_optype_ == NPUOP || support_optype_ == IDENTITY) {
            if (attr.count("input_format") &&
                attr.at("input_format").type() == typeid(std::vector<std::string>)) {
                const auto& input_format_vector =
                    std::any_cast<const std::vector<std::string> &>(
                        attr.at("input_format"));

                number_inputs_ = input_format_vector.size();
                input_format_.resize(number_inputs_);
                for (size_t i = 0; i < number_inputs_; i++) {
                    input_format_[i] = input_format_vector[i];
                }
            }
            else {
                LOG_VERBOSE() << "Input Format attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("output_format") &&
                attr.at("output_format").type() == typeid(std::vector<std::string>)) {
                const auto& output_format_vector =
                    std::any_cast<const std::vector<std::string> &>(
                        attr.at("output_format"));

                number_outputs_ = output_format_vector.size();
                output_format_.resize(number_outputs_);
                for (size_t i = 0; i < number_outputs_; i++) {
                    output_format_[i] = output_format_vector[i];
                }
            }
            else {
                LOG_VERBOSE() << "Output Format attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("input_shape") &&
                attr.at("input_shape").type() == typeid(std::vector<int>)) {
                const auto& input_shape_vector =
                    std::any_cast<const std::vector<int> &>(attr.at("input_shape"));

                //if (input_shape_vector.size() != 4 * number_inputs_) {
                //    std::cout
                //        << "Input Shape attribute does not have the expected number of "
                //        "elements.Number of passed : input_shape_vector.size(), Expected:"
                //        << 4 * number_inputs_ << std::endl;
                //}
                // assume each input has the same size
                input_shape_size_ = input_shape_vector.size() / number_inputs_;
                inputShape_.resize(input_shape_vector.size());
                for (size_t i = 0; i < input_shape_vector.size(); i++) {
                    inputShape_[i] = input_shape_vector[i];
                }
            }
            else {
                LOG_VERBOSE() << "Input Shape attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("output_shape") &&
                attr.at("output_shape").type() == typeid(std::vector<int>)) {
                const auto& output_shape_vector =
                    std::any_cast<const std::vector<int> &>(attr.at("output_shape"));

                /*if (output_shape_vector.size() != 4 * number_outputs_) {
                    std::cout
                        << "Output Shape attribute does not have the expected number of "
                        "elements.Number of passed : input_shape_vector.size(), Expected:"
                        << 4 * number_outputs_ << std::endl;
                }*/
                // assume each output has the same size
                output_shape_size_ = output_shape_vector.size() / number_outputs_;
                outputShape_.resize(output_shape_vector.size());
                for (size_t i = 0; i < output_shape_vector.size(); i++) {
                    outputShape_[i] = output_shape_vector[i];
                }
            }
            else {
                LOG_VERBOSE() << "Output Shape attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("input_datatype") &&
                attr.at("input_datatype").type() == typeid(std::vector<std::string>)) {
                const auto& input_datatype_vector =
                    std::any_cast<const std::vector<std::string> &>(
                        attr.at("input_datatype"));

                if (input_datatype_vector.size() != number_inputs_) {
                    LOG_VERBOSE()
                        << "Input Datatype attribute does not have the expected number of "
                        "elements.Number of passed : input_datatype_vector.size(), "
                        "Expected:"
                        << number_inputs_ << std::endl;
                }
                input_datatype_.resize(input_datatype_vector.size());
                for (size_t i = 0; i < input_datatype_vector.size(); i++) {
                    input_datatype_[i] = input_datatype_vector[i];
                }
            }
            else {
                LOG_VERBOSE() << "Input Datatype attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("output_datatype") &&
                attr.at("output_datatype").type() == typeid(std::vector<std::string>)) {
                const auto& output_datatype_vector =
                    std::any_cast<const std::vector<std::string> &>(
                        attr.at("output_datatype"));

                if (output_datatype_vector.size() != number_outputs_) {
                    LOG_VERBOSE()
                        << "Output Datatype attribute does not have the expected number of "
                        "elements.Number of passed : output_datatype_vector.size(), "
                        "Expected:"
                        << number_outputs_ << std::endl;
                }
                output_datatype_.resize(output_datatype_vector.size());
                for (size_t i = 0; i < output_datatype_vector.size(); i++) {
                    output_datatype_[i] = output_datatype_vector[i];
                }
            }
            else {
                LOG_VERBOSE() << "Output Datatype attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("onnx_arg_idx") &&
                attr.at("onnx_arg_idx").type() == typeid(std::vector<int>)) {
                const auto& onnx_arg_idx_vector =
                    std::any_cast<const std::vector<int> &>(attr.at("onnx_arg_idx"));

                if (onnx_arg_idx_vector.size() == 3) {
                    onnx_arg_idx_[0] = onnx_arg_idx_vector[0]; // first input idx
                    onnx_arg_idx_[1] = onnx_arg_idx_vector[1]; // first wgt idx
                    onnx_arg_idx_[2] = onnx_arg_idx_vector[2]; // first output idx
                }
                else {
                    LOG_VERBOSE()
                        << "Onnx arg idx attribute does not have the expected number of "
                        "elements.Number of passed : onnx_arg_idx_vector.size(), "
                        "Expected:4"
                        << std::endl;
                }
            }
            else {
                LOG_VERBOSE() << "Onnx arg idx attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("layer_name") &&
                attr.at("layer_name").type() == typeid(std::vector<std::string>)) {
                const auto& layer_name_vector =
                    std::any_cast<const std::vector<std::string> &>(attr.at("layer_name"));

                if (layer_name_vector.size() == 1) {
                    layer_name_ = layer_name_vector[0];
                }
                else {
                    LOG_VERBOSE() << "Layer name attribute does not have the expected number of "
                        "elements.Number of passed : layer_name_vector.size(), "
                        "Expected:1"
                        << std::endl;
                }
                if (!use_inmem_) {
                    replace_symbols(layer_name_);
                }
            }
            else {
                LOG_VERBOSE() << "LayerName attribute not found or not of correct type."
                    << std::endl;
            }

            if (attr.count("tkey") &&
                attr.at("tkey").type() == typeid(std::vector<std::string>)) {
                const auto& tkey_vector =
                    std::any_cast<const std::vector<std::string> &>(attr.at("tkey"));

                if (tkey_vector.size() == 1) {
                    tkey_ = tkey_vector[0];
                }
                else {
                    LOG_VERBOSE() << "tkey attribute does not have the expected number of "
                        "elements.Number of passed : tkey_vector.size(), "
                        "Expected:1"
                        << std::endl;
                }
            }
            else {
                LOG_VERBOSE() << "tkey attribute not found or not of correct type."
                    << std::endl;
            }

            bkend_ = "";
            if (attr.count("bkend") &&
                attr.at("bkend").type() == typeid(std::vector<std::string>)) {
                const auto& bkend_vector =
                    std::any_cast<const std::vector<std::string> &>(
                        attr.at("bkend"));

                if (bkend_vector.size() == 1) {
                    bkend_ = bkend_vector[0];
                }
                else {
                    LOG_VERBOSE() << "bkend attribute does not have the expected number of "
                        "elements.Number of passed : bkend_vector.size(), "
                        "Expected:1"
                        << std::endl;
                }
            }

            conv_special_flag_ = "";
            if (attr.count("conv_special_flag") &&
                attr.at("conv_special_flag").type() == typeid(std::vector<std::string>)) {
                const auto& conv_special_flag_vector =
                    std::any_cast<const std::vector<std::string> &>(
                        attr.at("conv_special_flag"));

                if (conv_special_flag_vector.size() == 1) {
                    conv_special_flag_ = conv_special_flag_vector[0];
                }
                else {
                    LOG_VERBOSE() << "conv_special_flag attribute does not have the expected number of "
                        "elements.Number of passed : conv_special_flag_vector.size(), "
                        "Expected:1"
                        << std::endl;
                }
            }
            

            if (bkend_ == "mladf") {
                for (int i = 0; i < 5; i++) {
                    BO_seq_[i] = MLADF_BO_SEQUENCE[i];
                }
            }
            else {
                for (int i = 0; i < 5; i++) {
                    BO_seq_[i] = WAIC_BO_SEQUENCE[i];
                }
            }
        }
    }

    template <typename InT, typename WtT, typename OutT>
    void subgraph_op<InT, WtT, OutT>::initialize_const_params(
        ConstBufferIO& io, const std::vector<Tensor>& const_params,
        const std::map<std::string, std::any>& attr) {
        // Get wgt data from file
        if (support_optype_ == NPUOP) {
            try {
                size_t wgt_size;
                if (use_inmem_) {
                    auto const_path = std::filesystem::path(binfile_path_ + "/DataGen/Consts/").make_preferred();
                    auto bint = wgt_formatting_inmem(const_path.u8string(), binfile_path_, tilings_data_, tkey_, layer_name_, verbose_, verbose_);
                    LOG_VERBOSE() << layer_name_ << std::endl;
                    LOG_VERBOSE() << "wgts_inmem = " << bint.size() << std::endl;
                    if (bint.size() > 0) {
                        wgt_size = bint.size();
                        io.write(0, bint.data(), wgt_size);
                    }
                    else {
                        std::cout << "Empty const data in node " << layer_name_ << std::endl;
                        throw std::runtime_error("subgraph_op : Empty const data.");
                    }
                }
                else {
                    // use pregenerated bin file
                    auto wgt_bo_key = get_instr_key(layer_name_) + "wgt";
                    auto bint = ReadBinaryFile(wgt_bo_key + ".bin", verbose_);
                    wgt_size = bint.size();
                    io.write(0, bint.data(), wgt_size);
                }
                // std::cout << layer_name_ << ": " << wgt_size << std::endl;
            }
            catch (...) {
                //std::vector<uint8_t> bint(16, 0);
                //size_t wgt_size = bint.size();
                //io.write(0, bint.data(), wgt_size);
                printf("No wgt.bin for %s\n", layer_name_.c_str());
            }
        }
    }

    template <typename InT, typename WtT, typename OutT>
    const std::vector<uint8_t> subgraph_op<InT, WtT, OutT>::get_transaction_bin(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr) const {
        if (support_optype_ == NPUOP) {
            if (use_inmem_) {
                auto const_path = binfile_path_ + "/DataGen/Consts/";
                auto bint = txn_update_inmem(const_path, binfile_path_, tilings_data_, tkey_, layer_name_, verbose_);
                return bint;
            }
            else {
                std::string txn_key = get_instr_key(layer_name_) + "txn";
                return ReadBinaryFile(txn_key + ".bin", verbose_);
            }
        }
        else if (support_optype_ == RECORDTIME) {
            return timer_get_transaction_bin(input, output, attr);
        }
        else if (support_optype_ == PREEMPTION) {
            return preemption_get_transaction_bin(input, output, attr);
        }
        else if (support_optype_ == PMLOAD) {
            return pm_load_get_transaction_bin(input, output, attr, prebuilt_bin_dir_, verbose_);
        }
        else if (support_optype_ == IDENTITY) {
            return identity_get_transaction_bin(input, output, attr);
        }
        else {
            return {};
        }
    }

    template <typename InT, typename WtT, typename OutT>
    const std::vector<uint8_t> subgraph_op<InT, WtT, OutT>::get_super_kernel_params(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr) const {
        if (support_optype_ == NPUOP) {
            if (use_inmem_) {
                auto ps = std::filesystem::path(binfile_path_ + tkey_ + "/param.bin").make_preferred();
                return ReadBinaryFile(ps.u8string(), verbose_);
            }
            else {
                std::string param_key = get_instr_key(layer_name_) + "param";
                return ReadBinaryFile(param_key + ".bin", verbose_);
            }
        }
        else {
            return {};
        }
    }

    template <typename InT, typename WtT, typename OutT>
    std::vector<CtrlPktPatchInfo>
        subgraph_op<InT, WtT, OutT>::get_ctrl_pkt_patch_info(
            std::vector<Tensor>& input, std::vector<Tensor>& output,
            const std::map<std::string, std::any>& attr) const {
        if (support_optype_ == NPUOP) {
            try {
                if (use_inmem_) {
                    auto ps = std::filesystem::path(binfile_path_ + tkey_ + "/patch.json").make_preferred();
                    json data = read_json_file(ps.u8string());
                    if (bkend_ == "mladf") {
                        return ext_buf_json_to_ctrlpkt_patch_info(data, get_ctrl_pkts(input, output, attr));
                    }
                    else {
                        return extract_ctrlpkt_patch_info(data);
                    }
                }
                else {
                    std::string ctrl_pkt_meta = get_instr_key(layer_name_) + "patch";
                    json data = read_json_file(ctrl_pkt_meta + ".json");
                    if (bkend_ == "mladf") {
                        return ext_buf_json_to_ctrlpkt_patch_info(data, get_ctrl_pkts(input, output, attr));
                    }
                    else {
                        return extract_ctrlpkt_patch_info(data);
                    }
                }
            }
            catch (...) {
                throw std::runtime_error("subgraph_op : No patch.json file.");
                return {};
            }
        }
        else {
            return {};
        }
    }

    template <typename InT, typename WtT, typename OutT>
    std::vector<uint8_t> subgraph_op<InT, WtT, OutT>::get_ctrl_pkts(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr) const {
        if (support_optype_ == NPUOP) {
            try {
                if (use_inmem_) {
                    auto ps = std::filesystem::path(binfile_path_ + tkey_ + "/ctrl.bin").make_preferred();
                    return ReadBinaryFile(ps.u8string(), verbose_);
                }
                else {
                    std::string ctrl_pkt_key = get_instr_key(layer_name_) + "ctrl";
                    return ReadBinaryFile(ctrl_pkt_key + ".bin", verbose_);
                }
            }
            catch (...) {
                throw std::runtime_error("subgraph_op : No ctrl.bin file.");
                return {};
            }
        }
        else if (support_optype_ == PMLOAD) {
            uint8_t pm_id;
            if (attr.find("pm_id") != attr.end()) {
                pm_id = std::any_cast<uint8_t>(attr.find("pm_id")->second);
            }
            else {
                throw std::runtime_error("Can't find pm_id in attrs");
            }

            auto bin_path = std::filesystem::path(prebuilt_bin_dir_ + "/").make_preferred();
            std::string filename = bin_path.u8string() + "pm_" + std::to_string(static_cast<unsigned int>(pm_id)) + ".bin";

            LOG_VERBOSE() << "Loading pm.bin from " << filename << std::endl;

            return ReadBinaryFile(filename, verbose_);

        }
        else {
            return {};
        }
    }

    template <typename InT, typename WtT, typename OutT>
    std::vector<OpArgMap> subgraph_op<InT, WtT, OutT>::get_buffer_reqs(
        std::vector<Tensor>& input, std::vector<Tensor>& output,
        const std::map<std::string, std::any>& attr) const {
        std::vector<OpArgMap> arg_map;
        struct OpArgMap inmap;
        if (support_optype_ == NPUOP) {
            size_t wgt_size;
            bool flag_concat = 0;
            try {
                if (use_inmem_) {
                    wgt_size = get_wgts_size(binfile_path_, tilings_data_, tkey_);
                    LOG_VERBOSE() << layer_name_ << std::endl;
                    LOG_VERBOSE() << "wgts_size = " << wgt_size << std::endl;
                    if (wgt_size == 0) {
                        std::cout << "wgts size is 0, from " << layer_name_ << std::endl;
                        throw std::runtime_error("subgraph_op : wgts size is 0.");
                    }
                }
                else {
                    auto wgt_bo_key = get_instr_key(layer_name_) + "wgt";
                    wgt_size = ReadBinaryFile(wgt_bo_key + ".bin", verbose_).size();
                }
            }
            catch (...) {
                wgt_size = 0; // dummy wgt
                //flag_concat = 1;
            }
            size_t super_kernel_size = get_super_kernel_params(input, output).size();
            size_t ctrl_pkt_size = get_ctrl_pkts(input, output, attr).size();

            if (flag_concat == 0) {

                size_t input_offset = 0;
                for (size_t n = 0; n < number_inputs_; n++) {
                    size_t input_bo_size = get_size_of_type(input_datatype_[n]);
                    for (size_t i = 0; i < input_shape_size_; i++) {
                        input_bo_size *= inputShape_[n * input_shape_size_ + i];
                    }
                    inmap = { OpArgMap::OpArgType::INPUT, BO_seq_[1], onnx_arg_idx_[0] + n,
                             input_offset, input_bo_size };
                    input_offset += input_bo_size;
                    arg_map.push_back(inmap);
                }
                inmap = { OpArgMap::OpArgType::CONST_INPUT, BO_seq_[2], onnx_arg_idx_[1], 0,
                         wgt_size },
                    arg_map.push_back(inmap);

                size_t output_offset = 0;
                for (size_t n = 0; n < number_outputs_; n++) {
                    size_t output_bo_size = get_size_of_type(output_datatype_[n]);
                    for (size_t i = 0; i < output_shape_size_; i++) {
                        output_bo_size *= outputShape_[n * output_shape_size_ + i];
                    }
                    inmap = { OpArgMap::OpArgType::OUTPUT, BO_seq_[0], onnx_arg_idx_[2] + n,
                             output_offset, output_bo_size };
                    output_offset += output_bo_size;
                    arg_map.push_back(inmap);
                }

                inmap = { OpArgMap::OpArgType::CONST_KERNEL_PARAM_INPUT, BO_seq_[3], 0, 0,
                         super_kernel_size };
                arg_map.push_back(inmap);
                inmap = { OpArgMap::OpArgType::CTRL_PKT_BIN, BO_seq_[4], 0, 0, ctrl_pkt_size};
                arg_map.push_back(inmap);
            }
            else {
                for (size_t n = 0; n < number_inputs_; n++) {
                    size_t input_bo_size = get_size_of_type(input_datatype_[n]);
                    for (size_t i = 0; i < input_shape_size_; i++) {
                        input_bo_size *= inputShape_[n * input_shape_size_ + i];
                    }
                    inmap = { OpArgMap::OpArgType::INPUT, 1 + n, onnx_arg_idx_[0] + n, 0,
                             input_bo_size };
                    arg_map.push_back(inmap);
                }

                size_t output_offset = 0;
                for (size_t n = 0; n < number_outputs_; n++) {
                    size_t output_bo_size = get_size_of_type(output_datatype_[n]);
                    for (size_t i = 0; i < output_shape_size_; i++) {
                        output_bo_size *= outputShape_[n * output_shape_size_ + i];
                    }
                    inmap = { OpArgMap::OpArgType::OUTPUT, 0, onnx_arg_idx_[2] + n,
                             output_offset, output_bo_size };
                    output_offset += output_bo_size;
                    arg_map.push_back(inmap);
                }

                inmap = { OpArgMap::OpArgType::CONST_KERNEL_PARAM_INPUT, BO_seq_[3], 0, 0,
                         super_kernel_size };
                arg_map.push_back(inmap);
                inmap = { OpArgMap::OpArgType::CTRL_PKT_BIN, BO_seq_[4], 0, 0, ctrl_pkt_size };
                arg_map.push_back(inmap);
            }

            return arg_map;
        }
        else if (support_optype_ == PMLOAD) {

            size_t ctrl_pkt_size = get_ctrl_pkts(input, output, attr).size();

            inmap = { OpArgMap::OpArgType::CTRL_PKT_BIN, 0, 0, 0, ctrl_pkt_size };
            arg_map.push_back(inmap);

            return arg_map;
        }
        else if (support_optype_ == IDENTITY) {
            size_t input_offset = 0;
            for (size_t n = 0; n < number_inputs_; n++) {
                size_t input_bo_size = get_size_of_type(input_datatype_[n]);
                for (size_t i = 0; i < input_shape_size_; i++) {
                    input_bo_size *= inputShape_[n * input_shape_size_ + i];
                }
                inmap = { OpArgMap::OpArgType::INPUT, BO_seq_[1], onnx_arg_idx_[0] + n,
                         input_offset, input_bo_size };
                input_offset += input_bo_size;
                arg_map.push_back(inmap);
            }

            size_t output_offset = 0;
            for (size_t n = 0; n < number_outputs_; n++) {
                size_t output_bo_size = get_size_of_type(output_datatype_[n]);
                for (size_t i = 0; i < output_shape_size_; i++) {
                    output_bo_size *= outputShape_[n * output_shape_size_ + i];
                }
                inmap = { OpArgMap::OpArgType::OUTPUT, BO_seq_[0], onnx_arg_idx_[2] + n,
                         output_offset, output_bo_size };
                output_offset += output_bo_size;
                arg_map.push_back(inmap);
            }
            return arg_map;
        }
        else {
            return {};
        }

    };

    template <typename InT, typename WtT, typename OutT>
    std::tuple<size_t, size_t, size_t, size_t>
        subgraph_op<InT, WtT, OutT>::extract_NHWC(const std::vector<size_t> shape) {
        size_t N, H, W, C;
        if (shape.size() == 4) {
            N = shape[0];
            H = shape[1];
            W = shape[2];
            C = shape[3];
        }
        else if (shape.size() == 3) {
            N = 1;
            H = shape[0];
            W = shape[1];
            C = shape[2];
        }
        else if (shape.size() == 2) {
            N = 1;
            H = 1;
            W = shape[0];
            C = shape[1];
        }
        else if (shape.size() == 1) {
            N = 1;
            H = 1;
            W = 1;
            C = shape[0];
        }
        else {
            throw std::runtime_error("Unsupported padded input shape");
        }

        return std::make_tuple(N, H, W, C);
    }

    template <typename InT, typename WtT, typename OutT>
    std::tuple<size_t, size_t, size_t>
        subgraph_op<InT, WtT, OutT>::extract_BMN(const std::vector<size_t> shape) {
        size_t B, M, N;
        bool start = true;
        std::vector<size_t> actual_shape;
        for (int i = 0; i < shape.size(); ++i) {
            if (shape[i] == 1 && start) {
                continue;
            }
            else {
                start = false;
                actual_shape.push_back(shape[i]);
            }
        }
        if (actual_shape.size() == 4) {
            B = actual_shape[0] * actual_shape[1];
            M = actual_shape[2];
            N = actual_shape[3];
        }
        else if (actual_shape.size() == 3) {
            B = actual_shape[0];
            M = actual_shape[1];
            N = actual_shape[2];
        }
        else if (actual_shape.size() == 2) {
            B = 1;
            M = actual_shape[0];
            N = actual_shape[1];
        }
        else if (actual_shape.size() == 1) {
            B = 1;
            M = 1;
            N = actual_shape[0];
        }
        else {
            throw std::runtime_error("Unsupported raw input for batch  shape");
        }

        return std::make_tuple(B, M, N);
    }

    template <typename InT, typename WtT, typename OutT>
    std::tuple<size_t, size_t, size_t>
        subgraph_op<InT, WtT, OutT>::extract_MN(const std::vector<size_t> shape) {
        size_t B, M, N;
        if (shape.size() == 4) {
            B = shape[0];
            M = shape[1] * shape[2];
            N = shape[3];
        }
        else if (shape.size() == 3) {
            if (shape[0] == 1) {
                B = shape[0];
                M = shape[1];
                N = shape[2];
            }
            else {
                B = 1;
                M = shape[0];
                N = shape[1] * shape[2];
            }
        }
        else if (shape.size() == 2) {
            B = 1;
            M = shape[0];
            N = shape[1];
        }
        else if (shape.size() == 1) {
            B = 1;
            M = 1;
            N = shape[0];
        }
        else {
            throw std::runtime_error("Unsupported raw input shape");
        }

        return std::make_tuple(B, M, N);
    }

    template <typename InT, typename WtT, typename OutT>
    std::vector<size_t>
        subgraph_op<InT, WtT, OutT>::get_pad_shape(const std::vector<int64_t> shape,
            size_t tensor_number,
            size_t tensor_idx) {
        assert(shape.size() % tensor_number == 0);
        int dim = shape.size() / tensor_number;
        std::vector<size_t> pad_shape;
        int start = dim * tensor_idx;
        int end = dim * (tensor_idx + 1);
        bool first_1 = true;
        for (int i = start; i < end; ++i) {
            if (shape[i] == 1 && first_1) {
                continue;
            }
            first_1 = false;
            pad_shape.push_back((size_t)shape[i]);
        }

        std::vector<size_t> output_shape = reduce_shape_to_4d(pad_shape);

        return output_shape;
    }

    template <typename InT, typename WtT, typename OutT>
    void subgraph_op<InT, WtT, OutT>::format_output(
        const Tensor& out_tensor, void* hw_out_ptr, size_t sz, size_t tensor_idx,
        const std::map<std::string, std::any>& attr) {
        auto format_output_start = GET_TIMESTAMP();
        // shape from the tensor nchw
        auto cpu_output_shape = out_tensor.shape;
        size_t ch_idx = cpu_output_shape.size() - 1;
        std::string skip_nchw_conversion = get_env_var("SKIP_NCHW_CONVERSION", "0");
        if (output_format_[tensor_idx].find("NCHW") != std::string::npos && (skip_nchw_conversion == "0")) {
            if (out_tensor.shape.size() < 4) {
                ch_idx = cpu_output_shape.size() - 2;
            }
            else {
                ch_idx = cpu_output_shape.size() - 3;
            }
        }
        size_t cpu_ch_size = cpu_output_shape[ch_idx];

        auto cpu_datatype_size = get_size_of_type(output_datatype_[tensor_idx]);
	if (output_format_[tensor_idx].find("bf2f") != std::string::npos) {
	    cpu_datatype_size = sizeof(OutT);
	}
        size_t cpu_output_size =
            std::accumulate(cpu_output_shape.begin(), cpu_output_shape.end(), (size_t)1, std::multiplies<size_t>());
        cpu_output_size *= cpu_datatype_size;

        // shape from AIE
        size_t npu_output_size = cpu_datatype_size;
        // Get padded shape
        std::vector<int64_t> output_shape_pad(output_shape_size_, 1);
	for (size_t i = 0; i < output_shape_size_; i++) {
            output_shape_pad[i] = outputShape_[tensor_idx * output_shape_size_ + i];
        }

        for (size_t i = 0; i < output_shape_pad.size(); i++)
        {
            npu_output_size *= output_shape_pad[i];
        }

        if (npu_output_size > sz)
        {
            throw std::runtime_error("subgraph_op : The size of hw_out is not correct.");
        }

        if (cpu_output_size > sz)
        {
            throw std::runtime_error("subgraph_op : The size of hw_out is smaller than onnx shape.");
        }

	//std::vector<OutT> out_trans(cpu_output_size / sizeof(OutT), 0);
        const uint8_t* src = reinterpret_cast<const uint8_t*>(hw_out_ptr);
        uint8_t* dst = reinterpret_cast<uint8_t*>(out_tensor.data);

        if (cpu_output_size == sz) {
            if (output_format_[tensor_idx].find("NCHW") != std::string::npos && (skip_nchw_conversion == "0")) {
                auto nchw_no_depad_start = GET_TIMESTAMP();
                // NHWC copy
                size_t cpu_out_ch_size = 1;
                size_t cpu_inner_ch_size = 1; // H and W dimension size
                for (size_t i = 0; i < cpu_output_shape.size(); i++)
                {
                    if (i < ch_idx) {
                        cpu_out_ch_size *= cpu_output_shape[i];
                    }
                    else if (i > ch_idx) {
                        cpu_inner_ch_size *= cpu_output_shape[i];
                    }
                }
                size_t dst_offset = 0;
                size_t src_offset = 0;
                for (size_t i = 0; i < cpu_out_ch_size; i++) {
                    for (size_t c = 0; c < cpu_ch_size; c++) {
                        for (size_t j = 0; j < cpu_inner_ch_size; j++) {
                            src_offset = (i * cpu_ch_size * cpu_inner_ch_size + j * cpu_ch_size + c) * cpu_datatype_size;
                            dst_offset = (i * cpu_ch_size * cpu_inner_ch_size + c * cpu_inner_ch_size + j) * cpu_datatype_size;
                            memcpy((void*)(dst + dst_offset), (void*)(src + src_offset), cpu_datatype_size);
                        }
                    }
                }
                auto nchw_no_depad_end = GET_TIMESTAMP();
		LOG_GLOBAL_TIMING("SUBGRAPH::NCHW_format_no_depad", nchw_no_depad_start, nchw_no_depad_end);
            }
            else {
                // copy only
		auto direct_copy_start = GET_TIMESTAMP();
                memcpy((void*)dst, (void*)src, cpu_output_size);
		auto direct_copy_end = GET_TIMESTAMP();
		LOG_GLOBAL_TIMING("SUBGRAPH::output_direct_copy", direct_copy_start, direct_copy_end);
            }
        }
        else {
            if (output_format_[tensor_idx].find("NCHW") != std::string::npos && (skip_nchw_conversion == "0")) {
                auto nchw_depad_start = GET_TIMESTAMP();
                //NHWC + depad copy
                // only support inner_dim depadding
		size_t pdded_ch_size = (cpu_ch_size % 8) ? ((cpu_ch_size / 8 + 1) * 8) : cpu_ch_size;
                size_t cpu_out_ch_size = 1;
                size_t cpu_inner_ch_size = 1; // H and W dimension size
                for (size_t i = 0; i < cpu_output_shape.size(); i++)
                {
                    if (i < ch_idx) {
                        cpu_out_ch_size *= cpu_output_shape[i];
                    }
                    else if (i > ch_idx) {
                        cpu_inner_ch_size *= cpu_output_shape[i];
                    }
                }
		size_t pdded_out_ch_size = cpu_out_ch_size * cpu_inner_ch_size;
	        size_t hw_pad_size = 1;
                for (size_t i = 0; i < output_shape_pad.size(); i++)
                {
	            hw_pad_size *= output_shape_pad[i];
	        }
                if (pdded_out_ch_size * pdded_ch_size > hw_pad_size) {
                    std::cout << "Padded size: " << pdded_out_ch_size * pdded_ch_size << ", HW size: " << hw_pad_size << std::endl;
	            std::cout << "Padded channel size: " << pdded_ch_size << std::endl;
                    throw std::runtime_error("Input Tensor pad size doesn't match.");
                }

                size_t dst_offset = 0;
                size_t src_offset = 0;
                for (size_t i = 0; i < cpu_out_ch_size; i++) {
                    for (size_t c = 0; c < cpu_ch_size; c++) {
                        for (size_t j = 0; j < cpu_inner_ch_size; j++) {
                            src_offset = (i * pdded_out_ch_size + j * pdded_ch_size + c) * cpu_datatype_size;
                            dst_offset = (i * cpu_ch_size * cpu_inner_ch_size + c * cpu_inner_ch_size + j) * cpu_datatype_size;
                            memcpy((void*)(dst + dst_offset), (void*)(src + src_offset), cpu_datatype_size);
                        }
                    }
                }
                auto nchw_depad_end = GET_TIMESTAMP();
		LOG_GLOBAL_TIMING("SUBGRAPH::NCHW_format_depad", nchw_depad_start, nchw_depad_end);
            }
            else {
                auto depad_start = GET_TIMESTAMP();
                // depad only
                // only support inner_dim depadding
		size_t pdded_ch_size = (cpu_ch_size % 8) ? ((cpu_ch_size / 8 + 1) * 8) : cpu_ch_size;
                size_t cpu_out_ch_size = 1;
                for (size_t i = 0; i < cpu_output_shape.size(); i++)
                {
                    if (i != ch_idx) {
                        cpu_out_ch_size *= cpu_output_shape[i];
                    }
                }
		size_t pdded_out_ch_size = cpu_out_ch_size;
	        size_t hw_pad_size = 1;
                for (size_t i = 0; i < output_shape_pad.size(); i++)
                {
	            hw_pad_size *= output_shape_pad[i];
	        }
                if (pdded_out_ch_size * pdded_ch_size > hw_pad_size) {
                    std::cout << "Padded size: " << pdded_out_ch_size * pdded_ch_size << ", HW size: " << hw_pad_size << std::endl;
	            std::cout << "Padded channel size: " << pdded_ch_size << std::endl;
                    throw std::runtime_error("Output Tensor pad size doesn't match.");
                }
                size_t dst_offset = 0;
                size_t src_offset = 0;
                size_t copy_size = cpu_ch_size * cpu_datatype_size;
                for (size_t i = 0; i < cpu_out_ch_size; i++) {
                    memcpy((void*)(dst + dst_offset), (void*)(src + src_offset), copy_size);
                    src_offset += pdded_ch_size * cpu_datatype_size;
                    dst_offset += copy_size;
                }
                auto depad_end = GET_TIMESTAMP();
		LOG_GLOBAL_TIMING("SUBGRAPH::depad_only", depad_start, depad_end);
            }
        }
        auto format_output_end = GET_TIMESTAMP();
        LOG_GLOBAL_TIMING("SUBGRAPH::format_output_total", format_output_start, format_output_end);
	/*
	auto aie_out = (InT*)out_tensor.data;
        if (output_format_[tensor_idx].find("bf2f") != std::string::npos) {
            for (int i = 0; i < out_trans.size(); ++i) {
	        aie_out[i] = bfloat_to_float(out_trans[i]);
	    }
	} else {
            memcpy((void*)aie_out, (void*)out_trans.data(), cpu_output_size);
	}
	*/
    }

    template <typename InT, typename WtT, typename OutT>
    void subgraph_op<InT, WtT, OutT>::format_input(
        const Tensor& in_tensor, void* hw_in_ptr, size_t sz, size_t tensor_idx,
        const std::map<std::string, std::any>& attr) {
        auto format_input_start = GET_TIMESTAMP();
	// shape from the tensor nchw
        auto cpu_input_shape = in_tensor.shape;
        size_t ch_idx = cpu_input_shape.size() - 1;
        std::string skip_nchw_conversion = get_env_var("SKIP_NCHW_CONVERSION", "0");
        if (input_format_[tensor_idx].find("NCHW") != std::string::npos && (skip_nchw_conversion == "0")) {
            if (in_tensor.shape.size() < 4) {
                ch_idx = cpu_input_shape.size() - 2;
            }
            else {
                ch_idx = cpu_input_shape.size() - 3;
            }
        }
        size_t cpu_ch_size = cpu_input_shape[ch_idx];

        auto input_datatype_size = get_size_of_type(input_datatype_[tensor_idx]);
        size_t cpu_input_size =
            std::accumulate(cpu_input_shape.begin(), cpu_input_shape.end(), (size_t)1, std::multiplies<size_t>());
        cpu_input_size *= input_datatype_size;

        // shape from AIE
        size_t npu_input_size = input_datatype_size;
        // Get padded shape
        std::vector<int64_t> input_shape_pad(input_shape_size_, 1);
	for (size_t i = 0; i < input_shape_size_; i++) {
            input_shape_pad[i] = inputShape_[tensor_idx * input_shape_size_ + i];
        }

        for (size_t i = 0; i < input_shape_pad.size(); i++)
        {
            npu_input_size *= input_shape_pad[i];
        }

        if (npu_input_size > sz)
        {
            throw std::runtime_error("subgraph_op : The size of hw_in is not correct.");
        }

        if (cpu_input_size > sz)
        {
            throw std::runtime_error("subgraph_op : The size of hw_in is smaller than onnx shape.");
        }

	/*
        auto cpu_in = (OutT*)in_tensor.data;
        std::vector<InT> in_converted(cpu_input_size / sizeof(OutT), 0);
        if (input_format_[tensor_idx].find("f2bf") != std::string::npos) {
            // If input is float OutT=float, InT=int16
	    // Updated input data size and type as it is converted from float to bfloat16
            cpu_input_size = in_converted.size() * sizeof(InT);
            input_datatype_size = sizeof(InT);
            for (int i = 0; i < in_converted.size(); ++i) {
	        in_converted[i] = float_to_bfloat16_1(cpu_in[i]);
	    }
	} else {
            memcpy((void*)in_converted.data(), (void*)cpu_in, cpu_input_size);
	}
	*/

        const uint8_t* src = reinterpret_cast<const uint8_t*>(in_tensor.data);
        uint8_t* dst = reinterpret_cast<uint8_t*>(hw_in_ptr);
        if (conv_special_flag_ == "conv7x7_fold") {
            auto conv7x7_fold_start = GET_TIMESTAMP();
            auto raw_input_shape = in_tensor.shape;
            auto input_shape = reduce_shape_to_4d(raw_input_shape);
            std::vector<size_t> input_shape_nhwc = convert_nchw_shape(input_shape, input_format_[tensor_idx]);
            // fold size
            int YI = input_shape_nhwc[1];
            int XI = input_shape_nhwc[2];
            int CI = input_shape_nhwc[3];
            int CI_padded = inputShape_[3] / 3;
            int Cip_fold = inputShape_[3];
            int XI_fold =  XI / 2;
            TestConfig cfg_pad, cfg_fold;

            cfg_pad.Ci = CI_padded;
            cfg_pad.Xi = XI;
            cfg_pad.Yi = YI;
            cfg_fold.Ci = Cip_fold;
            cfg_fold.Yi = YI;
            cfg_fold.Xi = XI_fold;

            // need to do padding for in_trans matrix first.
            InT ifm_zero = 0;

            auto cpu_in = (InT*)in_tensor.data;
            auto out = (InT*)hw_in_ptr;
            fold_conv_ifm_7x7_pad<InT>(reinterpret_cast<const InT*>(cpu_in), ifm_zero, CI, cfg_pad, cfg_fold, out);
            auto conv7x7_fold_end = GET_TIMESTAMP();
	    LOG_GLOBAL_TIMING("SUBGRAPH::conv7x7_fold", conv7x7_fold_start, conv7x7_fold_end);
        }
        else {
            if (cpu_input_size == sz) {
                if (input_format_[tensor_idx].find("NCHW") != std::string::npos && (skip_nchw_conversion == "0")) {
                    auto nchw_no_pad_start = GET_TIMESTAMP();
		    // NHWC copy
                    size_t cpu_out_ch_size = 1;
                    size_t cpu_inner_ch_size = 1; // H and W dimension size
                    for (size_t i = 0; i < cpu_input_shape.size(); i++)
                    {
                        if (i < ch_idx) {
                            cpu_out_ch_size *= cpu_input_shape[i];
                        }
                        else if (i > ch_idx) {
                            cpu_inner_ch_size *= cpu_input_shape[i];
                        }
                    }
                    size_t dst_offset = 0;
                    size_t src_offset = 0;
                    for (size_t i = 0; i < cpu_out_ch_size; i++) {
                        for (size_t c = 0; c < cpu_ch_size; c++) {
                            for (size_t j = 0; j < cpu_inner_ch_size; j++) {
				dst_offset = (i * cpu_ch_size * cpu_inner_ch_size + j * cpu_ch_size + c) * input_datatype_size;
                                src_offset = (i * cpu_ch_size * cpu_inner_ch_size + c * cpu_inner_ch_size + j) * input_datatype_size;
                                memcpy((void*)(dst + dst_offset), (void*)(src + src_offset), input_datatype_size);
                            }
                        }
                    }
                    auto nchw_no_pad_end = GET_TIMESTAMP();
		    LOG_GLOBAL_TIMING("SUBGRAPH::NCHW_format_pad", nchw_no_pad_start, nchw_no_pad_end);
		}
                else {
                    // copy only
		    auto direct_copy_start = GET_TIMESTAMP();
                    memcpy((void*)hw_in_ptr, (void*)src, cpu_input_size);
		    auto direct_copy_end = GET_TIMESTAMP();
		    LOG_GLOBAL_TIMING("SUBGRAPH::output_direct_copy", direct_copy_start, direct_copy_end);
		}
            }
            else {
                if (input_format_[tensor_idx].find("NCHW") != std::string::npos && (skip_nchw_conversion == "0")) {
                    auto nchw_pad_start = GET_TIMESTAMP();
                    //NHWC + pad copy
                    // only support inner_dim padding
		    size_t pdded_ch_size = (cpu_ch_size % 8) ? ((cpu_ch_size / 8 + 1) * 8) : cpu_ch_size;
                    size_t cpu_out_ch_size = 1;
                    size_t cpu_inner_ch_size = 1; // H and W dimension size
                    for (size_t i = 0; i < cpu_input_shape.size(); i++)
                    {
                        if (i < ch_idx) {
                            cpu_out_ch_size *= cpu_input_shape[i];
                        }
                        else if (i > ch_idx) {
                            cpu_inner_ch_size *= cpu_input_shape[i];
                        }
                    }
		    size_t pdded_out_ch_size = cpu_out_ch_size * cpu_inner_ch_size;
	            size_t hw_pad_size = 1;
                    for (size_t i = 0; i < input_shape_pad.size(); i++)
                    {
	                hw_pad_size *= input_shape_pad[i];
	            }
                    if (pdded_out_ch_size * pdded_ch_size > hw_pad_size) {
                        std::cout << "Padded size: " << pdded_out_ch_size * pdded_ch_size << ", HW size: " << hw_pad_size << std::endl;
	                std::cout << "Padded channel size: " << pdded_ch_size << std::endl;
                        throw std::runtime_error("Input Tensor pad size doesn't match.");
                    }

                    size_t dst_offset = 0;
                    size_t src_offset = 0;
                    for (size_t i = 0; i < cpu_out_ch_size; i++) {
                        for (size_t c = 0; c < cpu_ch_size; c++) {
                            for (size_t j = 0; j < cpu_inner_ch_size; j++) {
                                dst_offset = (i * pdded_out_ch_size + j * pdded_ch_size + c) * input_datatype_size;
                                src_offset = (i * cpu_ch_size * cpu_inner_ch_size + c * cpu_inner_ch_size + j) * input_datatype_size;
                                memcpy((void*)(dst + dst_offset), (void*)(src + src_offset), input_datatype_size);
                            }
                        }
                    }
                    auto nchw_pad_end = GET_TIMESTAMP();
		    LOG_GLOBAL_TIMING("SUBGRAPH::NCHW_format_pad", nchw_pad_start, nchw_pad_end);
                }
                else {
                    auto pad_start = GET_TIMESTAMP();
                    // pad only
                    // only support inner_dim padding
		    size_t pdded_ch_size = (cpu_ch_size % 8) ? ((cpu_ch_size / 8 + 1) * 8) : cpu_ch_size;
                    size_t cpu_out_ch_size = 1;
                    for (size_t i = 0; i < cpu_input_shape.size(); i++)
                    {
                        if (i != ch_idx) {
                            cpu_out_ch_size *= cpu_input_shape[i];
                        }
                    }
		    size_t pdded_out_ch_size = cpu_out_ch_size;
	            size_t hw_pad_size = 1;
                    for (size_t i = 0; i < input_shape_pad.size(); i++)
                    {
	                hw_pad_size *= input_shape_pad[i];
	            }
                    if (pdded_out_ch_size * pdded_ch_size > hw_pad_size) {
                        std::cout << "Padded size: " << pdded_out_ch_size * pdded_ch_size << ", HW size: " << hw_pad_size << std::endl;
	                std::cout << "Padded channel size: " << pdded_ch_size << std::endl;
                        throw std::runtime_error("Input Tensor pad size doesn't match.");
                    }
                    size_t dst_offset = 0;
                    size_t src_offset = 0;
                    size_t copy_size = cpu_ch_size * input_datatype_size;
                    for (size_t i = 0; i < cpu_out_ch_size; i++) {
                        memcpy((void*)(dst + dst_offset), (void*)(src + src_offset), copy_size);
                        dst_offset += pdded_ch_size * input_datatype_size;
                        src_offset += copy_size;
                    }
                    auto pad_end = GET_TIMESTAMP();
		    LOG_GLOBAL_TIMING("SUBGRAPH::pad_only", pad_start, pad_end);
                }
            }
        }
        auto format_input_end = GET_TIMESTAMP();
        LOG_GLOBAL_TIMING("SUBGRAPH::format_input_total", format_input_start, format_input_end);

    }

    template class subgraph_op<int8_t, int8_t, int8_t>;
    template class subgraph_op<uint8_t, uint8_t, uint8_t>;
    template class subgraph_op<uint16_t, uint8_t, uint16_t>;
    template class subgraph_op<int16_t, int8_t, int16_t>;
    template class subgraph_op<float, uint8_t, uint16_t>;
    template class subgraph_op<uint16_t, uint8_t, float>;

} // namespace waic_runner
