#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__

using namespace std;

#include <iostream>
#include <vector>
#include <cstring>
#include <cassert>
#include <cmath>

#if (QDQ_MODE == 1)
using Tin = uint16_t;
#else
    #if (FIXED_POINT_BIT_SIZE == 16)
        using Tin = uint16_t;
    #else 
        using Tin = uint8_t;
    #endif
#endif

#if (QDQ_MODE == 0)
using Tout = uint16_t;
#else 
    #if (FIXED_POINT_BIT_SIZE == 16)
        using Tout = uint16_t;
    #else
        using Tout = uint8_t;
    #endif
#endif



using Welem = int32_t;

struct bfloat16_t
{
    uint16_t value;
};

void write_bin_file(std::string filename, char* data, size_t size) {
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

uint16_t float_to_bfloat16(float x)
{
    uint32_t i;
    uint8_t* src = (uint8_t*) &x;
    uint8_t* tmp = (uint8_t*) &i;
    // copy float to uint32_t
    tmp[0] = src[0];
    tmp[1] = src[1];
    tmp[2] = src[2];
    tmp[3] = src[3];
    // round to nearest even
    uint32_t lsb = (i >> 16) & 0x1;
    uint32_t bias = 0x7fff + lsb;
    i += bias;
    // extract upper half of input
    uint16_t y = uint16_t(i >> 16);
    return y;
}

inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}

inline float bfloat16_to_float(uint16_t bf)
{
    return uint_to_float(uint32_t(bf) << 16);
}

int max_64_W8(int x){
    int m = 8;
    return ((x + m - 1) / m) * m;
}


template<typename T>
void init_random_mat(T* data, int N, int Y, int X, int C, int C_P, int qdq_mode,
                     int min_uint16 = 0, int max_uint16 = 127,
                     float min_bf16 = 3.3, float max_bf16 = 7.7,
                     float safe_quant_scale = 0.1f) // assume scale ~0.1 unless overridden
{
    int total_elements_NHW = N * Y * X;
    for (int i = 0; i < total_elements_NHW; ++i) {
        for (int j = 0; j < C_P; ++j){
            int idx = i * C_P + j;
            if (qdq_mode == 0 || qdq_mode == 2 || qdq_mode == 3) {
                // Generate 16-bit range values, avoiding extremes
                data[idx] = (j < C) ? ((rand() % (max_uint16 - min_uint16 + 1)) + min_uint16) : 0;
            } 
            else if (qdq_mode == 1) {
                // Generate float within safe quantizable range
                // Avoid saturation: assume int16_t output, stay in [-3276, 3276] if scale = 0.1
                float quant_limit = 32760.0f * safe_quant_scale; // buffer to avoid clipping
                // float rand_float = ((quant_limit - (-quant_limit)) * (rand() / (float) RAND_MAX)) - quant_limit;
                float rand_float = ((max_bf16 - min_bf16) * (rand() / (float) RAND_MAX)) + min_bf16;
                data[idx] = (j < C) ?  float_to_bfloat16(rand_float) : 0;
            } 
            else {
                printf("ERROR: Invalid qdq_mode %d\n", qdq_mode);
                return;
            }
        }
    }
}


template<typename T>
void init_wgt_mat(T* data, int qdq_mode) {  
    //for dq
    data[0] = 1;
    data[1] = float_to_bfloat16 (float(1.0));
    //for q
    data[2] = 1;
    data[3] = float_to_bfloat16 (float(1/ 1.0)); 

    //for enable
    if (qdq_mode == 0) {
        data[4] = 1; //dq ENABLE 
        data[5] = 0; //q DISABLE
    } else if (qdq_mode == 1)
    {
        data[4] = 0; //dq DISABLE
        data[5] = 1; //q ENABLE
    } else if (qdq_mode == 2)
    {
        data[4] = 1; //dq ENABLE
        data[5] = 1; //q ENABLE
    } else if (qdq_mode == 3)
    {
        data[4] = 0; //dq DISABLE
        data[5] = 0; //q DISABLE
    } else {
        std::cout << "ERROR MODE!" << std::endl;
    }
}


template<typename T>
void print_mat(T* data, int N, int Y, int X, int C, const std::string& msg = "")
{
    std::cout << msg << "\n";
 
    for (int n = 0; n < N; ++n) {
        for (int y = 0; y < Y; ++y) {
            for (int x = 0; x < X; ++x) {
                for (int c = 0; c < C; ++c) {
                    int idx = ((n * Y * X * C) + (y * X * C) + (x * C) + c);
                    printf("%3d ",  data[idx]) ;
                }
                printf("\n");
            }
            printf("\n");
        }
        printf("\n");
    }
}
 

