#ifndef TENSOR_HPP
#define TENSOR_HPP

#include <assert.h>
#include <stdlib.h>
#include <iostream>
#include <random>

uint16_t float_to_bfloat16(float x)
{
    uint32_t i;
    uint8_t* src = (uint8_t*) &x;
    uint8_t* tmp = (uint8_t*) &i;
    // copy float to uint32_t
    tmp[0] = src[0];
    tmp[1] = src[1];
    tmp[2] = src[2];
    tmp[3] = src[3];
    // round to nearest even
    uint32_t lsb = (i >> 16) & 0x1;
    uint32_t bias = 0x7fff + lsb;
    i += bias;
    // extract upper half of input
    uint16_t y = uint16_t(i >> 16);
    return y;
}

float bfloat16_to_float(uint16_t x)
{
    float y = 0.0;
    uint8_t* src = (uint8_t*) &x;
    uint8_t* dst = (uint8_t*) &y;
    dst[2] = src[0];
    dst[3] = src[1];
    return y;
}


template<typename T>
struct Tensor
{
    int const num_rows;
    int const num_cols;
    int const num_channels;
    T* const data;

    Tensor(int num_rows, int num_cols, int num_channels, void* data)
        : num_rows(num_rows)
        , num_cols(num_cols)
        , num_channels(num_channels)
        , data(static_cast<T*>(data))
    {

    }

    T& at(int row, int col, int channel)
    {
        assert(row < num_rows);
        assert(col < num_cols);
        assert(channel < num_channels);
        int const idx = (row * num_cols * num_channels) + col * num_channels + channel;
        return data[idx];
    }

    static int size(int num_rows, int num_cols)
    {
        return num_rows * num_cols * sizeof(T);
    }
};


template<typename T>
void rand_tensor(T tensor, int min= -32768, int max = 32767, int fixed=0)
{
    std::default_random_engine gen;
    std::uniform_int_distribution<int16_t> ui_distribution(min, max);
     for(int i = 0; i < tensor.num_rows; ++i) {
        for (int j = 0; j < tensor.num_cols; ++j) {
            for (int k = 0; k < tensor.num_channels; ++k) {
                if (fixed)
                    tensor.at(i, j, k) = ui_distribution(gen);
                else
                    tensor.at(i, j, k) = float_to_bfloat16(2.0 * (rand() / (float) RAND_MAX) - 1.0);
            }
        }
    }
    
}


template<typename T>
void read_model_tensor(T tensor, std::string path)
{
    std::ifstream ifm(path);
    std::string input_line;

    int h = 0;
    int w = 0;
    int c = 0;


    while(std::getline(ifm, input_line))
    {
        int16_t curr = std::stoi(input_line);
        tensor.at(h, w, c) = curr;
        w += 1;
      

        if (h == (tensor.num_rows - 1) && w == tensor.num_cols)
        {
            h = 0;
            w = 0;
            c += 1;
        } else if (w == tensor.num_cols)
        {
            w = 0;
            h += 1;
        }
    }
}


template<typename T>
void print_matrix(T tensor, const char* msg = nullptr, int fixed=0)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < tensor.num_rows; ++i) {
        for (int j = 0; j < tensor.num_cols; ++j) {
            for (int k = 0; k < tensor.num_channels; ++k)
            {
                if (fixed)
                    std::cout << static_cast<int64_t>(tensor.at(i, j, k)) << " ";
                else
                    std::cout << bfloat16_to_float(tensor.at(i, j, k)) << " ";

            }
        }
        std::cout << "\n";
    }
}



template<typename T>
void cpu_nni(
    Tensor<T> input_tensor,
    Tensor<T> output_tensor,
    int num_interpolations
)
{

    for (int h = 0; h < input_tensor.num_rows * num_interpolations; ++h)
    {
        for (int w = 0; w < input_tensor.num_cols * num_interpolations; ++w)
        {
            for (int c = 0; c < input_tensor.num_channels; ++c)
            {
                int i, j, k;
                if (h % num_interpolations != 0 && w % num_interpolations == 0)
                {
                    i = h - 1;
                    j = w;
                    k = c;
                }
                else if (h % num_interpolations == 0 && w % num_interpolations != 0) 
                {
                    i = h;
                    j = w - 1;
                    k = c;
                }
                else if (h % num_interpolations != 0 && w % num_interpolations != 0)
                {
                    i = h - 1;
                    j = w - 1;
                    k = c;
                }
                else 
                {
                    i = h;
                    j = w;
                    k = c;
                }
                i = i / num_interpolations;
                j = j / num_interpolations;
                k = k;
                output_tensor.at(h, w, c) = input_tensor.at(i, j, k);
            }
        }

    }
}


template<typename T>
int check_result(
    Tensor<T> cpu_Y,
    Tensor<T> aie_Y,
    int fixed=0
    )
{
    int fail = 0;
    for (int h = 0; h < cpu_Y.num_rows; ++h) {
        for (int w = 0; w < cpu_Y.num_cols; ++w) {
            for (int c = 0; c < cpu_Y.num_channels; ++c) {
                if (fixed) {
                    int32_t diff = (int32_t)cpu_Y.at(h, w, c) - (int32_t)aie_Y.at(h, w, c);
                    if (diff != 0) {
                        std::cout << "ERROR: Y[" << h << ", " << w << "," << c << "]: "
                            << "Expected: " << static_cast<int64_t>(cpu_Y.at(h, w, c)) << ", "
                            << "Received: " << static_cast<int64_t>(aie_Y.at(h, w, c)) << "\n";
                        fail = 1;
                    } else {
                        std::cout << "PASS: Y[" << h << ", " << w << "," << c << "]: "
                            << "Expected: " << static_cast<int64_t>(cpu_Y.at(h, w, c)) << ", "
                            << "Received: " << static_cast<int64_t>(aie_Y.at(h, w, c)) << "\n";
                    }
                }
                else {
                    float diff = std::abs( bfloat16_to_float(cpu_Y.at(h, w, c)) - bfloat16_to_float(aie_Y.at(h, w, c)) );
                    if (diff >= 1e-2) {
                        printf("ERROR: Y[%d][%d][%d] Expected: %f, Received: %f \n", h, w, c, bfloat16_to_float(cpu_Y.at(h, w, c)), bfloat16_to_float(aie_Y.at(h, w, c)) );
                        fail = 1;
                    } else {
                        printf("PASS: Y[%d][%d][%d] Expected: %f, Received: %f \n", h, w, c, bfloat16_to_float(cpu_Y.at(h, w, c)), bfloat16_to_float(aie_Y.at(h, w, c)) );
                    }
                }
            }
        }
    }
    return fail;
}

#endif // MATRIX_HPP