#include "op_utils.hpp"
namespace waic_runner {
    static auto OpArgMapLT = [](const OpArgMap& lhs, const OpArgMap& rhs) {
        return lhs.xrt_arg_idx < rhs.xrt_arg_idx;
        };

    // Input argmap contains multiple args with different xrt_arg_ids.
    // Partition it to multiple slots based on each xrt_arg_id
    // And sort each partition for binary search.
    std::vector<std::vector<OpArgMap>>
        partition_argmap(const std::vector<OpArgMap>& arg_map) {
        std::vector<std::vector<OpArgMap>> res;
        if (arg_map.size() == 0) {
            // std::cout << "Operator with arg_map size 0, skipping partition_argmap" 
            //     << std::endl;
            return res;
        }
        auto max_xrt_arg_id =
            *std::max_element(arg_map.begin(), arg_map.end(), OpArgMapLT);
        for (size_t i = 0; i <= max_xrt_arg_id.xrt_arg_idx; ++i) {
            std::vector<OpArgMap> args;
            std::copy_if(arg_map.begin(), arg_map.end(), std::back_inserter(args),
                [i](const OpArgMap& arg) { return arg.xrt_arg_idx == i; });
            std::sort(args.begin(), args.end(), OpArgMapLT);
            res.push_back(std::move(args));
        }
        return res;
    }

    // Given an offset and xrt_arg_id, find the block(OpArg) in partition to which
    // the offset belongs to. Returns reference to the corresponding OpArg
    const OpArgMap& find_op_arg(const std::vector<std::vector<OpArgMap>>& argmaps,
        size_t xrt_arg_id, size_t offset) {
        const auto& partition = argmaps.at(xrt_arg_id);
        auto iter = std::lower_bound(
            partition.begin(), partition.end(), offset,
            [](const OpArgMap& lhs, size_t val) { return lhs.offset <= val; });
        
        size_t idx = std::distance(partition.begin(), iter);

        return argmaps.at(xrt_arg_id).at(idx - 1);
    }
    
    std::vector<size_t> get_runtime_in_offset(const std::map<std::string, std::any>& attr) {
        std::vector<size_t> input_offset{};
        if (attr.count("runtime_in_offset") &&
            attr.at("runtime_in_offset").type() == typeid(std::vector<int>)) {
            const auto &input_offset_vec = std::any_cast<const std::vector<int> &>(attr.at("runtime_in_offset"));
            input_offset.resize(input_offset_vec.size());
            for (size_t i = 0; i < input_offset_vec.size(); i++) {
                input_offset[i] = input_offset_vec[i];
            }
        }

        return (input_offset);
    }
    
    std::vector<size_t> get_runtime_out_offset(const std::map<std::string, std::any>& attr) {
        std::vector<size_t> output_offset{};
        if (attr.count("runtime_out_offset") &&
            attr.at("runtime_out_offset").type() == typeid(std::vector<int>)) {
            const auto &output_offset_vec = std::any_cast<const std::vector<int> &>(attr.at("runtime_out_offset"));
            output_offset.resize(output_offset_vec.size());
            for (size_t i = 0; i < output_offset_vec.size(); i++) {
                output_offset[i] = output_offset_vec[i];
            }
        }

        return (output_offset);
    }

}
