#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__
#include <iostream>
#include <vector>
#include <cstring>
#include <cassert>
#include <cstdint>

using namespace std;

void write_bin_file(std::string filename, char* data, size_t size) {
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

template <typename T>
void init_random_4D_matrix(T* data, int dim1, int dim2, int dim3, int dim4) {
    int counter = 0;
    for (int i = 0; i < dim1; ++i) {
        for (int j = 0; j < dim2; ++j) {
            for (int k = 0; k < dim3; ++k) {
                for (int l = 0; l < dim4; ++l) {
                    data[(i * dim2 * dim3 * dim4) + (j * dim3 * dim4) + (k * dim4) + l] = counter++;
                }
            }
        }
    }
}

template <typename T>
void print_4D_matrix(T* data, int dim1, int dim2, int dim3, int dim4, const string& msg) {
    cout << msg << endl;
    for (int i = 0; i < dim1; ++i) {
        for (int j = 0; j < dim2; ++j) {
            for (int k = 0; k < dim3; ++k) {
                for (int l = 0; l < dim4; ++l) {
                    cout << data[(i * dim2 * dim3 * dim4) + (j * dim3 * dim4) + (k * dim4) + l] << " ";
                }
                cout << endl;
            }
            cout << endl;
        }
        cout << "----" << endl;
    }
}

// Function to permute a 4D matrix based on the specified permutation
template <typename T>
void depthToSpace(T* input,
                  T* output,
                  int batch, int height, int width, int depth, int blockSize,
                  bool DCR = true) {//default is DCR(true), CRD(false) 
    // Validate dimensions
    if (depth % (blockSize * blockSize) != 0) {
        throw std::invalid_argument("Depth must be divisible by blockSize^2.");
    }

    // Compute the output dimensions
    int newDepth = depth / (blockSize * blockSize);
    int newHeight = height * blockSize;
    int newWidth = width * blockSize;

    // Perform DepthToSpace rearrangement
    for (int b = 0; b < batch; ++b) {
        for (int h = 0; h < newHeight; ++h) {
            for (int w = 0; w < newWidth; ++w) {
                for (int d = 0; d < newDepth; ++d) {
                    int offsetH = h % blockSize;
                    int offsetW = w % blockSize;
                    int inDepth = DCR ? (d * blockSize * blockSize + offsetH * blockSize + offsetW) //DCR mode
                                      : (d + newDepth * (offsetH * blockSize + offsetW));  //CRD mode
                    int inHeight = h / blockSize;
                    int inWidth = w / blockSize;

                    int inputIndex = b * (height * width * depth) +
                                     inHeight * (width * depth) +
                                     inWidth * depth +
                                     inDepth;

                    int outputIndex = b * (newHeight * newWidth * newDepth) +
                                      h * (newWidth * newDepth) +
                                      w * newDepth +
                                      d;

                    output[outputIndex] = input[inputIndex];
                }
            }
        }
    }
}

template <typename T>
int check_result(T* expected, T* received, int dim1, int dim2, int dim3, int dim4) {
    int err_count = 0;
    for (int i = 0; i < dim1; ++i) {
        for (int j = 0; j < dim2; ++j) {
            for (int k = 0; k < dim3; ++k) {
                for (int l = 0; l < dim4; ++l) {
                    T e = expected[(i * dim2 * dim3 * dim4) + (j * dim3 * dim4) + (k * dim4) + l];
                    T r = received[(i * dim2 * dim3 * dim4) + (j * dim3 * dim4) + (k * dim4) + l];
                    if (e != r) {
                        cout << "ERROR: [" << i << ", " << j << ", " << k << ", " << l << "]: "
                             << "Expected: " << e << ", "
                             << "Received: " << r << "\n";
                        err_count += 1;
                    }
                }
            }
        }
    }
    return err_count;
}

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

int main() {
    srand(0xABCD);
    using Telem = uint16_t;

    int aie_rows = AIE_ROWS;
    int aie_cols = AIE_COLS;
    int batch  = N_BATCH;
    int depth  = C_DEPTH;
    int height = Y_HEIGHT;
    int width  = X_WIDTH;
    int blockSize = BLOCK_SIZE;
    bool permuteMode = PERMUTE_MODE == 1 ? true : false;

    // Validate dimensions
    if (depth % (blockSize * blockSize) != 0) {
        throw std::invalid_argument("Depth must be divisible by blockSize^2.");
    }

    // Compute the output dimensions
    int outDepth  = depth / (blockSize * blockSize);
    int outHeight = height * blockSize;
    int outWidth  = width * blockSize;  

    int ifm_size = batch * depth * height * height * sizeof(Telem);
    int wgt_size = 1;
    int ofm_size = ifm_size;

#ifdef __TXNRT__
    auto aie_ifm = static_cast<Telem*>(malloc(ifm_size));
    auto aie_wgt = static_cast<Telem*>(malloc(wgt_size));
    auto aie_ofm = static_cast<Telem*>(malloc(ofm_size));
#else
    auto aie_ifm = static_cast<Telem*>(adf::GMIO::malloc(ifm_size));
    auto aie_wgt = static_cast<Telem*>(adf::GMIO::malloc(wgt_size));
    auto aie_ofm = static_cast<Telem*>(adf::GMIO::malloc(ofm_size));
#endif // __TXNRT__
    auto cpu_ofm = static_cast<Telem*>(malloc(ofm_size));
    // auto aie_ifm = static_cast<Telem*>(malloc(ifm_size));
    // auto aie_wgt = static_cast<Telem*>(malloc(wgt_size));
    // auto aie_ofm = static_cast<Telem*>(malloc(ofm_size));
    // auto cpu_ofm = static_cast<Telem*>(malloc(ofm_size));

    Telem* in_mat = aie_ifm;
    Telem* aie_out_mat = aie_ofm;
    Telem* cpu_out_mat = cpu_ofm;

    init_random_4D_matrix<Telem>(in_mat, batch, height, width, depth);
    print_4D_matrix<Telem>(in_mat, batch, height, width, depth, "AIE IFM =\n");

    depthToSpace<Telem>(in_mat, cpu_out_mat, batch, height, width, depth, blockSize, permuteMode);

 
    
    print_4D_matrix<Telem>(cpu_out_mat, batch, outHeight, outWidth, outDepth, "CPU OFM =\n");

#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), ifm_size);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), wgt_size);
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm), ofm_size);
    #else
        g_compute_graph.init();
        run_dma_layer_config(g_compute_graph, aie_ofm, aie_ifm, aie_wgt);
        g_compute_graph.end();

    print_4D_matrix<Telem>(aie_out_mat, batch, outHeight, outWidth, outDepth, "AIE OFM =\n");

    int err_count = check_result<Telem>(cpu_out_mat, aie_out_mat, batch, outHeight, outWidth, outDepth);
    if (err_count == 0) {
        printf("DI: PASS\n");
    } else {
        printf("DI: FAIL\n");
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif // __AIESIM__ || __TXNRT__

    #ifdef __TXNRT__
        free(aie_ifm);
        free(aie_wgt);
        free(aie_ofm);
    #else
        adf::GMIO::free(aie_ifm);
        adf::GMIO::free(aie_wgt);
        adf::GMIO::free(aie_ofm);
    #endif // __TXNRT__
    free(cpu_ofm);
    
    #ifndef __TXNRT__
    assert(false);
    #endif // __TXNRT__
    return 0;
}