template<typename Tin>
void dequant(Tin* in_data, uint16_t* out_data, int n_in, int h_in, int w_in, int c_in, Welem* qdq_param)
{
    float s = bfloat16_to_float(qdq_param[1]);
    uint16_t z = qdq_param[0];

    for (int n = 0; n < n_in; ++n) {
        for (int i = 0; i < h_in; ++i) {
            for (int j = 0; j < w_in; ++j) {
                for (int c = 0; c < c_in; ++c) {
                    int idx = ((n * h_in * w_in * c_in) +
                               (i * w_in * c_in) +
                               (j * c_in) + c);
                    float val = static_cast<float>(in_data[idx]) - z;
                    float val1 = val * s;
                    out_data[idx] = float_to_bfloat16(val1);
                }
            }
        }
    }
}

template<typename T>
void pad_channel_dim(T* in_data, T* out_data, int n_in, int h_in, int w_in, int c_in) {
    // Compute padded channel dimension
    int c_pad =  ((c_in + 7) / 8) * 8; // Next multiple of 8, at least 64

    for (int n = 0; n < n_in; ++n) {
        for (int i = 0; i < h_in; ++i) {
            for (int j = 0; j < w_in; ++j) {
                for (int c = 0; c < c_pad; ++c) {
                    int out_idx = ((n * h_in * w_in * c_pad) +
                                   (i * w_in * c_pad) +
                                   (j * c_pad) + c);

                    if (c < c_in) {
                        int in_idx = ((n * h_in * w_in * c_in) +
                                      (i * w_in * c_in) +
                                      (j * c_in) + c);
                        out_data[out_idx] = in_data[in_idx];
                    } else {
                        out_data[out_idx] = static_cast<T>(0);  // Pad with zero
                    }
                }
            }
        }
    }
}


template<typename Tout>
void quant(uint16_t* in_data, Tout* out_data, int n_in, int h_in, int w_in, int c_in, Welem* qdq_param)
{
    float inv_s = bfloat16_to_float(qdq_param[3]);
    uint16_t z = qdq_param[2];

    for (int n = 0; n < n_in; ++n) {
        for (int i = 0; i < h_in; ++i) {
            for (int j = 0; j < w_in; ++j) {
                for (int c = 0; c < c_in; ++c) {
                    int idx = ((n * h_in * w_in * c_in) +
                               (i * w_in * c_in) +
                               (j * c_in) + c);
                    float val = bfloat16_to_float(in_data[idx]);
                    float val1 = std::round(val * inv_s) + z;
                    if (val1 < 0)
                    {
                        val1 = 0;
                    }
                    out_data[idx] = static_cast<Tout>(val1);
                }
            }
        }
    }
}

template<typename Tin>
void slice_mat(
    Tin* matI, int Nin, int Yin, int Xin, int Cin,
    Tin* matO, int Nout, int Yout, int Xout, int Cout,
    int axis, int out_start, int out_stop)
{
    assert((axis >= 0 && axis <= 3));
    int slice_size = out_stop - out_start;
    assert((axis == 0 && slice_size == Nout) ||
           (axis == 1 && slice_size == Yout) ||
           (axis == 2 && slice_size == Xout) ||
           (axis == 3 && slice_size == Cout));

    for (int n = 0; n < Nout; ++n) {
        for (int y = 0; y < Yout; ++y) {
            for (int x = 0; x < Xout; ++x) {
                for (int c = 0; c < Cout; ++c) {
                    int src_n = (axis == 0) ? (out_start + n) : n;
                    int src_y = (axis == 1) ? (out_start + y) : y;
                    int src_x = (axis == 2) ? (out_start + x) : x;
                    int src_c = (axis == 3) ? (out_start + c) : c;

                    int src_idx = ((src_n * Yin * Xin * Cin) +
                                   (src_y * Xin * Cin) +
                                   (src_x * Cin) + src_c);

                    int dst_idx = ((n * Yout * Xout * Cout) +
                                   (y * Xout * Cout) +
                                   (x * Cout) + c);

                    matO[dst_idx] = matI[src_idx];
                }
            }
        }
    }
}

