#include "elemwise_qdq.hpp"
#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__

using namespace std;

#include <iostream>
#include <vector>
#include <cstring>
#include <cassert>
#include <cstdint>
#include <fstream>
#include <string>
#include <algorithm>
#include <cmath>
#include <sstream>


/*  ------- QDQ mode and bit width
    if is_int16:
        ifm_bits = 16
        ofm_bits = ifm_bits
        has_scratch_buf = False
        transpose_bits = 16
    else: # int8
        if qdq_mode == 0:  #dq only
            #NOTE: sequence:
            # 1. first do transpose (8bits in) -> 8bits output buff (2nd half);
            # 2. then do dq, from 8bits output buf 2nd half to 16bits out buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 16
            has_scratch_buf = False
            scratch_buf_bits = 8
            transpose_bits = 8
        elif qdq_mode == 1: #q only
            #NOTE: sequence:
            # 1. first do q (16bits input buf) -> 8bits to same buf;
            # 2. then do transpose, from 8bits input buf to 8bits out buf
            # sctrach buf elem:  0
            ifm_bits = 16
            ofm_bits = 8
            has_scratch_buf = False # q output use ifm buffer
            scratch_buf_bits = 8
            transpose_bits = 8
        elif qdq_mode == 2:
            #NOTE: sequence:
            # 1. first do dq (8bits input buf) -> 16bits to scrath buf;
            # 2. second do q (16bits scratch buf) -> 8bits to scratch buf;
            # 3. then do transpose, from 8bits scratch buf to 8bits out buf
            # sctrach buf elem:  same as ifm
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = True
            scratch_buf_bits = 16
            transpose_bits = 8
        elif qdq_mode == 3:
            #NOTE: sequence:
            # 1. do transpose from 8bits input buf to 8bits output buf
            # sctrach buf elem:  0
            ifm_bits = 8
            ofm_bits = 8
            has_scratch_buf = False
            scratch_buf_bits = 8
            transpose_bits = 8

*/


#if IS_SIGNED == 0
    #if IS_INT16 == 1
        using Tin = uint16_t;
        using Tout = uint16_t;
    #else
    #if QDQ_MODE == 3
        using Tin = uint8_t;
        using Tout = uint8_t;
    #elif QDQ_MODE == 0
        using Tin = uint8_t;
        using Tout = uint16_t;
    #elif QDQ_MODE == 1
        using Tin = uint16_t;
        using Tout = uint8_t;
    #elif QDQ_MODE == 2
        using Tin = uint8_t;
        using Tout = uint8_t;
    #else
        #error "INVALID QDQ_MODE"
    #endif

    #endif
#else
    #if IS_INT16 == 1
        using Tin = int16_t;
        using Tout = int16_t;
    #else
    #if QDQ_MODE == 3
        using Tin = int8_t;
        using Tout = int8_t;
    #elif QDQ_MODE == 0
        using Tin = int8_t;
        using Tout = int16_t;
    #elif QDQ_MODE == 1
        using Tin = int16_t;
        using Tout = int8_t;
    #elif QDQ_MODE == 2
        using Tin = int8_t;
        using Tout = int8_t;
    #else
        #error "INVALID QDQ_MODE"
    #endif
    #endif
#endif

using Welem = int32_t;


