#pragma once

#include "op_common.hpp"

namespace waic_runner {
    template <typename T>
    int check_result(T* golden, T* result, int size, int err_tolerance = 2,
        bool enable_logging = false) {
        int err_count = 0;
        int max_diff = 0;
        float L2_norm = 0;
        float max_relative_diff = 0;
        for (int i = 0; i < size / sizeof(T); ++i) {
            int diff = std::abs(golden[i] - result[i]);
            int act_diff = golden[i] - result[i];
            float relative_diff = (float(diff)) / (golden[i] == 0 ? (0.00001) : (float(golden[i])));
            L2_norm += diff * diff;
            if (diff > max_diff) {
                max_diff = diff;
            }
            if (relative_diff > max_relative_diff) {
                max_relative_diff = relative_diff;
            }
            if (diff > err_tolerance) {
                if (enable_logging) {
                    std::cout << "ERROR: Y[" << i << "]: "
                        << "Expected: " << int(golden[i]) << ", "
                        << "Received: " << int(result[i]) << ", "
                        << "Relative Diff %: " << relative_diff * 100 << ", "
                        << "Diff: " << int(act_diff) << "\n";
                }
                err_count++;
            }
            else {
                if (enable_logging) {
                    //std::cout << "PASS: Y[" << i << "]: "
                    //          << "Expected: " << int(golden[i]) << ", "
                    //          << "Received: " << int(result[i]) << ", "
                    //          << "Relative Diff %: " << relative_diff*100 << ", "
                    //          << "Diff: " << int(diff) << "\n";
                }
            }
        }
        L2_norm = std::sqrt(L2_norm);
        std::cout << "max_diff is " << max_diff << std::endl;
        std::cout << "max_relative_diff % is " << max_relative_diff * 100 << std::endl;
        std::cout << "L2_norm is " << L2_norm << std::endl;
        std::cout << "L2_norm per element is " << L2_norm / size * sizeof(T) << std::endl;
        return err_count;
    }

    template <typename T>
    int check_result_float(T* golden, T* result, int size, float err_tolerance = 0.01,
        bool enable_logging = false) {
        int err_count = 0;
        float max_diff = 0;
        float L2_norm = 0;
        float max_relative_diff = 0;
        for (int i = 0; i < size / sizeof(T); ++i) {
            float diff = std::abs(golden[i] - result[i]);
            float relative_diff = (float(diff)) / (golden[i] < 0.0000001 ? (0.0000001) : (float(golden[i])));
            L2_norm += diff * diff;
            if (diff > max_diff) {
                max_diff = diff;
            }
            if (relative_diff > max_relative_diff) {
                max_relative_diff = relative_diff;
            }
            if (diff > err_tolerance) {
                if (enable_logging) {
                    std::cout << "ERROR: Y[" << i << "]: "
                        << "Expected: " << (golden[i]) << ", "
                        << "Received: " << (result[i]) << ", "
                        << "Relative Diff %: " << relative_diff * 100 << ", "
                        << "Diff: " << (diff) << "\n";
                }
                err_count++;
            }
            else {
                if (enable_logging) {
                    //std::cout << "PASS: Y[" << i << "]: "
                    //          << "Expected: " << int(golden[i]) << ", "
                    //          << "Received: " << int(result[i]) << ", "
                    //          << "Relative Diff %: " << relative_diff*100 << ", "
                    //          << "Diff: " << int(diff) << "\n";
                }
            }
        }
        L2_norm = std::sqrt(L2_norm);
        std::cout << "max_diff is " << max_diff << std::endl;
        std::cout << "max_relative_diff % is " << max_relative_diff * 100 << std::endl;
        std::cout << "L2_norm is " << L2_norm << std::endl;
        std::cout << "L2_norm per element is " << L2_norm / size * sizeof(T) << std::endl;
        return err_count;
    }

    void compare_results(const Metadata& meta, const std::vector<Tensor>& out_tensors, bool debug_flag=false);
}
