// Copyright 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
////////////////////////////////////////////////////////////////////////

#pragma once
#include <vector>
#include <iostream>
#include <stdexcept>
#include <cstdint>

#include <sstream>
#include <vector>

namespace adf_utils {

struct tensor_1d
{
    // Number of elements
    uint32_t x = 1;

    // Number of dimensions in the tensor
    static constexpr unsigned dims()
    {
        return 1;
    }

    // Total number of elements
    constexpr uint32_t size() const
    {
        return x;
    }

    // Get a dimension vector that can be used in ADF APIs
    std::vector<uint32_t> to_vector() const
    {
        std::vector<uint32_t> ret(1);
        ret[0] = x;
        return ret;
    }

    std::string to_json() const
    {
        std::stringstream ss;
        ss << "{\"x\": " << x  << " }";
        return ss.str();
    }

    bool operator==(const tensor_1d &other) const {
        return x == other.x;
    }
};

struct tensor_2d
{
    // Number of elements in inner dimension
    uint32_t x = 1;

    // Number of elements in outer dimension
    uint32_t y = 1;

    // Number of dimensions in the tensor
    static constexpr unsigned dims()
    {
        return 2;
    }

    // Total number of elements
    constexpr uint32_t size() const
    {
        return x * y;
    }

    // Max of the two elements
    constexpr uint32_t max() const
    {
        return x > y ? x : y;
    }

    // Get a dimension vector that can be used in ADF APIs
    std::vector<uint32_t> to_vector() const
    {
        std::vector<uint32_t> ret(2);
        ret[0] = x;
        ret[1] = y;

        return ret;
    }

    std::string to_json() const
    {
        std::stringstream ss;
        ss << "{ \"x\": " << x;
        ss << ", \"y\": " << y;
        ss << " }";
        return ss.str();
    }

    bool operator==(const tensor_2d &other) const {
        return x == other.x && y == other.y;
    }
};

// Named axis indices
enum class HCWc_dim : unsigned
{
    c_inner = 0,
    W,
    C,
    H
};

template<HCWc_dim DIM> struct HCWc_dim_traits;
template<> struct HCWc_dim_traits<HCWc_dim::c_inner>
{
    constexpr static int index = 0;
};
template<> struct HCWc_dim_traits<HCWc_dim::W>
{
    constexpr static int index = 1;
};
template<> struct HCWc_dim_traits<HCWc_dim::C>
{
    constexpr static int index = 2;
};
template<> struct HCWc_dim_traits<HCWc_dim::H>
{
    constexpr static int index = 3;
};

struct tensor_HCWc
{
    // Inner dimension in all AIE kernels, typically 8
    uint32_t c_inner = 1;

    // Width
    uint32_t W = 1;

    // Total number of channels divided by c
    uint32_t C = 1;

    // Height
    uint32_t H = 1;

    // Number of dimensions in the tensor
    static constexpr unsigned dims()
    {
        return 4;
    }

    // Total number of channels
    constexpr uint32_t channels() const
    {
        return c_inner * C;
    }

    // Total number of elements
    constexpr uint32_t size() const
    {
        return c_inner * W * C * H;
    }

    // Initialize dimensions from 1D tensor. Use lowest order dimension
    constexpr tensor_HCWc &operator=(const tensor_1d &tensor)
    {
        c_inner = tensor.x;
        W       = 1;
        C       = 1;
        H       = 1;

        return *this;
    }

    // Initialize dimensions from 1D tensor. Use lowest order dimensions
    constexpr tensor_HCWc &operator=(const tensor_2d &tensor)
    {
        c_inner = tensor.x;
        W       = tensor.y;
        C       = 1;
        H       = 1;

        return *this;
    }

    // Promote largest order dimension to the H dimension
    constexpr void promote_H() {
        if (H == 1) {
            if      (C > 1)       std::swap(C, H);
            else if (W > 1)       std::swap(W, H);
            else if (c_inner > 1) std::swap(c_inner, H);
        }
    };

    // Get a dimension vector that can be used in ADF APIs
    std::vector<uint32_t> to_vector() const
    {
        std::vector<uint32_t> ret(4);
        ret[0] = c_inner;
        ret[1] = W;
        ret[2] = C;
        ret[3] = H;
        return ret;
    }

    const uint32_t &operator[](unsigned int index) const
    {
        return const_cast<tensor_HCWc *>(this)->operator[](index);
    }

    constexpr bool valid() const
    {
        return c_inner > 0 && W > 0 && C > 0 && H > 0;
    }