template<typename T>
void init_wgt_mat(T* data, int qdq_mode) {
    int32_t zp;
    float scale;
    float fmin = 0.0f; //fmin and fmax can be random data
    float fmax = 128.0f;
    compute_scale_and_zp<Tin>(fmin, fmax, scale, zp);
    //for dq
    // zp =0;
    // scale = 1.0f;
    data[0] = zp;
    data[1] = float_to_bfloat16 (scale);
    //for q
    if (qdq_mode == 1 ) //and is_int8
        compute_scale_and_zp<Tout>(fmin, fmax, scale, zp);
    else
        compute_scale_and_zp<Tin>(fmin, fmax, scale, zp);
    // zp =0;
    // scale = 1.0f;
    data[2] = zp;
    data[3] = float_to_bfloat16 (1/scale);

    //for enable
    if (qdq_mode == 0) {
        data[4] = 1; //dq ENABLE
        data[5] = 0; //q DISABLE
    } else if (qdq_mode == 1)
    {
        data[4] = 0; //dq DISABLE
        data[5] = 1; //q ENABLE
    } else if (qdq_mode == 2)
    {
        data[4] = 1; //dq ENABLE
        data[5] = 1; //q ENABLE
    } else if (qdq_mode == 3)
    {
        data[4] = 0; //dq DISABLE
        data[5] = 0; //q DISABLE
    } else {
        std::cout << "ERROR MODE!" << std::endl;
    }
}



void write_bin_file(std::string filename, char* data, size_t size) {
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}
int iceil(int x, int m)
{
    return ((x + m - 1) / m) * m;
}

template <typename T>
void init_random_5D_matrix(T* data, int batch_size,
    int dim1, int dim2, int dim3, int dim4,
    int dim1_p, int dim2_p, int dim3_p, int dim4_p,
    uint16_t zp, float scale,
    float min_f = 0.0f, float max_f = 127.0f
) {
    float rnd_data;
    for (int b = 0; b < batch_size; ++b) {
        for (int i = 0; i < dim1_p; ++i) {
            for (int j = 0; j < dim2_p; ++j) {
                for (int k = 0; k < dim3_p; ++k) {
                    for (int l = 0; l < dim4_p; ++l) {
                        if (i >= dim1 || j >= dim2 || k >= dim3 || l >= dim4){
                            rnd_data = 0.0f;
                        }else{
                            rnd_data = ((max_f - min_f) * (rand() / float(RAND_MAX))) + min_f;
                            // rnd_data = float(k);
                        }
                        T tmp = (QDQ_MODE== 1) ? float_to_bfloat16(rnd_data) : quantize<T>(rnd_data, scale, zp );
                        data[(b * dim1_p * dim2_p * dim3_p * dim4_p) + (i * dim2_p * dim3_p * dim4_p) + (j * dim3_p * dim4_p) + (k * dim4_p) + l] = tmp;
                    }
                }
            }
        }
    }
}

template <typename T>
void print_5D_matrix(T* data, int batch_size, int dim1, int dim2, int dim3, int dim4, const string& msg) {
    cout << msg << endl;
    for (int b = 0; b < batch_size; ++b) {
        for (int i = 0; i < dim1; ++i) {
            for (int j = 0; j < dim2; ++j) {
                for (int k = 0; k < dim3; ++k) {
                    for (int l = 0; l < dim4; ++l) {
                        cout << +data[(b * dim1 * dim2 * dim3 * dim4) + (i * dim2 * dim3 * dim4) + (j * dim3 * dim4) + (k * dim4) + l] << " ";
                    }
                    cout << endl;
                }
                cout << endl;
            }
            cout << "----" << endl;
        }
        cout << "----" << endl;
    }
}

