#pragma once

#include <array>
#include <assert.h>
#include <fenv.h>
#include <math.h>
#include <stdint.h>
#include <string.h>

#include <fstream>
#include <iostream>
#include <limits>
#include <type_traits>
#include <vector>

#include "op_common.hpp"

namespace waic_runner
{
template <typename T> struct ActMatrix
{
    size_t const num_width;
    size_t const num_height;
    size_t const num_channel;
    size_t const num_frame;
    size_t const num_batch;
    T *const data;
    size_t const data_size;

    ActMatrix(size_t num_height, size_t num_width, void *data)
        : num_width(num_width), num_height(num_height), num_channel(static_cast<size_t>(1)),
          num_frame(static_cast<size_t>(1)), num_batch(static_cast<size_t>(1)), data(static_cast<T *>(data)),
          data_size(size(num_height, num_width))
    {
    }

    ActMatrix(size_t num_channel, size_t num_height, size_t num_width, void *data)
        : num_width(num_width), num_height(num_height), num_channel(num_channel), num_frame(static_cast<size_t>(1)),
          num_batch(static_cast<size_t>(1)), data(static_cast<T *>(data)),
          data_size(size(num_channel, num_height, num_width))
    {
    }

    ActMatrix(size_t num_frame, size_t num_channel, size_t num_height, size_t num_width, void *data)
        : num_width(num_width), num_height(num_height), num_channel(num_channel), num_frame(num_frame),
          num_batch(static_cast<size_t>(1)), data(static_cast<T *>(data)),
          data_size(size(num_frame, num_channel, num_height, num_width))
    {
    }

    ActMatrix(size_t num_batch, size_t num_frame, size_t num_channel, size_t num_height, size_t num_width, void *data)
        : num_width(num_width), num_height(num_height), num_channel(num_channel), num_frame(num_frame),
          num_batch(num_batch), data(static_cast<T *>(data)),
          data_size(size(num_batch, num_frame, num_channel, num_height, num_width))
    {
    }

    T &at(size_t row, size_t col)
    {
        assert(row < num_height);
        assert(col < num_width);
        size_t const idx = (row * num_width) + col;
        // assert(idx < num_height * num_width);
        return data[idx];
    }

    T &at(size_t channel, size_t row, size_t col)
    {
        assert(channel < num_channel);
        assert(row < num_height);
        assert(col < num_width);
        size_t channel_idx = channel * num_width * num_height;
        size_t const idx = channel_idx + (row * num_width) + col;
        // assert(idx < num_channel * num_height * num_width);
        return data[idx];
    }

    T &at(size_t frame, size_t channel, size_t row, size_t col)
    {
        assert(frame < num_frame);
        assert(channel < num_channel);
        assert(row < num_height);
        assert(col < num_width);
        size_t frame_idx = frame * num_channel * num_width * num_height;
        size_t channel_idx = channel * num_width * num_height;
        size_t const idx = frame_idx + channel_idx + (row * num_width) + col;
        // assert(idx < num_frame * num_channel * num_height * num_width);
        return data[idx];
    }

    T &at(size_t batch, size_t frame, size_t channel, size_t row, size_t col)
    {
        assert(batch < num_batch);
        assert(frame < num_frame);
        assert(channel < num_channel);
        assert(row < num_height);
        assert(col < num_width);
        size_t batch_idx = batch * num_frame * num_channel * num_width * num_height;
        size_t frame_idx = frame * num_channel * num_width * num_height;
        size_t channel_idx = channel * num_width * num_height;
        size_t const idx = batch_idx + frame_idx + channel_idx + (row * num_width) + col;
        // assert(idx < num_batch * num_frame * num_channel * num_height * num_width);
        return data[idx];
    }

    static size_t size(size_t num_height, size_t num_width)
    {
        return num_height * num_width * sizeof(T);
    }

    static size_t size(size_t num_channel, size_t num_height, size_t num_width)
    {
        return num_channel * num_height * num_width * sizeof(T);
    }

    static size_t size(size_t num_frame, size_t num_channel, size_t num_height, size_t num_width)
    {
        return num_frame * num_channel * num_height * num_width * sizeof(T);
    }

    static size_t size(size_t num_batch, size_t num_frame, size_t num_channel, size_t num_height, size_t num_width)
    {
        return num_batch * num_frame * num_channel * num_height * num_width * sizeof(T);
    }
};

template <typename T> struct Conv_ActTensor
{
    int const C;
    int const Y;
    int const X;
    T *const data;
    int const data_size;

    Conv_ActTensor(int C, int Y, int X, void *data)
        : C(C), Y(Y), X(X), data(static_cast<T *>(data)), data_size(size(C, Y, X))
    {
    }

    T &at(int c, int y, int x)
    {
        assert(c < C);
        assert(y < Y);
        assert(x < X);
        int idx = (y * X * C) + (x * C) + c;
        assert(idx < C * Y * X);
        return data[idx];
    }

    void print(char const *msg = nullptr)
    {
        if (msg != nullptr)
        {
            std::cout << msg;
        }
        for (int c = 0; c < C; ++c)
        {
            for (int y = 0; y < Y; ++y)
            {
                for (int x = 0; x < X; ++x)
                {
                    if (std::is_integral<T>::value)
                    {
                        std::cout << static_cast<int64_t>(at(c, y, x)) << " ";
                    }
                    else
                    {
                        std::cout << at(c, y, x) << " ";
                    }
                }
                std::cout << "\n";
            }
            std::cout << "\n";
        }
    }

    void save(std::string filename)
    {
        std::fstream file;
        file.open(filename, std::ios::out | std::ios::binary);
        if (!file.is_open())
        {
            std::cerr << "failed to open " << filename << "!\n";
            exit(1);
        }
        file.write(reinterpret_cast<char *>(data), data_size);
    }

    void init_random(int64_t min = -4, int64_t max = 4)
    {
        for (int c = 0; c < C; ++c)
        {
            for (int y = 0; y < Y; ++y)
            {
                for (int x = 0; x < X; ++x)
                {
                    // if (std::is_integral<T>::value) {
                    at(c, y, x) = (rand() % (max - min)) + min;
                    //} else {
                    //  at(c, y, x) = ((max - min) * (rand() / float(RAND_MAX))) + min;
                    //}
                }
            }
        }
    }
    static int size(int C, int Y, int X)
    {
        return C * Y * X * sizeof(T);
    }
};

static std::vector<size_t> reduce_shape_to_4d(const std::vector<size_t> &shape)
{
    if (shape.size() <= 4) {
        return shape;
    }
    std::vector<size_t> shape_4d;
    size_t outer_dim = 1;
    for (int i = 0; i <= shape.size() - 4; ++i) {
        outer_dim *= shape[i];
    }
    shape_4d.push_back(outer_dim);
    for (int i = shape.size() - 3; i < shape.size(); ++i) {
        shape_4d.push_back(shape[i]);
    }
    return shape_4d;
}

static std::vector<size_t> convert_nchw_shape(const std::vector<size_t> &tensor, const std::string &format)
{
    std::vector<size_t> shape_nhwc;
    if (format.find("NCHW") != std::string::npos)
    {
        if (tensor.size() == 2)
        {
            shape_nhwc.resize(2);
            shape_nhwc[0] = tensor[1];
            shape_nhwc[1] = tensor[0];
        }
        else if (tensor.size() == 3)
        {
            shape_nhwc.resize(3);
            shape_nhwc[0] = tensor[0];
            shape_nhwc[1] = tensor[2];
            shape_nhwc[2] = tensor[1];
        }
        else
        {
            size_t len_shape = tensor.size();
            shape_nhwc.resize(len_shape);
            for (size_t i = 0; i < len_shape - 3; i++)
            {
                shape_nhwc[i] = tensor[i];
            }
            shape_nhwc[len_shape - 3] = tensor[len_shape - 2]; // h
            shape_nhwc[len_shape - 2] = tensor[len_shape - 1]; // w
            shape_nhwc[len_shape - 1] = tensor[len_shape - 3]; // c
        }
    }
    else
    {
        for (size_t i = 0; i < tensor.size(); i++)
        {
            shape_nhwc.push_back(tensor[i]);
        }
    }
    return shape_nhwc;
}
static std::vector<size_t> convert_nchw_shape(const Tensor &tensor, const std::string &format)
{
    std::vector<size_t> shape_nhwc;
    if (format.find("NCHW") != std::string::npos)
    {
        if (tensor.shape.size() == 2)
        {
            shape_nhwc.resize(2);
            shape_nhwc.at(0) = tensor.shape.at(1);
            shape_nhwc.at(1) = tensor.shape.at(0);
        }
        else if (tensor.shape.size() == 3)
        {
            shape_nhwc.resize(3);
            shape_nhwc.at(0) = tensor.shape.at(0);
            shape_nhwc.at(1) = tensor.shape.at(2);
            shape_nhwc.at(2) = tensor.shape.at(1);
        }
        else
        {
            size_t len_shape = tensor.shape.size();
            shape_nhwc.resize(len_shape);
            for (size_t i = 0; i < len_shape - 3; i++)
            {
                shape_nhwc.at(i) = tensor.shape.at(i);
            }
            shape_nhwc.at(len_shape - 3) = tensor.shape.at(len_shape - 2); // h
            shape_nhwc.at(len_shape - 2) = tensor.shape.at(len_shape - 1); // w
            shape_nhwc.at(len_shape - 1) = tensor.shape.at(len_shape - 3); // c
        }
    }
    else
    {
        for (size_t i = 0; i < tensor.shape.size(); i++)
        {
            shape_nhwc.push_back(tensor.shape.at(i));
        }
    }
    return shape_nhwc;
}
} // namespace waic_runner
