#include <cassert>
#include <cstdint>
#include <cmath>
#include <iostream>
#include <type_traits>
#include "common.hpp"

inline std::size_t idx_nyxc(int n, int y, int x, int c, int Y, int X, int C)
{
    // Flat index for layout: (((n * Y + y) * X + x) * C + c)
    return (static_cast<std::size_t>(n) * Y + static_cast<std::size_t>(y)) * X * C
         + static_cast<std::size_t>(x) * C
         + static_cast<std::size_t>(c);
}

/*
 * cmp_tensor_nyxc<T>
 *
 * Supported element types:
 *   - Integral types (e.g., int8_t/uint8_t/int16_t/uint16_t/int32_t/uint32_t, etc.)
 *   - bfloat16_t (compared by converting to float via bfloat16_to_float(), rounding, then int64 diff)
 *
 * Layout assumed: N-Y-X-C (contiguous, C is innermost).
 */
template <typename T>
inline int cmp_tensor_nyxc(
    const void* expected_void,
    const void* received_void,
    int N, int Y, int X, int C,
    int64_t epsilon
)
{
    static_assert(std::is_integral<T>::value || is_float16_type_v<T>,
                  "cmp_tensor_nyxc<T>: T must be an integral type or float16 type.");

    assert(expected_void && received_void);
    assert(N >= 0 && C >= 0 && X >= 0 && Y >= 0);
    assert(epsilon >= 0);

    const T* expected = static_cast<const T*>(expected_void);
    const T* received = static_cast<const T*>(received_void);

    int err_count = 0;

    for (int n = 0; n < N; ++n) {
        for (int y = 0; y < Y; ++y) {
            for (int x = 0; x < X; ++x) {
                for (int c = 0; c < C; ++c) {

                    const std::size_t i = idx_nyxc(n, y, x, c, Y, X, C);

                    int64_t cpu_val = 0;
                    int64_t aie_val = 0;

                    if constexpr (is_float16_type_v<T>) {
                        cpu_val = static_cast<int64_t>(
                            std::round(Float16Traits<T>::to_float(expected[i]))
                        );
                        aie_val = static_cast<int64_t>(
                            std::round(Float16Traits<T>::to_float(received[i]))
                        );
                    } else {
                        cpu_val = static_cast<int64_t>(expected[i]);
                        aie_val = static_cast<int64_t>(received[i]);
                    }

                    int64_t diff = cpu_val - aie_val;
                    if (diff < 0) diff = -diff;

                    if (diff > epsilon) {
                        ++err_count;
                        std::cout << "ERROR: [N:" << n
                                  << " Y:" << y
                                  << " X:" << x
                                  << " C:" << c << "] "
                                  << "Expected: " << cpu_val
                                  << " Received: " << aie_val
                                  << " Diff: " << diff << "\n";
                    }
                }
            }
        }
    }

    return err_count;
}

/*
 * print_tensor_nyxc<T>
 *
 * Supported element types:
 *   - Integral types (int8_t/uint8_t/int16_t/uint16_t/int32_t/uint32_t, etc.)
 *   - bfloat16_t (printed by converting to float via bfloat16_to_float(), rounding)
 *
 * Layout assumed: N-Y-X-C (contiguous, C is innermost).
 */
template <typename T>
inline void print_tensor_nyxc(
    const void* tensor_void,
    int N, int C, int X, int Y,
    int max_print = -1   // -1 = print all elements
)
{
    static_assert(std::is_integral<T>::value || is_float16_type_v<T>,
                  "print_tensor_nyxc<T>: T must be an integral type or float16 type.");

    assert(tensor_void);
    assert(N >= 0 && C >= 0 && X >= 0 && Y >= 0);

    const T* tensor = static_cast<const T*>(tensor_void);

    int printed = 0;

    for (int n = 0; n < N; ++n) {
        std::cout << "N = " << n << "\n";
        for (int y = 0; y < Y; ++y) {
            for (int x = 0; x < X; ++x) {

                std::cout << "  [Y:" << y << " X:" << x << "] ";

                for (int c = 0; c < C; ++c) {
                    const std::size_t idx =
                        (static_cast<std::size_t>(n) * Y + y) * X * C +
                        static_cast<std::size_t>(x) * C +
                        static_cast<std::size_t>(c);

                    int64_t val = 0;
                    if constexpr (is_float16_type_v<T>) {
                        val = static_cast<int64_t>(
                            std::round(Float16Traits<T>::to_float(tensor[idx]))
                        );
                    } else {
                        val = static_cast<int64_t>(tensor[idx]);
                    }

                    std::cout << val;
                    if (c + 1 < C) std::cout << ", ";

                    if (max_print >= 0 && ++printed >= max_print) {
                        std::cout << "\n... (truncated)\n";
                        return;
                    }
                }
                std::cout << "\n";
            }
        }
    }
}