// Function to transpose a 4D matrix based on the specified permutation
template <typename Tin, typename Tout>
void transpose_5D_matrix(
    const Tin* in_data,
    Tout* out_data,
    int batch_size,
    const std::vector<int>& in_dims,     // [D0, D1, D2, D3]
    const std::vector<int>& perm,        // length = 4, values in [0..3]
    int qdq_mode,
    Welem* qdq_param)
{
    // Original dims
    int D0 = in_dims[0];
    int D1 = in_dims[1];
    int D2 = in_dims[2];
    int D3 = in_dims[3];

    // Permuted dims
    std::vector<int> out_dims = {
        in_dims[perm[0]],
        in_dims[perm[1]],
        in_dims[perm[2]],
        in_dims[perm[3]]
    };

    // NHWC strides for input
    int stride0 = D1 * D2 * D3;
    int stride1 = D2 * D3;
    int stride2 = D3;

    // Strides for output according to output dims
    int out_stride0 = out_dims[1] * out_dims[2] * out_dims[3];
    int out_stride1 = out_dims[2] * out_dims[3];
    int out_stride2 = out_dims[3];

    // QDQ parameters
    float q_inv_or_s = 1.0f / bfloat16_to_float(qdq_param[3]);
    uint16_t q_z = qdq_param[2];

    float dq_inv_or_s = bfloat16_to_float(qdq_param[1]);
    uint16_t dq_z = qdq_param[0];

    for (int b = 0; b < batch_size; ++b)
    {
        const Tin* batch_in  = in_data  + (size_t)b * D0 * D1 * D2 * D3;
        Tout* batch_out      = out_data + (size_t)b * out_dims[0] * out_dims[1] * out_dims[2] * out_dims[3];

        for (int i0 = 0; i0 < D0; ++i0)
        for (int i1 = 0; i1 < D1; ++i1)
        for (int i2 = 0; i2 < D2; ++i2)
        for (int i3 = 0; i3 < D3; ++i3)
        {
            // Compute input flat index
            int in_index =
                i0 * stride0 +
                i1 * stride1 +
                i2 * stride2 +
                i3;

            // Apply perm
            int idx_in[4] = {i0, i1, i2, i3};
            int o0 = idx_in[perm[0]];
            int o1 = idx_in[perm[1]];
            int o2 = idx_in[perm[2]];
            int o3 = idx_in[perm[3]];

            // Compute output flat index
            int out_index =
                o0 * out_stride0 +
                o1 * out_stride1 +
                o2 * out_stride2 +
                o3;

            // --------------------------------------
            // QDQ logic
            // --------------------------------------
            Tin  in_val  = batch_in[in_index];
            Tout out_val;

            float ftmp;

            switch (qdq_mode)
            {
                case 0: // dq only
                case 2: // dq+q
                    ftmp = dequantize<Tout>(in_val, dq_inv_or_s, dq_z);
                    if (qdq_mode == 0)
                        out_val = float_to_bfloat16(ftmp);
                    else
                        out_val = quantize<Tout>(ftmp, q_inv_or_s, q_z);
                    break;

                case 1: // q only
                    out_val = quantize<Tin>(bfloat16_to_float(in_val), q_inv_or_s, q_z);
                    break;

                case 3: // copy
                    out_val = in_val;
                    break;
            }

            batch_out[out_index] = out_val;
        }
    }
}


//function to depad the cpu_ofm
template <typename T>
void depad_5D_matrix(const T* in_data, T* out_data,
                     int batch_size,
                     const std::vector<int>& in_dims,   // [N, H, W, C]
                     int axis,
                     int slice_end,
                     int slice_start = 0)
{
    int N = in_dims[0];
    int H = in_dims[1];
    int W = in_dims[2];
    int C = in_dims[3];

    // Input strides after the batch dimension
    int stride_N = H * W * C;
    int stride_H = W * C;
    int stride_W = C;

    // Output dims (same as input but axis shrinks)
    std::vector<int> out_dims = in_dims;
    out_dims[axis] = slice_end - slice_start;

    int out_stride_N = out_dims[1] * out_dims[2] * out_dims[3];
    int out_stride_H = out_dims[2] * out_dims[3];
    int out_stride_W = out_dims[3];

    for (int b = 0; b < batch_size; ++b) {
        const T* batch_in  = in_data  + b * N * H * W * C;
        T* batch_out       = out_data + b * out_dims[0] * out_dims[1] * out_dims[2] * out_dims[3];

        for (int n = 0; n < N; ++n) {
            for (int h = 0; h < H; ++h) {
                for (int w = 0; w < W; ++w) {
                    for (int c = 0; c < C; ++c) {

                        int idx[4] = {n, h, w, c};   // original indices

                        // skip unwanted slice region
                        int ax = idx[axis];
                        if (ax < slice_start || ax >= slice_end) continue;

                        // original offset (in NHWC)
                        int in_off = n * stride_N + h * stride_H + w * stride_W + c;

                        // shifted depad index
                        idx[axis] = ax - slice_start;

                        int out_off =
                              idx[0] * out_stride_N
                            + idx[1] * out_stride_H
                            + idx[2] * out_stride_W
                            + idx[3];

                        batch_out[out_off] = batch_in[in_off];
                    }
                }
            }
        }
    }
}



