#include "meta_utils.hpp"
#include <sstream>
namespace waic_runner {
    std::string convert_argtype_to_string(OpArgMap::OpArgType arg_type) {

        std::string arg;
        switch (arg_type) {
        case OpArgMap::OpArgType::INPUT:
            arg = "in_onnx";
            break;
        case OpArgMap::OpArgType::OUTPUT:
            arg = "out_onnx";
            break;
        case OpArgMap::OpArgType::SCRATCH_PAD:
            arg = "scratch";
            break;
        case OpArgMap::OpArgType::CONST_INPUT:
            arg = "const";
            break;
        case OpArgMap::OpArgType::CONST_KERNEL_PARAM_INPUT:
            arg = "super_instr";
            break;
        case OpArgMap::OpArgType::CTRL_PKT_BIN:
            arg = "ctrl_pkt";
            break;
        default:
            throw std::runtime_error("Invalide arg_type conversion to string");
            break;
        }

        return arg;
    }

    std::string MetaUtils::get_partition_pmid(const Metadata& meta) {
        std::ostringstream oss;
        oss << "Summary of partitions\n";
        oss << "Total number of partitions : " << meta.partitions.size() << "\n";
        // list partitions
        std::vector<uint16_t> pmid_list(meta.partitions.size());
        for (size_t part = 0; part < meta.partitions.size(); part++) {
            const auto& partition = meta.partitions.at(part);
            pmid_list[part] = partition.pdi_id;
            oss << "range: " << partition.op_range.first
                << ": " << partition.op_range.second
                << ", " << "pm_id: " << pmid_list[part] << "\n";
        }
        std::sort(pmid_list.begin(), pmid_list.end());
        auto last_unique = std::unique(pmid_list.begin(), pmid_list.end());
        pmid_list.erase(last_unique, pmid_list.end());
        oss << "Unique pm_ids: " << "\n";
        for (uint16_t x : pmid_list) {
            oss << x << ", ";
        }
        oss << "\n";

        return oss.str();
    }

    std::string MetaUtils::get_summary(const Metadata& meta) {
        std::ostringstream oss;
        oss << "Summary of Metadata\n";
        oss << "-------------------\n";
        oss << "Total number of Ops : " << meta.op_list.size() << "\n";
        // OpCount
        std::map<std::string, size_t> op_count;
        for (const auto& op : meta.op_list) {
            op_count[op.type]++;
        }
        for (const auto& [op_type, cnt] : op_count) {
            oss << "  #" << op_type << " : " << cnt << "\n";
        }

        // MemReqs
        size_t total_mem = std::accumulate(
            meta.fused_tensors.begin(), meta.fused_tensors.end(), size_t{ 0 },
            [](size_t accum, const auto& item) { return accum + item.second.size; });

        oss << "\n";
        oss << "Total Device Memmory (B) : " << total_mem << "\n";
        oss << "  Input Memory (B) : " << meta.fused_tensors.at("in").size << "\n";
        oss << "  Output Memory (B) : " << meta.fused_tensors.at("out").size << "\n";
        oss << "  Scratch Memory (B) : " << meta.fused_tensors.at("scratch").size
            << "\n";
        oss << "  Const Memory (B) : " << meta.fused_tensors.at("const").size << "\n";
        oss << "  SuperKernel Memory (B) : "
            << meta.fused_tensors.at("super_instr").size << "\n";
        oss << "  Control packets Memory (B) : "
            << meta.fused_tensors.at("ctrl_pkt").size << "\n";
        oss << "-------------------\n";

        return oss.str();
    }

    size_t MetaUtils::get_num_inputs(const Metadata& meta) {
        return MetaUtils::get_num_tensors(meta, OpArgMap::OpArgType::INPUT);
    }

    size_t MetaUtils::get_num_outputs(const Metadata& meta) {
        return MetaUtils::get_num_tensors(meta, OpArgMap::OpArgType::OUTPUT);
    }

    std::vector<Tensor> MetaUtils::get_input_tensors(const Metadata& meta) {
        return MetaUtils::get_tensors(meta, OpArgMap::OpArgType::INPUT);
    }

    std::vector<std::string> MetaUtils::get_input_files(const Metadata& meta) {
        return MetaUtils::get_file_names(meta, OpArgMap::OpArgType::INPUT);
    }

    std::vector<Metadata::OpInfo> MetaUtils::get_input_info(const Metadata& meta) {
        return MetaUtils::get_op_info(meta, OpArgMap::OpArgType::INPUT);
    }

    std::vector<Tensor> MetaUtils::get_output_tensors(const Metadata& meta) {
        return MetaUtils::get_tensors(meta, OpArgMap::OpArgType::OUTPUT);
    }

    std::vector<std::string> MetaUtils::get_output_files(const Metadata& meta) {
        return MetaUtils::get_file_names(meta, OpArgMap::OpArgType::OUTPUT);
    }

