#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) 
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__

using namespace std;

#include <iostream>
#include <vector>
#include <cstring>
#include <cassert>
#include <string>
#include <cmath>
#include "../../tools/src_cppref/qdq.hpp" //Note: required for new ref function modeled from carf

#if (QDQ_MODE == 1)
using Tin = uint16_t;
#else
    #if (FIXED_POINT_BIT_SIZE == 16)
        using Tin = uint16_t;
    #else 
        using Tin = uint8_t;
    #endif
#endif

#if (QDQ_MODE == 0)
using Tout = uint16_t;
#else 
    #if (FIXED_POINT_BIT_SIZE == 16)
        using Tout = uint16_t;
    #else
        using Tout = uint8_t;
    #endif
#endif



using Welem = int32_t;

template<typename T>
void init_random(T* data, int h_in, int w_in, int c_in, int qdq_mode, float min_bf16 = -1.0, float max_bf16 = 1.0, int min_fixed = 0, int max_fixed = 127)
{
    for (int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in; ++c) { 
                if (qdq_mode == 0 || qdq_mode == 2) {
                    data[(i * w_in * c_in) + (j * c_in) + c] = (rand() % (max_fixed - min_fixed)) + min_fixed;
                }
                else if (qdq_mode == 1) {
                    float val = ((max_bf16 - min_bf16) * (rand() / (float) RAND_MAX)) + min_bf16;
                    data[(i * w_in * c_in) + (j * c_in) + c] = float_to_bfloat16(val).value;
                } else {
                    printf("ERROR MODE!");
                }
                    
            }
        }
    }
}


template<typename T>
void print_mat(T* data, int N, int Y, int X, int C, const std::string& msg = "")
{
    std::cout << msg << "\n";
 
    for (int n = 0; n < N; ++n) {
        for (int y = 0; y < Y; ++y) {
            for (int x = 0; x < X; ++x) {
                for (int c = 0; c < C; ++c) {
                    int idx = ((n * Y * X * C) + (y * X * C) + (x * C) + c);
                    printf("%3d ",  data[idx]) ;
                }
                printf("\n");
            }
            printf("\n");
        }
        printf("\n");
    }
}
 

template<typename Tin>
void dequant(Tin* in_data, uint16_t* out_data, int n_in, int h_in, int w_in, int c_in, Welem* qdq_param)
{
    float s = bfloat16_to_float(qdq_param[1]);
    uint16_t z = qdq_param[0];

    for (int n = 0; n < n_in; ++n) {
        for (int i = 0; i < h_in; ++i) {
            for (int j = 0; j < w_in; ++j) {
                for (int c = 0; c < c_in; ++c) {
                    int idx = ((n * h_in * w_in * c_in) +
                               (i * w_in * c_in) +
                               (j * c_in) + c);
                    float val = static_cast<float>(in_data[idx]) - z;
                    float val1 = val * s;
                    out_data[idx] = float_to_bfloat16(val1).value;
                }
            }
        }
    }
}


template<typename Tout>
void quant(uint16_t* in_data, Tout* out_data, int n_in, int h_in, int w_in, int c_in, Welem* qdq_param)
{
    float inv_s = bfloat16_to_float(qdq_param[3]);
    uint16_t z = qdq_param[2];

    for (int n = 0; n < n_in; ++n) {
        for (int i = 0; i < h_in; ++i) {
            for (int j = 0; j < w_in; ++j) {
                for (int c = 0; c < c_in; ++c) {
                    int idx = ((n * h_in * w_in * c_in) +
                               (i * w_in * c_in) +
                               (j * c_in) + c);
                    float val = bfloat16_to_float(in_data[idx]);
                    float val1 = std::round(val * inv_s) + z;
                    if (val1 < 0)
                    {
                        val1 = 0;
                    }
                    out_data[idx] = static_cast<Tout>(val1);
                }
            }
        }
    }
}
template<typename T>
int check_result(
    T* expected,
    T* received,
    int h_in, int w_in, int c_in,
    int qdq_mode,
    int threshold = 8)
{
    int err_count = 0;
    float diff; 
    for (int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in; ++c) {
                T e = expected[(i * w_in * c_in) + (j * c_in) + c];
                T r = received[(i * w_in * c_in) + (j * c_in) + c];
                if (qdq_mode == 0) {
                    diff = std::abs(bfloat16_to_float(e) - bfloat16_to_float(r));
                } else {
                    diff = std::abs(static_cast<int16_t>(e) - static_cast<int16_t>(r));
                }
                if (diff > threshold) {
                    if (qdq_mode == 0) {
                        std::cout << "ERROR: [" << i << ", " << j << ", " << c <<"]: "
                                << "Expected: " << bfloat16_to_float(e) << ", "
                                << "Received: " << bfloat16_to_float(r) << ", "
                                << "Diff    : " << diff << "\n";
                    } else {
                        std::cout << "ERROR: [" << i << ", " << j << ", " << c <<"]: "
                                << "Expected: " << e << ", "
                                << "Received: " << r << ", "
                                << "Diff    : " << diff << "\n";
                    }
                    err_count += 1;
                }
            }
        }
    }
    return err_count;
}

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

