#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__

using namespace std;

#include <iostream>
#include <vector>
#include <cstring>
#include <cassert>
#include <cmath>

using Telem = uint16_t;
using Welem = int32_t;

struct bfloat16_t
{
    uint16_t value;
};

void write_bin_file(std::string filename, char* data, size_t size) {
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

uint16_t float_to_bfloat16(float x)
{
    uint32_t i;
    uint8_t* src = (uint8_t*) &x;
    uint8_t* tmp = (uint8_t*) &i;
    // copy float to uint32_t
    tmp[0] = src[0];
    tmp[1] = src[1];
    tmp[2] = src[2];
    tmp[3] = src[3];
    // round to nearest even
    uint32_t lsb = (i >> 16) & 0x1;
    uint32_t bias = 0x7fff + lsb;
    i += bias;
    // extract upper half of input
    uint16_t y = uint16_t(i >> 16);
    return y;
}

inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}

inline float bfloat16_to_float(uint16_t bf)
{
    return uint_to_float(uint32_t(bf) << 16);
}

template<typename T>
void init_random_mat(T* data, int num_rows, int num_cols, int qdq_mode,
                    int min_uint16 = 0, int max_uint16 = 2048, int min_bf16 = -128, int max_bf16 = 128)
{
    for (int i = 0; i < num_rows; ++i) {
        for (int j = 0; j < num_cols; ++j) {
            if (qdq_mode == 0 || qdq_mode == 2)
                data[(i * num_cols) + j] = (rand() % (max_uint16 - min_uint16)) + min_uint16;
            else if (qdq_mode == 1)
                data[(i * num_cols) + j] = float_to_bfloat16(min_bf16 + static_cast<float>(std::rand()) / (static_cast<float>(RAND_MAX / (max_bf16 - min_bf16))));
            else
                printf("error mode!");
        }
    }
}

template<typename T>
void init_wgt_mat(T* data, int qdq_mode) {  
    int num_inputs = 1;
    assert (num_inputs <= 32); // half
    if (qdq_mode == 0) {
        data[0] = 0;
        data[1] = float_to_bfloat16 (float(1.0));
    } else if  (qdq_mode == 1) {
        data[0] = 0;
        data[1] = float_to_bfloat16 (float(1/ 1.0));           
    }
}

template<typename T>
void print_mat(T* data, int num_rows, int num_cols, std::string msg = "")
{
    std::cout << msg;
    for (int i = 0; i < num_rows; ++i) {
        for (int j = 0; j < num_cols; ++j) {
            std::cout << data[(i * num_cols) + j] << " ";
        }
        std::cout << "\n";
    }
    std::cout << "\n";
}

template<typename T>
void slice_mat(
    T* matI, int Hin, int Win,
    T* matO, int Hout, int Wout,
    int Wout_start, int Wout_stop,
    int qdq_mode, Welem* qdq_prm)
{
    assert(Hin == Hout);
    assert((Wout_stop - Wout_start) == Wout);
    T tmp;
    for (int n = 0; n < 1; ++n) { 
        uint16_t z = qdq_prm[n*2];
        float inv_or_s = bfloat16_to_float(qdq_prm[n*2 + 1]);
        printf("qdq_mode = %2d, n = %2d, z = %2d, s = %1.1f\n", qdq_mode, n, z, inv_or_s);
        for (int i = 0; i < Hin; ++i) {
            for (int j = 0; j < Wout; ++j) {
                if (qdq_mode==0) { //deq
                    tmp = matI[(i * Win) + Wout_start + j];
                    matO[(i * Wout) + j] = float_to_bfloat16((float)(tmp - z) * inv_or_s); 
                } else if (qdq_mode == 1) {
                    tmp = matI[(i * Win) + Wout_start + j];
                    matO[(i * Wout) + j] = static_cast<int16_t>(std::round(bfloat16_to_float(tmp) * inv_or_s) + z);
                } else if (qdq_mode == 2){
                    matO[(i * Wout) + j] = -(matI[(i * Win) + Wout_start + j]);
                }

            }
        }
    }
}