    std::vector<Metadata::OpInfo> MetaUtils::get_output_info(const Metadata& meta) {
        return MetaUtils::get_op_info(meta, OpArgMap::OpArgType::OUTPUT);
    }

    std::vector<Tensor> MetaUtils::get_const_tensors(const Metadata& meta) {
        return MetaUtils::get_tensors(meta, OpArgMap::OpArgType::CONST_INPUT);
    }

    // nhwc to nchw conversion
    static std::vector<size_t> update_tensor_shape(const Metadata::OffsetInfo& tensor) {
        std::vector<size_t> shape_nchw;
        if (tensor.format.find("NCHW") != std::string::npos) {
            if (tensor.shape.size() == 2) {
                shape_nchw.resize(2);
                shape_nchw.at(0) = tensor.shape.at(1);
                shape_nchw.at(1) = tensor.shape.at(0);
            }
            else if (tensor.shape.size() == 3) {
                shape_nchw.resize(3);
                shape_nchw.at(0) = tensor.shape.at(0);
                shape_nchw.at(1) = tensor.shape.at(2);
                shape_nchw.at(2) = tensor.shape.at(1);
            }
            else {
                size_t len_shape = tensor.shape.size();
                shape_nchw.resize(len_shape);
                for (size_t i = 0; i < len_shape - 3; i++) {
                    shape_nchw.at(i) = tensor.shape.at(i);
                }
                shape_nchw.at(len_shape - 3) = tensor.shape.at(len_shape - 1); // c
                shape_nchw.at(len_shape - 2) = tensor.shape.at(len_shape - 3); // H
                shape_nchw.at(len_shape - 1) = tensor.shape.at(len_shape - 2); // W
            }
        }
        else {
            for (size_t i = 0; i < tensor.shape.size(); i++) {
                shape_nchw.push_back(tensor.shape.at(i));
            }
        }
        return shape_nchw;
    }

    std::vector<Tensor> MetaUtils::get_tensors(const Metadata& meta,
        OpArgMap::OpArgType arg_type) {
        std::vector<Tensor> res;
        const auto tensor_name = convert_argtype_to_string(arg_type);
        for (const auto& inp : meta.fused_tensors.at(tensor_name).packed_tensors) {
            const auto& tensor = meta.tensor_map.at(inp);
            std::vector<size_t> shape_nchw = update_tensor_shape(tensor);
            
            Tensor t{/*data*/ nullptr,
                /*shape*/ shape_nchw,
                /*dtype*/ tensor.dtype };
            res.push_back(t);
        }
        return res;
    }

    std::vector<std::string>
        MetaUtils::get_file_names(const Metadata& meta, OpArgMap::OpArgType arg_type) {
        std::vector<std::string> res;
        const auto tensor_name = convert_argtype_to_string(arg_type);
        for (const auto& inp : meta.fused_tensors.at(tensor_name).packed_tensors) {
            const auto& tensor = meta.tensor_map.at(inp);
            res.push_back(tensor.file_name);
        }
        return res;
    }

    std::vector<Metadata::OpInfo>
        MetaUtils::get_op_info(const Metadata& meta, OpArgMap::OpArgType arg_type) {
        std::vector<Metadata::OpInfo> res;
        const auto tensor_name = convert_argtype_to_string(arg_type);
        auto op_list = meta.op_list;
        for (const auto& inp : meta.fused_tensors.at(tensor_name).packed_tensors) {
            for (const auto& list_item : op_list) {
                std::vector<std::string> _args;
                if (arg_type == OpArgMap::OpArgType::INPUT) {
                    _args = list_item.in_args;
                }
                else {
                    _args = list_item.out_args;
                }
                if (std::find(_args.begin(), _args.end(), inp) != _args.end()) {
                    res.push_back(list_item);
                }
            }
        }
        return res;
    }

    size_t MetaUtils::get_num_tensors(const Metadata& meta,
        OpArgMap::OpArgType arg_type) {
        const auto tensor_name = convert_argtype_to_string(arg_type);
        return meta.fused_tensors.at(tensor_name).packed_tensors.size();
    }

    std::vector<Tensor> MetaUtils::collect_op_tensors(
        const Metadata& meta, const Metadata::OpInfo& op_info,
        const std::map<std::string, void*>& const_buffer_ptrs) {
        std::vector<Tensor> tensors;
        bool enable_real_const_buffer_ptr = !const_buffer_ptrs.empty();
        auto args = get_op_args(op_info);
        for (auto& tensor_name : args) {
            const auto& tinfo = MAP_AT(meta.tensor_map, tensor_name);

            void* tensor_ptr = nullptr;
            if (enable_real_const_buffer_ptr && tinfo.parent_name == "const") {
                auto const_buffer_ptr = MAP_AT(const_buffer_ptrs, tensor_name);
                tensor_ptr = const_buffer_ptr;
            }

            tensors.push_back({ tensor_ptr, meta.tensor_map.at(tensor_name).shape,
                               meta.tensor_map.at(tensor_name).dtype });
        }
        return tensors;
    }
}
