#include "conv/conv_host_runtime.hpp"
#include <cstdio>
#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__
#include <cstdint>
#include <fstream>

void write_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

int main(void)
{
    using Ta = std::conditional_t<IS_A8W8==1, uint8_t, uint16_t>;
    using Tw = uint8_t;
    using Tb = uint16_t;
    using To = std::conditional_t<IS_A8W8==1, uint8_t, uint16_t>;
    srand(0xABCD);
    int Ci = C_IN;
    int Yi = Y_IN;
    int Xi = X_IN;
    int Co = C_OUT;
    int Yo = Y_OUT;
    int Xo = X_OUT;
    int Ky = KERNEL_Y;
    int Kx = KERNEL_X;
    int Sy = STRIDE_Y;
    int Sx = STRIDE_X;
    int Py_b = PAD_Y_BEFORE;
    int Px_b = PAD_X_BEFORE;
    int Py_a = PAD_Y_AFTER;
    int Px_a = PAD_X_AFTER;
    int constexpr Cis = C_IN_SUBV;
    int constexpr Cos = C_OUT_SUBV;
    int constexpr Co_split = C_OUT_SPLIT;

    int Cip = round_up_to_multiple(Ci, Cis);
    int Cop = round_up_to_multiple(Co, Cos * Co_split);

    int real_filter = 3;
    int Dy = (Ky - 1) / (real_filter - 1);
    int Dx = Dy;

    int ifm_size = ActTensor<Ta>::size(Ci, Yi, Xi);
    int wgt_size = DwcWgtTensor<Tw, Cos>::size(Cop, Ky, Kx);
    int ofm_size = ActTensor<To>::size(Co, Yo, Xo);

#ifdef __TXNRT__
    ActTensor<Ta> aie_ifm(Ci, Yi, Xi, malloc(ifm_size));
    DwcWgtTensor<Tw, Cos> aie_wgt(Cop, Ky, Kx, malloc(wgt_size));
    ActTensor<To> aie_ofm(Co, Yo, Xo, malloc(ofm_size));
#else
    ActTensor<Ta> aie_ifm(Ci, Yi, Xi, adf::GMIO::malloc(ifm_size));
    DwcWgtTensor<Tw, Cos> aie_wgt(Cop, Ky, Kx, adf::GMIO::malloc(wgt_size));
    ActTensor<To> aie_ofm(Co, Yo, Xo, adf::GMIO::malloc(ofm_size));
#endif // __TXNRT__
    ActTensor<To> cpu_ofm(Co, Yo, Xo, malloc(ofm_size));

    aie_ifm.init_random();
    aie_wgt.init_random();
    for (int co = 0; co < aie_wgt.C; ++co) {
    // for (int co = 0; co < 1; ++co) {
        aie_wgt.set_qdq_c0(co, 0);
    }
    aie_wgt.set_qdq_c1(0);
    aie_wgt.set_qdq_c2(1);
    aie_wgt.set_shift_tdm(0);
    aie_wgt.set_shift_res(0);
    aie_wgt.set_zp_wgt(0);
    int shift = 0;
    cpu_conv_dw(aie_ifm, aie_wgt, cpu_ofm, Sy, Sx, Py_b, Px_b, Py_a, Px_a, shift);

#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm.data), ifm_size);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt.data), wgt_size);
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm.data), ofm_size);
    #else
    aie_ifm.print("AIE IFM =\n");
    aie_wgt.print("AIE WGT =\n");
    cpu_ofm.print("CPU OFM =\n");
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm.data, aie_wgt.data);
    g_compute_graph.end();
    aie_ofm.print("AIE OFM =\n");
    int err_count = check_result(cpu_ofm, aie_ofm);
    if (err_count == 0) {
        printf("DI: PASS\n");
    } else {
        printf("DI: FAIL\n");
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif // __AIESIM__ || __TXNRT__

    #ifdef __TXNRT__
    free(aie_ifm.data);
    free(aie_wgt.data);
    free(aie_ofm.data);
    #else
    adf::GMIO::free(aie_ifm.data);
    adf::GMIO::free(aie_wgt.data);
    adf::GMIO::free(aie_ofm.data);
    #endif // __TXNRT__
    free(cpu_ofm.data);

    #ifndef __TXNRT__
    assert(false);
    #endif // __TXNRT__
    return 0;
}
