#include <string>
#include <sstream>
#include <fstream>
#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__
#include "maxpool.hpp"

#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

int run_maxpool_noqdq_a8(
    int const Yo, int const Xo, int const Co,
    int const Yi, int const Xi,
    int const Ky, int const Kx,
    int const Sy, int const Sx,
    int const Py, int const Px,
    int const sign,
    int const read_ifm
){
    int ifm_size = ActTensor<int8_t>::size(Co, Yi, Xi);
    int ofm_size = ActTensor<int8_t>::size(Co, Yo, Xo);
    int wgt_size = sizeof(MaxPool_NoQdqParams_A8);

    printf("DDR IFM SIZE = %d \n", ifm_size);
    printf("DDR OFM SIZE = %d \n", ofm_size);

#if !ASM_MODE
    ActTensor<int8_t> aie_ifm(
        Co, Yi, Xi,
        adf::GMIO::malloc(ifm_size)
    );
    ActTensor<int8_t> aie_ofm(
        Co, Yo, Xo,
        adf::GMIO::malloc(ofm_size)
    );
    auto qdq_params = reinterpret_cast<MaxPool_NoQdqParams_A8*>(adf::GMIO::malloc(wgt_size));
#else
    ActTensor<int8_t> aie_ifm(
        Co, Yi, Xi,
        malloc(ifm_size)
    );
    auto qdq_params = reinterpret_cast<MaxPool_NoQdqParams_A8*>(malloc(wgt_size));
#endif // !ASM_MODE
    ActTensor<int8_t> cpu_ofm(
        Co, Yo, Xo,
        malloc(ofm_size)
    );

    std::string const ifm_bin_path = "../intermediate_bins/ifm1.bin";
    if (read_ifm) {
        read_bin_file(ifm_bin_path, reinterpret_cast<char*>(aie_ifm.data), ifm_size);
    } else {
        init_random_maxpool_noqdq_a8(aie_ifm, qdq_params, sign);
    }
    cpu_maxpool_noqdq_a8(aie_ifm, cpu_ofm, Sy, Sx, Py, Px, Ky, Kx, sign);
#if ASM_MODE
    write_bin_file("ifm.bin", (char*)aie_ifm.data, ifm_size);
    write_bin_file("wgt.bin", (char*)qdq_params, wgt_size);
    write_bin_file("ofm.bin", (char*)cpu_ofm.data, ofm_size);
    write_external_buffer_json(ofm_size, ifm_size, wgt_size);
#endif // ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(aie_ifm, "ifm");
    log_tensor(cpu_ofm, "cpu_ofm");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm.data), ifm_size,
                 reinterpret_cast<void*>(qdq_params), wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm.data, qdq_params);
    g_compute_graph.end();
#endif //USE_CERT_LIBRARY
    log_tensor(cpu_ofm, "aie_ofm");

    int epsilon = 1;
    int err = cmp_tensor(cpu_ofm, aie_ofm, sign, epsilon);
    if (err == 0)
        printf("MAXPOOL_NOQDQ_A8 DI_PASS: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d\n", Yi, Xi, Co, Yo, Xo, Co, Ky, Kx, Sy, Sx);
    else
        printf("MAXPOOL_NOQDQ_A8 DI_FAIL: Yi=%d, Xi=%d, Ci=%d, Yo=%d, Xo=%d, Co=%d, Ky=%d, Kx=%d, Sy=%d, Sx=%d\n", Yi, Xi, Co, Yo, Xo, Co, Ky, Kx, Sy, Sx);
#endif // ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm.data);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm.data);
#endif // !ASM_MODE
    free(cpu_ofm.data);
    return 0;
}

int main(void)
{
    auto cfg = load_json("maxpool_cfg.json");
    int const Yi = extract_json(cfg, "Y_IN");
    int const Xi = extract_json(cfg, "X_IN");
    int const Co = extract_json(cfg, "C_OUT");
    int const Yo = extract_json(cfg, "Y_OUT");
    int const Xo = extract_json(cfg, "X_OUT");

    int const Ky = extract_json(cfg, "KERNEL_Y");
    int const Kx = extract_json(cfg, "KERNEL_X");

    int const Sy = extract_json(cfg, "STRIDE_Y");
    int const Sx = extract_json(cfg, "STRIDE_X");

    int const Py = extract_json(cfg, "PAD_Y");
    int const Px = extract_json(cfg, "PAD_X");

    int const sign = extract_json(cfg, "SIGN");

    int const read_ifm = extract_json(cfg, "READ_IFM");

    printf("IFM dimension: YIN x XIN x CIN = %d x %d x %d \n", Yi, Xi, Co);
    printf("OFM dimension: YOUT x XOUT x COUT = %d x %d x %d \n", Yo, Xo, Co);
    printf("Kernel dimension: KY x KX = %d x %d \n", Ky, Kx);
    printf("Stride: SY x SX = %d x %d \n", Sy, Sx);
    printf("Padding: PY x PX = %d x %d \n", Py, Px);
    printf("Sign: %d \n", sign);

    run_maxpool_noqdq_a8(Yo, Xo, Co, Yi, Xi, Ky, Kx, Sy, Sx, Py, Px, sign, read_ifm);

    return 0;
}
