#include <assert.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include <cmath>
#include <cstdint>
#include <numeric>
#include <algorithm>
#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__



#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif
using Telem = uint16_t;
using Twgts = int32_t;

uint16_t float_to_bfloat16(float x)
{
    uint32_t i;
    uint8_t* src = (uint8_t*) &x;
    uint8_t* tmp = (uint8_t*) &i;
    // copy float to uint32_t
    tmp[0] = src[0];
    tmp[1] = src[1];
    tmp[2] = src[2];
    tmp[3] = src[3];
    // round to nearest even
    uint32_t lsb = (i >> 16) & 0x1;
    uint32_t bias = 0x7fff + lsb;
    i += bias;
    // extract upper half of input
    uint16_t y = uint16_t(i >> 16);
    return y;
}


template<typename T>
void init_random_mat(T* data, int N, int Y, int X, int C, int C_P, int qdq_mode,
                     int min_uint16 = 0, int max_uint16 = 128,
                     float min_bf16 = 3.3, float max_bf16 = 7.7,
                     float safe_quant_scale = 0.1f) // assume scale ~0.1 unless overridden
{
    int total_elements_NHW = N * Y * X;
    for (int i = 0; i < total_elements_NHW; ++i) {
        for (int j = 0; j < C_P; ++j){
            int idx = i * C_P + j;
            if (qdq_mode == 0 || qdq_mode == 2 || qdq_mode == 3) {
                // Generate 16-bit range values, avoiding extremes
                data[idx] = (j < C) ? ((rand() % (max_uint16 - min_uint16 + 1)) + min_uint16) : 0;
            } 
            else if (qdq_mode == 1) {
                // Generate float within safe quantizable range
                // Avoid saturation: assume int16_t output, stay in [-3276, 3276] if scale = 0.1
                float quant_limit = 32760.0f * safe_quant_scale; // buffer to avoid clipping
                // float rand_float = ((quant_limit - (-quant_limit)) * (rand() / (float) RAND_MAX)) - quant_limit;
                float rand_float = ((max_bf16 - min_bf16) * (rand() / (float) RAND_MAX)) + min_bf16;
                data[idx] = (j < C) ?  float_to_bfloat16(rand_float) : 0;
            } 
            else {
                printf("ERROR: Invalid qdq_mode %d\n", qdq_mode);
                return;
            }
        }
    }
}

template <typename T>
void print_mat(std::vector<int> input_shape, T * input) {



    for (int n_num = 0; n_num < input_shape[0]; ++n_num)
    {
        for (int h_num = 0; h_num < input_shape[1]; ++h_num)
        {
            for (int w_num = 0; w_num < input_shape[2]; ++w_num)
            {

                std::cout << " W INDEX " << w_num << std::endl;
                for (int c_num = 0; c_num < input_shape[3];  ++c_num)
                {
                    int idx =   n_num * input_shape[1] * input_shape[2] * input_shape[3] +
                                h_num * input_shape[2] * input_shape[3] + 
                                w_num * input_shape[3] +
                                c_num;

                    std::cout << input[idx] << " ";
                }
                std::cout << std::endl;
            }
        }
    }
}

template <typename T>
void init_random_idxs(int num_indices, T * idxs, std::vector<int> input_shape) 
{
    for (int i = 0; i < num_indices; ++i)
    {
        idxs[i] = rand() % input_shape[2];

    }
}



void cpu_model(std::vector<int> input_shape, Twgts* idxs, Telem* input, Telem* output, int num_indices)
{
    for (int idx_n = 0; idx_n < num_indices; ++idx_n)
    {
        for (int elem_n = 0; elem_n < input_shape[3]; ++elem_n)
        {
            output[idx_n*input_shape[3] + elem_n] = input[idxs[idx_n] * input_shape[3] + elem_n];
        }
    }
}


void read_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::in | std::ios::binary);
    file.read(data, size);
}

