
#include "debug_utils.hpp"

#include "meta_utils.hpp"
#include "utils.hpp"
namespace waic_runner {
    void compare_results(const Metadata& meta, const std::vector<Tensor>& out_tensors, bool debug_flag) {
        std::vector<std::string> out_paths = MetaUtils::get_output_files(meta);
        auto out_ops = MetaUtils::get_output_info(meta);
        bool output_files_exist = true;
        for (auto& out_path : out_paths) {
            if (!std::filesystem::exists(out_path)) {
                output_files_exist = false;
                break;
            }
        }
        
        int err_count = 0;
        if (output_files_exist) {
            int ix = 0;
            for (auto& tensor : out_tensors) {
                err_count = 0;
                size_t sz = std::accumulate(tensor.shape.begin(), tensor.shape.end(),
                    size_t{ 1 }, std::multiplies{}) *
                    get_size_of_type(tensor.dtype);
                std::cout << "output tensor size is " << sz << std::endl;
                std::vector<int8_t> out(sz, 0);
                read_bin_file(out_paths[ix++], reinterpret_cast<char*>(out.data()));

                std::cout << "Result check for output " << ix << std::endl;
                if (tensor.dtype.find("uint8") != std::string::npos) {
                    err_count = check_result(reinterpret_cast<uint8_t*>(out.data()),
                        reinterpret_cast<uint8_t*>(tensor.data),
                        (int)out.size(), 2, debug_flag);
                }
                else if (tensor.dtype.find("int8") != std::string::npos) {
                    err_count = check_result(reinterpret_cast<int8_t*>(out.data()),
                        reinterpret_cast<int8_t*>(tensor.data),
                        (int)out.size(), 2, debug_flag);
                }
                else if (tensor.dtype.find("uint16") != std::string::npos) {
                    err_count = check_result(reinterpret_cast<uint16_t*>(out.data()),
                        reinterpret_cast<uint16_t*>(tensor.data),
                        (int)out.size(), 2, debug_flag);
                }
                else if (tensor.dtype.find("int16") != std::string::npos) {
                    err_count = check_result(reinterpret_cast<int16_t*>(out.data()),
                        reinterpret_cast<int16_t*>(tensor.data),
                        (int)out.size(), 2, debug_flag);
                }
                else if (tensor.dtype.find("float") != std::string::npos) {
                    err_count = check_result_float(reinterpret_cast<float*>(out.data()),
                        reinterpret_cast<float*>(tensor.data),
                        (int)out.size(), 0.001, debug_flag);
                }
                else {
                    throw std::runtime_error("Not Valid datatype");
                }

                std::cout << "Error Count is " << err_count << std::endl;
            }
        }
        else {
            std::cout << "NO Valid output file" << std::endl;
        }
    }
}