// Copy 4D matrix into padded version along inner-most dimension
template <typename T>
void zero_padding_5D_matrix(T* input, T* output, int batch_size, int n, int h, int w, int c, int pad) {
    int padded_c = c + pad;
     for (int b = 0; b < batch_size; ++b) {
        for (int ni = 0; ni < n; ++ni) {
            for (int hi = 0; hi < h; ++hi) {
                for (int wi = 0; wi < w; ++wi) {
                    for (int ci = 0; ci < padded_c; ++ci) {
                        int out_idx = ((((b * n + ni) * h + hi) * w + wi) * padded_c + ci);
                        if (ci < c) {
                            int in_idx = ((((b * n + ni) * h + hi) * w + wi) * c + ci);
                            output[out_idx] = input[in_idx];
                        } else {
                            output[out_idx] = static_cast<T>(0); // zero padding
                        }
                    }
                }
            }
        }
    }
}

template <typename T>
int check_result(T* expected, T* received, int batch_size,
    int dim1, int dim2, int dim3, int dim4, int threshold = 8
) {
    int err_count = 0;
    float diff;
    for (int b = 0; b < batch_size; ++b) {
        for (int i = 0; i < dim1; ++i) {
            for (int j = 0; j < dim2; ++j) {
                for (int k = 0; k < dim3; ++k) {
                    for (int l = 0; l < dim4; ++l) {
                        T e = expected[(b * dim1 * dim2 * dim3 * dim4) + (i * dim2 * dim3 * dim4) + (j * dim3 * dim4) + (k * dim4) + l];
                        T r = received[(b * dim1 * dim2 * dim3 * dim4) + (i * dim2 * dim3 * dim4) + (j * dim3 * dim4) + (k * dim4) + l];

                        diff = std::abs(static_cast<float>(e) - static_cast<float>(r));

                        if (diff > threshold) {
                            std::cout << "ERROR: [b=" << b << ", n=" << i << ", h=" << j << ", w=" << k << ", c=" << l << "]: ";
                            std::cout << "Expected: " << +e
                                        << ", Received: " << +r
                                        << ", Diff: " << diff << "\n";
                            ++err_count;
                        }else if (diff > 0) {
                            std::cout << "WARNING: [b=" << b << ", n=" << i << ", h=" << j << ", w=" << k << ", c=" << l << "]: ";
                            std::cout << "Expected: " << +e
                                        << ", Received: " << +r
                                        << ", Diff: " << diff << "\n";
                        }
                    }
                }
            }
        }
    }
    return err_count;
}

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