    constexpr void check_consistency() const
    {
        if (!valid())
        {
            throw std::runtime_error("tensor_HCWc dimensions cannot be zero");
        }
    }

    uint32_t &operator[](HCWc_dim dim)
    {
        if (dim == HCWc_dim::c_inner) {
            return c_inner;
        }
        else if (dim == HCWc_dim::W) {
            return W;
        }
        else if (dim == HCWc_dim::C) {
            return C;
        }
        else if (dim == HCWc_dim::H) {
            return H;
        }
        else {
            throw std::invalid_argument("Invalid dimension!");
        }
    }

    uint32_t operator[](HCWc_dim dim) const
    {
        return const_cast<tensor_HCWc *>(this)->operator [](dim);
    }

    template<HCWc_dim DIM>
    uint32_t &get()
    {
        if constexpr(DIM == HCWc_dim::c_inner) {
            return c_inner;
        }
        else if constexpr(DIM == HCWc_dim::W) {
            return W;
        }
        else if constexpr(DIM == HCWc_dim::C) {
            return C;
        }
        else if constexpr(DIM == HCWc_dim::H) {
            return H;
        }
    }

    template<HCWc_dim DIM>
    uint32_t get() const
    {
        return const_cast<tensor_HCWc *>(this)->template get<DIM>();
    }

    std::string to_json() const
    {
        std::stringstream ss;
        ss << "{ \"c_inner\": " << c_inner;
        ss << ", \"W\": "       << W;
        ss << ", \"C\": "       << C;
        ss << ", \"H\": "       << H;
        ss << " }";
        return ss.str();
    }

    bool operator==(const tensor_HCWc &other) const {
        return c_inner == other.c_inner
            && W == other.W
            && C == other.C
            && H == other.H;
    }
};

// Signed version of tensor_HCWc, to be used for offsets
struct offset_HCWc
{
    // Inner dimension in all AIE kernels
    int32_t c_inner = 0;

    // Width
    int32_t W = 0;

    // Total number of channels divided by c
    int32_t C = 0;

    // Height
    int32_t H = 0;

    // Number of dimensions in the tensor
    static constexpr unsigned dims()
    {
        return 4;
    }

    // Get a dimension vector that can be used in ADF APIs
    std::vector<int32_t> to_vector() const
    {
        std::vector<int32_t> ret(4);
        ret[0] = c_inner;
        ret[1] = W;
        ret[2] = C;
        ret[3] = H;
        return ret;
    }

    bool is_zero() const
    {
        return c_inner == 0 && W == 0 && C == 0 && H == 0;
    }

    int32_t &operator[](HCWc_dim dim)
    {
        if (dim == HCWc_dim::c_inner) {
            return c_inner;
        }
        else if (dim == HCWc_dim::W) {
            return W;
        }
        else if (dim == HCWc_dim::C) {
            return C;
        }
        else if (dim == HCWc_dim::H) {
            return H;
        }
        else {
            throw std::invalid_argument("Invalid dimension!");
        }
    }
};

// Padding information from ONNX. This determines when zero-padding is required. Order is the same as in ONNX.
struct image_padding
{
    unsigned top;

    unsigned left;

    unsigned bottom;

    unsigned right;

    //! @note Returns true if padding is non-zero on any edge
    bool is_zero() const
    {
        return (left != 0 || top != 0 || right != 0 || bottom != 0);
    }
};

// Filter information for convolutions.
struct filter_2d
{
    unsigned size_W = 1;

    unsigned size_H = 1;

    unsigned stride_W = 1;

    unsigned stride_H = 1;

    constexpr unsigned halo_W() const
    {
        return (size_W - 1) / 2;
    }

    constexpr unsigned halo_H() const
    {
        return (size_H - 1) / 2;
    }

    constexpr bool valid() const
    {
        return size_W > 0 && size_H > 0 && stride_W > 0 && stride_H > 0;
    }