template<typename T>
int check_result(
    T* expected,
    T* received,
    int num_rows, int num_cols,
    int qdq_mode,
    int threshold = 8)
{
    int err_count = 0;
    float diff; 
    for (int i = 0; i < num_rows; ++i) {
        for (int j = 0; j < num_cols; ++j) {
            T e = expected[(i * num_cols) + j];
            T r = received[(i * num_cols) + j];
            if (qdq_mode == 0) {
                diff = std::abs(bfloat16_to_float(e) - bfloat16_to_float(r));
            } else {
                diff = std::abs(static_cast<int16_t>(e) - static_cast<int16_t>(r));
            }
            if (diff > threshold) {
                if (qdq_mode == 0) {
                    std::cout << "ERROR: [" << i << ", " << j << "]: "
                            << "Expected: " << bfloat16_to_float(e) << ", "
                            << "Received: " << bfloat16_to_float(r) << ", "
                            << "Diff    : " << diff << "\n";
                } else {
                    std::cout << "ERROR: [" << i << ", " << j << "]: "
                            << "Expected: " << e << ", "
                            << "Received: " << r << ", "
                            << "Diff    : " << diff << "\n";
                }
                err_count += 1;
            }
        }
    }
    return err_count;
}

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

int main() {
    srand(0xABCD);

    int aie_rows = AIE_ROWS;
    int aie_cols = AIE_COLS;
    int constexpr h_in = H_IN;
    int constexpr w_in = W_IN;
    int constexpr w_out_start = W_OUT_START;
    int constexpr w_out_stop = W_OUT_STOP;
    int constexpr h_out = H_OUT;
    int constexpr w_out = W_OUT;
    
    int qdq_neg_mode = 2; 2: NEGATIVE

    int wgt_size  = (qdq_neg_mode == 2) ? 1:  WGT_SIZE; 
    int ifm_size = h_in * w_in * sizeof(Telem);
    int ofm_size = h_out * w_out * sizeof(Telem);

#ifdef __TXNRT__
    auto aie_ifm = static_cast<Telem*>(malloc(ifm_size));
    auto aie_wgt = static_cast<Telem*>(malloc(wgt_size));
    auto aie_ofm = static_cast<Telem*>(malloc(ofm_size));
#else
    auto aie_ifm = static_cast<Telem*>(adf::GMIO::malloc(ifm_size));
    auto aie_wgt = static_cast<Telem*>(adf::GMIO::malloc(wgt_size));
    auto aie_ofm = static_cast<Telem*>(adf::GMIO::malloc(ofm_size));
#endif // __TXNRT__
    auto cpu_ofm = static_cast<Telem*>(malloc(ofm_size));

    Telem* in_mat = aie_ifm;
    Telem* aie_out_mat = aie_ofm;
    Telem* cpu_out_mat = cpu_ofm;

    init_random_mat(in_mat, h_in, w_in, qdq_neg_mode);

    Welem qdq_param[64];
    if (qdq_neg_mode == 0 || qdq_neg_mode == 1){
        init_wgt_mat<Welem>(qdq_param, qdq_neg_mode);
        printf("qdq_mode : %d\n", qdq_neg_mode);
        memcpy(aie_wgt, (void*)qdq_param, wgt_size);
    }

    slice_mat(
        in_mat, h_in, w_in,
        cpu_out_mat, h_out, w_out,
        w_out_start, w_out_stop,
        qdq_neg_mode, qdq_param);

#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), ifm_size);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), wgt_size);
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm), ofm_size);
    #else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm, aie_ifm, aie_wgt);
    g_compute_graph.end();

    print_mat(in_mat, h_in, w_in, "AIE IFM =\n");
    print_mat(cpu_out_mat, h_out, w_out, "CPU OFM =\n");
    print_mat(aie_out_mat, h_out, w_out, "AIE OFM =\n");

    int threshold = (qdq_neg_mode == 0 || qdq_neg_mode == 2) ? 8: 1;
    int err_count = check_result(cpu_out_mat, aie_out_mat, h_out, w_out, qdq_neg_mode, threshold);
    if (err_count == 0) {
        printf("DI: PASS\n");
    } else {
        printf("DI: FAIL\n");
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif // __AIESIM__ || __TXNRT__
    #ifdef __TXNRT__
        free(aie_ifm);
        free(aie_wgt);
        free(aie_ofm);
    #else
        adf::GMIO::free(aie_ifm);
        adf::GMIO::free(aie_wgt);
        adf::GMIO::free(aie_ofm);
    #endif // __TXNRT__
    free(cpu_ofm);

    #ifndef __TXNRT__
    assert(false);
    #endif // __TXNRT__
}