int main() {
    srand(0xABCD);
    int aie_rows = AIE_ROWS;
    int aie_cols = AIE_COLS;

    std::vector<int> input, perm;
    std::ifstream file("shapes.txt");
    std::string line;

    if (file.is_open()) {
        std::vector<int>* current_vector = &input;
        while (std::getline(file, line)) {
            std::istringstream ss(line);
            std::string value;
            while (std::getline(ss, value, ',')) {
                current_vector->push_back(std::stoi(value));
            }
            // Move to the next vector for each line
            if (current_vector == &input) {
                current_vector = &perm;
            }
        }
        file.close();
    } else {
        std::cerr << "Unable to open file shapes.txt" << std::endl;
    }
    //the real input based on the graph
    int Ni = input[0];
    int Yi = input[1];
    int Xi = input[2];
    int Ci = input[3];
    //the padded input by host or previous layer
    int batch_size = BATCH_SIZE;
    int Nip = N_IP;
    int Yip = Y_IP;
    int Xip = X_IP;
    int Cip = C_IP;

    int Nop = N_OP;
    int Yop = Y_OP;
    int Xop = X_OP;
    int Cop = C_OP;
    //the input to AIE transpose OP
    std::vector<int> input_pad(4);
    input_pad[0] = Nip;
    input_pad[1] = Yip;
    input_pad[2] = Xip;
    input_pad[3] = Cip;
    //the output from AIE transpose OP
    std::vector<int> output_pad(4);
    output_pad[0] = Nop;
    output_pad[1] = Yop;
    output_pad[2] = Xop;
    output_pad[3] = Cop;

    std::vector<int> out_dims(4);
    for (size_t i = 0; i < perm.size(); ++i) {
        out_dims[i] = input[perm[i]];
    }

    std::vector<int> out_dims_no_depad(4), out_dims_pad(4);
    for (size_t i = 0; i < perm.size(); ++i) {
        out_dims_no_depad[i] = input_pad[perm[i]];
        out_dims_pad[i] = output_pad[perm[i]];
    }

    int ifm_size_pad = batch_size * Nip * Yip * Xip * Cip * sizeof(Tin);
    int ofm_size_from_ifm = batch_size * Nip * Yip * Xip * Cip * sizeof(Tout);
    int wgt_size = 64;
    int ofm_size_pad = batch_size * Nop * Yop * Xop * Cop * sizeof(Tout);
    int cpu_ofm_size_pad_no_depad = batch_size * Nop * Yop * Xop * Cip * sizeof(Tout);

    printf("------------------------------------------------------------\n");
    printf("ifm_size_pad: %d\n", ifm_size_pad);
    printf("ofm_size_pad: %d\n", ofm_size_pad);
    printf("cpu_ofm_size_pad_no_depad: %d\n", cpu_ofm_size_pad_no_depad);

    printf("batch_size: %d\n", batch_size);
    printf("Ni: %d\n", Ni);
    printf("Yi: %d\n", Yi);
    printf("Xi: %d\n", Xi);
    printf("Ci: %d\n", Ci);
    printf("Nip: %d\n", Nip);
    printf("Yip: %d\n", Yip);
    printf("Xip: %d\n", Xip);
    printf("Cip: %d\n", Cip);
    printf("Nop: %d\n", Nop);
    printf("Yop: %d\n", Yop);
    printf("Xop: %d\n", Xop);
    printf("Cop: %d\n", Cop);

    printf("perm[0]: %d\n", perm[0]);
    printf("perm[1]: %d\n", perm[1]);
    printf("perm[2]: %d\n", perm[2]);
    printf("perm[3]: %d\n", perm[3]);

    printf("is_signed: %d\n", IS_SIGNED);
    printf("is_int16: %d\n", IS_INT16);
    printf("qdq_mode: %d\n", QDQ_MODE);
    printf("------------------------------------------------------------\n");

#ifdef __TXNRT__
    auto aie_ifm = static_cast<Tin*>(malloc(ifm_size_pad));
    auto aie_wgt = static_cast<Welem*>(malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(malloc(ofm_size_pad));
#else
    auto aie_ifm = static_cast<Tin*>(adf::GMIO::malloc(ifm_size_pad));
    auto aie_wgt = static_cast<Welem*>(adf::GMIO::malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(adf::GMIO::malloc(ofm_size_pad));
#endif // __TXNRT__
    auto cpu_ofm = static_cast<Tout*>(malloc(ofm_size_from_ifm));
    auto cpu_ofm_pad = static_cast<Tout*>(malloc(cpu_ofm_size_pad_no_depad));
    auto cpu_ofm_depad = static_cast<Tout*>(malloc(ofm_size_pad));

    Tin* in_mat = aie_ifm;
    Tout* aie_out_mat = aie_ofm;
    Tout* cpu_out_mat = cpu_ofm;

    Welem qdq_param[64];
    if (QDQ_MODE == 0 || QDQ_MODE == 1 || QDQ_MODE == 2 || QDQ_MODE == 3){
        init_wgt_mat<Welem>(qdq_param, QDQ_MODE);
        printf("qdq_mode : %d\n", QDQ_MODE);
        memcpy(aie_wgt, (void*)qdq_param, wgt_size);
    }

    //we take assumption that the ifm is padded by host or one of the previous layers
    // if ip = i, then no padding
    uint16_t zp = qdq_param[0];
    float scale = bfloat16_to_float(qdq_param[1]);
    init_random_5D_matrix(in_mat, batch_size, Ni, Yi, Xi, Ci, Nip, Yip, Xip, Cip, zp, scale);
    print_5D_matrix(in_mat, batch_size, input[0], input[1], input[2], input[3], "AIE IFM =\n");

    transpose_5D_matrix<Tin, Tout>(in_mat, cpu_out_mat, batch_size, input_pad, perm, QDQ_MODE, qdq_param);
    print_5D_matrix(cpu_out_mat, batch_size, out_dims[0], out_dims[1], out_dims[2], out_dims[3], "CPU OFM =\n");

    //zero_padding_4D_matrix inner-most dimension
    //depading the C-dim which being transposed to outer dim
    int pad = out_dims_pad[3] - out_dims_no_depad[3];
    zero_padding_5D_matrix(cpu_out_mat, cpu_ofm_pad, batch_size,
        out_dims_no_depad[0], out_dims_no_depad[1], out_dims_no_depad[2], out_dims_no_depad[3], pad);

    print_5D_matrix(cpu_ofm_pad, batch_size, out_dims[0], out_dims[1], out_dims[2], out_dims_pad[3], "CPU OFM PAD =\n");

    auto it = std::find(perm.begin(), perm.end(), 3);
    int axis = 3;
    if (it != perm.end())
        axis = std::distance(perm.begin(), it);
    std::vector<int> out_dims_depad(4);
    out_dims_depad[0] = out_dims_no_depad[0];
    out_dims_depad[1] = out_dims_no_depad[1];
    out_dims_depad[2] = out_dims_no_depad[2];
    out_dims_depad[3] = out_dims_no_depad[3] + pad;

    depad_5D_matrix(cpu_ofm_pad, cpu_ofm_depad, batch_size, out_dims_depad, axis, out_dims_pad[axis]);
    // print_5D_matrix(cpu_ofm_pad, batch_size, out_dims_pad[0], out_dims_pad[1], out_dims_pad[2], out_dims_pad[3], "CPU OFM dePAD =\n");

#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), ifm_size_pad);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), wgt_size);
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm_depad), ofm_size_pad);
    #else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm, aie_ifm, aie_wgt);
    g_compute_graph.end();

    print_5D_matrix(aie_out_mat, batch_size, out_dims_pad[0], out_dims_pad[1], out_dims_pad[2], out_dims_pad[3], "AIE OFM =\n");
    int threshold = (QDQ_MODE != 2) ? 1: (IS_INT16 ? 128 : 8);
    int err_count = check_result(cpu_ofm_depad, aie_out_mat, batch_size,
        out_dims_pad[0], out_dims_pad[1], out_dims_pad[2], out_dims_pad[3], threshold
    );
    if (err_count == 0) {
        printf("DI: PASS\n");
    } else {
        printf("DI: FAIL\n");
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif // __AIESIM__ || __TXNRT__
    #ifdef __TXNRT__
        free(aie_ifm);
        free(aie_wgt);
        free(aie_ofm);
    #else
        adf::GMIO::free(aie_ifm);
        adf::GMIO::free(aie_wgt);
        adf::GMIO::free(aie_ofm);
    #endif // __TXNRT__
    free(cpu_ofm);
    free(cpu_ofm_pad);
    free(cpu_ofm_depad);

    #ifndef __TXNRT__
    assert(false);
    #endif // __TXNRT__
}