template<typename Tout>
void slice_mat_padding(Tout* in_data, Tout* out_data, int n_in, int h_in, int w_in, int c_in, int c_in_p)
{   
    for(int n =0; n < n_in; ++n){
        for(int h = 0; h < h_in; ++h) {
            for (int w = 0; w < w_in; ++w) {
                for (int c = 0; c < c_in_p; ++c) {     
                    Tout val = in_data[(n * h_in * w_in * c_in) + (h * w_in * c_in) + (w * c_in) + c]; 
                    out_data[(n * h_in * w_in * c_in_p) + (h * w_in * c_in_p) + (w * c_in_p) + c] = (c < c_in) ? val : 0;
                }
            }
        }
    }
}

template<typename T>
int check_result(
    T* expected,
    T* received,
    int N, int Y, int X, int C,
    int qdq_mode,
    int threshold = 8)
{
    int err_count = 0;
    float diff;

    for (int n = 0; n < N; ++n) {
        for (int y = 0; y < Y; ++y) {
            for (int x = 0; x < X; ++x) {
                for (int c = 0; c < C; ++c) {
                    int idx = ((n * Y * X * C) + (y * X * C) + (x * C) + c);

                    T e = expected[idx];
                    T r = received[idx];
                    if (qdq_mode == 3 || qdq_mode == 2) {
                        diff = e - r;
                    } else if (qdq_mode == 0) {
                        diff = std::abs(bfloat16_to_float(e) - bfloat16_to_float(r));
                    } else {
                        diff = std::abs(static_cast<int16_t>(e) - static_cast<int16_t>(r));
                    }

                    if (diff > threshold) {
                        std::cout << "ERROR: [n=" << n << ", y=" << y << ", x=" << x << ", c=" << c << "]: ";
                        if (qdq_mode == 3) {
                            std::cout << "Expected: " << e
                                      << ", Received: " << r
                                      << ", Diff: " << diff << "\n";
                        } else if (qdq_mode == 0 || qdq_mode == 2) {
                            std::cout << "Expected: " << bfloat16_to_float(e)
                                      << ", Received: " << bfloat16_to_float(r)
                                      << ", Diff: " << diff << "\n";
                        } else {
                            std::cout << "Expected: " << static_cast<int>(e)
                                      << ", Received: " << static_cast<int>(r)
                                      << ", Diff: " << diff << "\n";
                        }
                        ++err_count;
                    }
                }
            }
        }
    }
    return err_count;
}


#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