void write_bin_file(std::string filename, char* data, size_t size) {
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

template<typename T>
void check_result(T* cpu_ofm, T* aie_ofm, std::vector<int> output_shape)
{
        bool is_pass = true;
        int error_count = 0;
        for (int n_num = 0; n_num < output_shape[0]; ++n_num)
        {
            for (int h_num = 0; h_num < output_shape[1]; ++h_num)
            {
                for (int w_num = 0; w_num < output_shape[2]; ++w_num)
                {
                    for (int c_num = 0; c_num < output_shape[3];  ++c_num)
                    {

                        int idx = n_num * output_shape[1] * output_shape[2] * output_shape[3] +
                                  h_num * output_shape[2] * output_shape[3] + 
                                  w_num * output_shape[3] +
                                  c_num;
                        if ( cpu_ofm[idx] != aie_ofm[idx] ) 
                        {
                                std::cout << "ERROR W INDEX " << w_num << "C INDEX " << c_num << " " 
                                          << "AIE OFM: " << aie_ofm[idx] 
                                          << " " 
                                          << "CPU OFM: " << cpu_ofm[idx] 
                                                                        << std::endl;
                                is_pass = false;
                                error_count++;
                        }
                    }
                }
            }
        }
        if (is_pass) {
            std::cout << "DI PASS" << std::endl;
        } else {
            std::cout << "DI FAIL" << std::endl;
        }

        std::cout << "ERROR COUNT: " << error_count << std::endl;
}


int main(void)
{

    std::cout << Nin << std::endl;
    std::cout << Hin << std::endl;
    std::cout << Win << std::endl;
    std::cout << Cin << std::endl;


    std::cout << NUM_INDICES << std::endl;

    std::vector<int> input_shape = {Nin, Hin, Win, Cin};
    std::vector<int> idxs_shape = {1, 1, 1, NUM_INDICES};
    std::vector<int> output_shape = {1, 1, NUM_INDICES, Cin};

    srand(0xABCD);

    int const data_size = Nin * Hin * Win * Cin * (INPUT_BYTES);
    int const output_size = output_shape[0] * output_shape[1] * output_shape[2] * output_shape[3] * (INPUT_BYTES);

    std::cout << "OUTPUT SIZE" << output_size << std::endl;
    int const idxs_size = NUM_INDICES * sizeof(Twgts);


#ifdef __TXNRT__
    Telem* aie_src = static_cast<Telem*>(malloc(data_size));
    Telem* aie_ofm = static_cast<Telem*>(malloc(output_size));
    Twgts* aie_wgt = static_cast<Twgts*>(malloc(1024));
    Telem* cpu_ofm = static_cast<Telem*>(malloc(output_size));
#else
    Telem* aie_src = static_cast<Telem*>(adf::GMIO::malloc(data_size));
    Telem* aie_ofm = static_cast<Telem*>(adf::GMIO::malloc(output_size));
    Twgts* aie_wgt = static_cast<Twgts*>(adf::GMIO::malloc(1024));
    Telem* cpu_ofm = static_cast<Telem*>(malloc(output_size));
#endif // __TXNRT__

    init_random_mat<Telem>(aie_src, input_shape[0], input_shape[1], input_shape[2], input_shape[3], input_shape[3], 3);
    Twgts* idxs = static_cast<Twgts*>(malloc(idxs_size));
    char * idxs_bin_data = static_cast<char*>(malloc(idxs_size));

    read_bin_file("idxs.bin", idxs_bin_data, idxs_size);

    memcpy(idxs, idxs_bin_data, idxs_size);
    
    
    cpu_model(input_shape, idxs, aie_src, cpu_ofm, NUM_INDICES);


#if INPUT_PRINTS
    {
        std::cout << "-------------------- INPUT -------------------" << std::endl;
        print_mat(input_shape, aie_src);

        std::cout << "-------------------- IDXS --------------------" << std::endl;

        for (int i = 0; i < NUM_INDICES; ++i)
        {
            std::cout << idxs[i] << std::endl;
            
        }
    }
#endif




#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_src), data_size);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), 1024);
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm), output_size);
    #else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm, aie_src, aie_wgt);
    g_compute_graph.end();
    #endif // __TXNRT__
#endif // __AIESIM__

#if !(TXN_MODE) && CHECK_OUTPUT_PRINTS
    std::cout << "----------------------- AIE OUTPUT ------------------------" << std::endl;

    std::string aie_input_type = "AIE";

    print_mat(output_shape, aie_ofm);

    std::cout << "--------------------- EXPECTED OUTPUT ---------------------" << std::endl;

    std::string cpu_input_type = "CPU";

    print_mat(output_shape, cpu_ofm);
    
    check_result<Telem>(cpu_ofm, aie_ofm, output_shape);
#endif

#ifdef __TXNRT__
    free(aie_src);
    free(aie_ofm);
    free(aie_wgt);
    free(cpu_ofm);
    free(idxs);
    free(idxs_bin_data);
#else
    adf::GMIO::free(aie_src);
    adf::GMIO::free(aie_ofm);
    adf::GMIO::free(aie_wgt);
    free(cpu_ofm);
    free(idxs);
    free(idxs_bin_data);
#endif

    // NOTE: This will exit the simulation if it hangs after completion.
    assert(false);

    return 0;
}