    constexpr void check_consistency() const
    {
        if (!valid())
        {
            throw std::runtime_error("filter2d dimensions cannot be zero");
        }
    }
};

// Error thrown when concatenation inputs' dimensions are not compatible
class concat_error : public std::domain_error
{
public:
    using std::domain_error::domain_error;
};

class invalid_axis final : public concat_error
{
private:
    static std::string message(unsigned axis, unsigned max) {
        std::stringstream ss;
        ss << "Invalid axis " << axis << ". Accepted values: [0, " << max << ")";
        return ss.str();
    }
public:
    invalid_axis(unsigned axis, unsigned max)
        : concat_error(message(axis, max))
    {
    }
};

class incompatible_tensors final : public concat_error
{
private:
    template <typename Tensor>
    static std::string message(unsigned axis, const Tensor &lhs, const Tensor &rhs);

public:
    template <typename Tensor>
    explicit incompatible_tensors(unsigned axis, const Tensor &lhs, const Tensor &rhs)
        : concat_error(message(axis, lhs, rhs))
    {
    }
};

// Concatenates two tensor_1d
static constexpr tensor_1d concat(const tensor_1d &lhs, const tensor_1d &rhs)
{
    return tensor_1d{lhs.x + rhs.x};
}

// Concatenates two tensor_2d over a specific axis
static constexpr tensor_2d concat(const tensor_2d &lhs, const tensor_2d &rhs, unsigned axis)
{
    if (axis >= tensor_2d::dims())
        throw invalid_axis(axis, tensor_2d::dims());

    auto get = [](auto &t, unsigned d) -> auto & { return d == 0 ? t.x : t.y; };

    tensor_2d result = lhs;
    for (unsigned d = 0; d < 2; ++d)
    {
        if (d == axis) {
            get(result, d) += get(rhs, d);
        }
        else if (get(lhs, d) != get(rhs, d)) {
            throw incompatible_tensors(d, lhs, rhs);
        }
    }
    return result;
}

// Concatenates two tensors over a specific axis
static constexpr tensor_HCWc concat(const tensor_HCWc &lhs,
                                    const tensor_HCWc &rhs,
                                    unsigned axis)
{
    if (axis >= tensor_HCWc::dims())
        throw invalid_axis(axis, tensor_HCWc::dims());

    auto get = [](auto &t, unsigned d) -> auto & {
        return d == 0 ? t.c_inner : d == 1 ? t.W : d == 2 ? t.C : t.H;
    };

    tensor_HCWc result = lhs;
    for (unsigned d = 0; d < tensor_HCWc::dims(); ++d)
    {
        if (axis == d) {
            get(result, d) += get(rhs, d);
        }
        else if (get(lhs, d) != get(rhs, d)) {
            throw incompatible_tensors(d, lhs, rhs);
        }
    }
    return result;
}

// Concatenates two tensors over a specific axis
static constexpr tensor_HCWc concat(const tensor_HCWc &lhs,
                                    const tensor_HCWc &rhs,
                                    HCWc_dim axis)
{
    return concat(lhs, rhs, static_cast<unsigned>(axis));
}

// Construct HCWc from a 1D tensor. Use lowest order dimension
static constexpr tensor_HCWc make_tensor_HCWc(const tensor_1d &tensor)
{
    return tensor_HCWc{
        .c_inner = tensor.x,
        .W       = 1,
        .C       = 1,
        .H       = 1
    };
}

// Construct HCWc from a 2D tensor. Use lowest order dimensions
static constexpr tensor_HCWc make_tensor_HCWc(const tensor_2d &tensor)
{
    return tensor_HCWc{
        .c_inner = tensor.x,
        .W       = tensor.y,
        .C       = 1,
        .H       = 1
    };
}

static std::ostream &operator<<(std::ostream &os, const tensor_1d &dim)
{
    os << "{ .x = " << dim.x;
    os << " }";
    return os;
}
static std::ostream &operator<<(std::ostream &os, const tensor_2d &dim)
{
    os << "{ .x = " << dim.x;
    os << ", .y = " << dim.y;
    os << " }";
    return os;
}

static std::ostream &operator<<(std::ostream &os, const tensor_HCWc &dim)
{
    os << "{ .c_inner = " << dim.c_inner;
    os << ", .W = "       << dim.W;
    os << ", .C = "       << dim.C;
    os << ", .H = "       << dim.H;
    os << " }";
    return os;
}

static std::ostream &operator<<(std::ostream &os, const offset_HCWc &dim)
{
    os << "{ .c_inner = " << dim.c_inner;
    os << ", .W = "       << dim.W;
    os << ", .C = "       << dim.C;
    os << ", .H = "       << dim.H;
    os << " }";
    return os;
}

template <typename Tensor>
std::string incompatible_tensors::message(unsigned axis,
                                          const Tensor &lhs,
                                          const Tensor &rhs)
{
    std::stringstream ss;
    ss << "Tensor sizes don't match along axis " << axis
       << ".\nA: " << lhs << "\nB: " << rhs;
    return ss.str();
}

} // namespace adf_utils