int main() {
    srand(0xABCD);

    int aie_rows = AIE_ROWS;
    int aie_cols = AIE_COLS;
    int constexpr n_in = 1;
    int constexpr h_in = H_IN;
    int constexpr w_in = W_IN;
    int constexpr c_in = C_IN;
    int constexpr ifm_bytes = IFM_BYTES;

    int constexpr wgt_size  = 64; 
    int constexpr ifm_size = h_in * w_in * c_in * sizeof(Tin);
    int constexpr ofm_size = h_in * w_in * c_in * sizeof(Tout);
    int constexpr scratch_size = h_in * w_in * c_in * 2;

    int qdq_mode = QDQ_MODE;

#ifdef __TXNRT__
    auto aie_ifm = static_cast<Tin*>(malloc(ifm_size));
    auto aie_wgt = static_cast<Welem*>(malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(malloc(ofm_size));
#else
    auto aie_ifm = static_cast<Tin*>(adf::GMIO::malloc(ifm_size));
    auto aie_wgt = static_cast<Welem*>(adf::GMIO::malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(adf::GMIO::malloc(ofm_size));
#endif // __TXNRT__
    auto cpu_ofm = static_cast<Tout*>(malloc(ofm_size));
    auto scratch_buffer = static_cast<uint16_t*>(malloc(scratch_size));

    Tin* in_mat = aie_ifm;
    Tout* aie_out_mat = aie_ofm;
    Tout* cpu_out_mat = cpu_ofm;

    Welem qdq_param[64];
    qdq_param[0] = 0;                                 //dq_zp
    qdq_param[1] = float_to_bfloat16 (float(1.02)).value;    //dq_sc
    qdq_param[2] = 0;                                 //q_zp 
    qdq_param[3] = float_to_bfloat16 (float(1.0)).value;    //q_sc 
    qdq_param[4] = 1;                                 //dq_enable
    qdq_param[5] = 1;                                 //q_enable

    try {
        if (qdq_mode == 0) {
            qdq_param[4] = 1;
            qdq_param[5] = 0;
        } 
        else if (qdq_mode == 1) {
            qdq_param[4] = 0;
            qdq_param[5] = 1;
        } 
        else if (qdq_mode == 2) {
            qdq_param[4] = 1;
            qdq_param[5] = 1;
        } 
        else {
            throw std::runtime_error("Check qdq_params and DI qdq_mode");
        }
    } catch (const std::exception& e) {
        std::cerr << "Error: " << e.what() << std::endl;
        return EXIT_FAILURE;
    }


    init_random<Tin>(in_mat, h_in, w_in, c_in, qdq_mode, 3.7, 7.7);
    memcpy(aie_wgt, (void*)qdq_param, wgt_size);

#if QDQ_MODE == 0
        dequant<Tin>(aie_ifm, cpu_ofm, n_in, h_in, w_in, c_in, qdq_param);
#elif QDQ_MODE == 1
        quant<Tout>(aie_ifm, cpu_ofm, n_in, h_in, w_in, c_in, qdq_param);        
#elif QDQ_MODE == 2
        dequant<Tin>(aie_ifm, scratch_buffer, n_in, h_in, w_in, c_in, qdq_param);
        quant<Tout>(scratch_buffer, cpu_ofm, n_in, h_in, w_in, c_in, qdq_param);
#endif

#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            //DmaBins bins = run_dma_layer_config();
            //bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), ifm_size);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), wgt_size);
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm), ofm_size);
    #else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm, aie_ifm, aie_wgt);
    g_compute_graph.end();


    print_mat<Tin>(in_mat,1, h_in, w_in, c_in, "AIE IFM =\n");
    print_mat<Tout>(cpu_ofm,1,  h_in, w_in, c_in, "CPU OFM =\n");
    print_mat<Tout>(aie_ofm, 1, h_in, w_in, c_in, "AIE OFM =\n");

    int threshold = (qdq_mode == 0 || qdq_mode == 2) ? 8: 1;
    int err_count = check_result<Tout>(cpu_ofm, aie_ofm, h_in, w_in, c_in, qdq_mode, threshold);
    if (err_count == 0) {
        printf("DI: PASS\n");
    } else {
        printf("DI: FAIL\n");
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif // __AIESIM__ || __TXNRT__

    #ifdef __TXNRT__
        free(aie_ifm);
        free(aie_wgt);
        free(aie_ofm);
    #else
        adf::GMIO::free(aie_ifm);
        adf::GMIO::free(aie_wgt);
        adf::GMIO::free(aie_ofm);
    #endif // __TXNRT__
    free(cpu_ofm);
    
    #ifndef __TXNRT__
    assert(false);
    #endif // __TXNRT__
    return 0;
}