int main() {
    srand(0xABCD);

    int aie_rows = AIE_ROWS;
    int aie_cols = AIE_COLS;
    int n_in = N_IN;
    int y_in = Y_IN;
    int x_in = X_IN;
    int c_in = C_IN;
    int n_out = N_OUT;
    int y_out = Y_OUT;
    int x_out = X_OUT;
    int c_out = C_OUT;
    int axis = AXIS;
    int out_start = OUT_START;
    int out_stop = OUT_STOP;
    
    int qdq_mode = QDQ_MODE; //0: DEQUANT; 1: QUANT; 2: BOTH; 3: NONE
    int enable_cout_pad = COUT_PAD;
    int c_in_padded, c_out_padded;
    if (enable_cout_pad == 1) {
        c_in_padded = max_64_W8(c_in);
        c_out_padded = max_64_W8(c_out);
    } else  {
        c_in_padded = c_in;
        c_out_padded = c_out;
    }

    int wgt_size  = WGT_SIZE; 
    uint32_t ifm_size = n_in * y_in * x_in * c_in_padded * sizeof(Tin);
    uint32_t cpu_ofm_size = n_out * x_out * y_out * c_out_padded * sizeof(Tin);
    uint32_t cpu_ofm_size_no_pad = n_out * x_out * y_out * c_out * sizeof(Tin);
    uint32_t aie_ofm_size = n_out * x_out * y_out * c_out_padded * sizeof(Tout);
    uint32_t scratch_size = n_out * x_out * y_out * c_out * sizeof(uint16_t);
    uint32_t qdq_output_no_pad_size = n_out * x_out * y_out * c_out * sizeof(Tout);


    std::cout << "TYPE OF INPUT " << sizeof(Tin) << std::endl;
    std::cout << "TYPE OF OUTPUT " << sizeof(Tout) << std::endl;


#ifdef __TXNRT__
    auto aie_ifm = static_cast<Tin*>(malloc(ifm_size));
    auto aie_wgt = static_cast<Welem*>(malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(malloc(aie_ofm_size));
#else
    auto aie_ifm = static_cast<Tin*>(adf::GMIO::malloc(ifm_size));
    auto aie_wgt = static_cast<Welem*>(adf::GMIO::malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(adf::GMIO::malloc(aie_ofm_size));
#endif // __TXNRT__
    auto cpu_ofm_nopad = static_cast<Tin*>(malloc(cpu_ofm_size_no_pad));
    auto cpu_ofm = static_cast<Tin*>(malloc(cpu_ofm_size));
    auto qdq_output = static_cast<Tout*>(malloc(aie_ofm_size));
    auto qdq_output_no_pad = static_cast<Tout*>(malloc(qdq_output_no_pad_size));
    auto scratch_buffer = static_cast<uint16_t*>(malloc(scratch_size));



    Tin* in_mat = aie_ifm;
    Tout* aie_out_mat = aie_ofm;
    Tin* cpu_out_mat = cpu_ofm;

    init_random_mat<Tin>(in_mat, n_in, y_in, x_in, c_in, c_in_padded, qdq_mode);
    Welem qdq_param[64];
    init_wgt_mat<Welem>(qdq_param, qdq_mode);
    printf("qdq_mode : %d\n", qdq_mode);
    memcpy(aie_wgt, (void*)qdq_param, wgt_size);

    slice_mat<Tin>(
        in_mat, n_in, y_in, x_in, c_in_padded,
        cpu_ofm_nopad, n_out, y_out, x_out, c_out,
        axis, out_start, out_stop);
    slice_mat_padding<Tin>(
            cpu_ofm_nopad, cpu_out_mat,
            n_out, y_out, x_out, c_out, c_out_padded
        );
    

        
#if QDQ_MODE == 0
        dequant<Tin>(cpu_ofm_nopad, qdq_output_no_pad, n_out, y_out, x_out, c_out, qdq_param);
#elif QDQ_MODE == 1
        quant<Tout>(cpu_ofm_nopad, qdq_output_no_pad, n_out, y_out, x_out, c_out, qdq_param);        
#elif QDQ_MODE == 2
        dequant<Tin>(cpu_ofm_nopad, scratch_buffer, n_out, y_out, x_out, c_out, qdq_param);
        quant<Tout>(scratch_buffer, qdq_output_no_pad, n_out, y_out, x_out, c_out, qdq_param);
#endif

    slice_mat_padding<Tout>(
        qdq_output_no_pad, qdq_output,
        n_out, y_out, x_out, c_out, c_out_padded
    );


#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), ifm_size);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), wgt_size);
#if (QDQ_MODE == 3)
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm), aie_ofm_size);
#else
            write_bin_file("ofm.bin", reinterpret_cast<char*>(qdq_output), aie_ofm_size);
#endif
    #else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm, aie_ifm, aie_wgt);
    g_compute_graph.end();

    print_mat(in_mat, n_in, y_in, x_in, c_in_padded, "AIE IFM =\n");
#if (QDQ_MODE == 3)
    print_mat<Tout>(cpu_out_mat, n_out, y_out, x_out, c_out_padded, "CPU OFM =\n");
#else
    print_mat<Tout>(qdq_output, n_out, y_out, x_out, c_out_padded, "CPU QDQ OFM =\n");
#endif
    print_mat<Tout>(aie_out_mat, n_out, y_out, x_out, c_out_padded, "AIE OFM =\n");

    int threshold = 8;
    int err_count = 0;
    

#if (QDQ_MODE == 3)
        err_count = check_result<Tout>(cpu_out_mat, aie_out_mat, n_out, y_out, x_out, c_out_padded, qdq_mode, threshold);
#else
        err_count = check_result<Tout>(qdq_output, aie_out_mat, n_out, y_out, x_out, c_out_padded, qdq_mode, threshold);
#endif


    if (err_count == 0) {
        printf("DI: PASS\n");
    } else {
        printf("DI: FAIL\n");
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif // __AIESIM__ || __TXNRT__
    #ifdef __TXNRT__
        free(aie_ifm);
        free(aie_wgt);
        free(aie_ofm);
    #else
        adf::GMIO::free(aie_ifm);
        adf::GMIO::free(aie_wgt);
        adf::GMIO::free(aie_ofm);
    #endif // __TXNRT__
    free(cpu_ofm);

    #ifndef __TXNRT__
    assert(false);
    #endif // __TXNRT__
}
