#include "conv/gemm_host_runtime.hpp"
#include "conv/conv_host_runtime.hpp"
#include "bilinear_pixel_resize_bf16/bilinear_pixel_resize_bf16_runtime.hpp"

#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__

using namespace std;

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

void write_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

int run_bilinear_pixel_resize_bf16(
    int const half_pixel, int const align_corner, int const asymmetric,
    int const Co, int const Xo, int const Yo,
    int const Ci, int const Xi, int const Yi,
    int const Cos, int const Xos, int const Yos, 
    int const Cis, int const Xis, int const Yis,
    int const Xis_step, int const Yis_step,
    int const Xis_offset, int const Yis_offset,
    int const Y_split, int const X_split,
    int const dq_enable_tb, int const q_enable_tb
) {
    using Ta = uint16_t;
    using To = uint16_t;
    using Tw = uint16_t;
    bool read_model_data = false;

    CoordinateTransformationMode mode;
    if (half_pixel) {
        mode = CoordinateTransformationMode::HalfPixel;
    } else if (align_corner) {
        mode = CoordinateTransformationMode::AlignCorners;
    } else if (asymmetric) {
        mode = CoordinateTransformationMode::Asymmetric;
    } else {
        printf("Error: no coordinate transformation mode is set.\n");
    }

    int ifm_size = ActTensor<Ta>::size(Ci, Yi, Xi);
    int ofm_size = ActTensor<Ta>::size(Co, Yo, Xo);
    int wgt_size = bilinearWGT<Tw>::size(Yo, Xo, Yos, Xos);
    int qdq_params_size = sizeof(BilinearQDQParams);
    int const_bo_size = wgt_size + qdq_params_size;
    printf("ifm_size = %d\n", ifm_size);
    printf("ofm_size = %d\n", ofm_size);
    printf("wgt_size = %d\n", wgt_size);
    printf("qdq_params_size = %d\n", qdq_params_size);
    printf("const_bo_size = %d\n", const_bo_size);
    bilinearWGT<Tw> wgt(
        mode, Yo, Xo, Yos, Xos,
        Yi, Xi, Yis, Xis,
        Yis_step, Xis_step,
        Yis_offset, Xis_offset,
        static_cast<Tw*>(malloc(wgt_size))
    );
    wgt.set_wgt();
    wgt.print_raw_indices();
    auto const_bo = malloc(const_bo_size);
#ifdef __TXNRT__
    ActTensor<Ta> aie_ifm(Ci, Yi, Xi, malloc(ifm_size));
    ActTensor<Ta> aie_ofm(Co, Yo, Xo, malloc(ofm_size));
#else
    ActTensor<Ta> aie_ifm(Ci, Yi, Xi, adf::GMIO::malloc(ifm_size));
    ActTensor<Ta> aie_ofm(Co, Yo, Xo, adf::GMIO::malloc(ofm_size));
    /* 
    NOTE: In the dataflow, the ifm and wgt are fetched on the same shim channel
    Due to AIESIM limitation, both the ifm and wgt must be part of the same GMIO buffer
    However there is no such limitation in the hardware and the ifm and wgt can be fetched
    on different BOs on DDR.
    */
    auto aie_ifm_const = adf::GMIO::malloc(ifm_size+const_bo_size);
#endif // __TXNRT__
    // This tensor is to store the CPU model output
    ActTensor<Ta> cpu_ofm(Co, Yo, Xo, malloc(ofm_size));
    BilinearQDQParams qdq_params;
    // DQ params
    uint16_t dq_enable = dq_enable_tb;
    float dq_sc = 1.0f;
    uint16_t dq_zp = 0;
    // Q params
    uint16_t q_enable = q_enable_tb;
    float q_sc = 1.0f;
    uint16_t q_zp = 0;

    if (read_model_data) {
        auto raw_ifm = malloc(ifm_size*2);
        auto raw_ofm = malloc(ofm_size);
        // read model data from binary files
        std::string folder = "resize_4";
        printf("Reading model data from %s data folder\n", folder.c_str());
        read_bin_file(folder+"/"+"input_0.bin", (char*)raw_ifm);
        read_bin_file(folder+"/"+"output_0.bin", (char*)raw_ofm);
        dq_enable = 1;
        dq_sc = 0.000030517578125;
        dq_zp = 32768;
        q_enable = 1;
        q_sc = 1/0.000030517578125;
        q_zp = 32768;
        // Copy raw_ifm data in C Y X order to aie_ifm in Y X C order
        for (int c = 0; c < Ci; c++) {
            for (int y = 0; y < Yi; y++) {
                for (int x = 0; x < Xi; x++) {
                    aie_ifm.at(c, y, x) = ((Ta*)raw_ifm)[c * Yi * Xi + y * Xi + x]; 
                }
            }
        }
        // Copy raw_ofm data in C Y X order to aie_ofm in Y X C order
        for (int c = 0; c < Co; c++) {
            for (int y = 0; y < Yo; y++) {
                for (int x = 0; x < Xo; x++) {
                    cpu_ofm.at(c, y, x) = ((Ta*)raw_ofm)[c * Yo * Xo + y * Xo + x];
                }
            }
        }
        init_bilinear_qdq_params(
            qdq_params,
            dq_enable, dq_sc, dq_zp,
            q_enable, q_sc, q_zp
        );
        free(raw_ifm);
        free(raw_ofm);
    } else {
        init_bilinear_qdq_params(
            qdq_params,
            dq_enable, dq_sc, dq_zp,
            q_enable, q_sc, q_zp
        );

        init_bilinear_input_tensor(aie_ifm);

        cpu_model(
            mode,
            aie_ifm, cpu_ofm,
            qdq_params,
            wgt
        );
    }
    // Pack the weights and qdq params into const_bo
    memcpy(const_bo, (void*)wgt.data, wgt_size);
    memcpy((char*)const_bo + wgt_size, &qdq_params, qdq_params_size);
    #ifndef __TXNRT__
        memcpy(aie_ifm_const, (void*)aie_ifm.data, ifm_size);
        memcpy((char*)aie_ifm_const + ifm_size, const_bo, const_bo_size);
    #endif
    

#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm.data), ifm_size);
            write_bin_file("wgt.bin", reinterpret_cast<char*>(const_bo), const_bo_size);
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm.data), ofm_size);
    #else
        aie_ifm.print("aie_ifm: ");
        cpu_ofm.print("cpu_ofm: ");
        g_compute_graph.init();
        run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm_const, nullptr);
        g_compute_graph.end();
        check_bilinear_output(cpu_ofm, aie_ofm, q_enable, 1.0f);
    #endif // __TXNRT__
#endif // __AIESIM__ || __TXNRT__

#ifdef __TXNRT__
    free(aie_ifm.data);
    free(aie_ofm.data);
#else
    adf::GMIO::free(aie_ifm.data);
    adf::GMIO::free(aie_ofm.data);
    adf::GMIO::free(aie_ifm_const);
#endif // __TXNRT__
    free(const_bo);
    free(wgt.data);
    free(wgt.raw_indices_X);
    free(wgt.raw_indices_Y);
    free(cpu_ofm.data);
    return 0;
}

int main(void) {
    int const half_pixel = HALF_PIXEL;
    int const align_corner = ALIGN_CORNERS;
    int const asymmetric = ASYMMETRIC;
    int const Co = COUT;
    int const Xo = XOUT;
    int const Yo = YOUT;
    int const Ci = CIN;
    int const Xi = XIN;
    int const Yi = YIN;
    int const Cos = COS;
    int const Xos = XOS;
    int const Yos = YOS;
    int const Cis = CIS;
    int const Xis = XIS;
    int const Yis = YIS;
    int const Xis_step = XIS_STEP;
    int const Yis_step = YIS_STEP;
    int const Xis_offset = XIS_OFFSET;
    int const Yis_offset = YIS_OFFSET;
    int const Y_split = Y_SPLIT;
    int const X_split = X_SPLIT;
    int const dq_enable_tb = DQ_ENABLE;
    int const q_enable_tb = Q_ENABLE;

    // check if only one coordinate transformation mode is set
    if (half_pixel + align_corner + asymmetric != 1) {
        printf("Error: no coordinate transformation mode is set.\n");
        return -1;
    }

    printf("half_pixel = %d\n", half_pixel);
    printf("align_corner = %d\n", align_corner);
    printf("asymmetric = %d\n", asymmetric);
    printf("IFM Dimensions: YIN x XIN x CIN = %d x %d x %d\n", Yi, Xi, Ci);
    printf("OFM Dimensions: YOUT x XOUT x COUT = %d x %d x %d\n", Yo, Xo, Co);
    printf("IFM subvolume Dimensions: YIS x XIS x CIS = %d x %d x %d\n", Yis, Xis, Cis);
    printf("OFM subvolume Dimensions: YOS x XOS x COS = %d x %d x %d\n", Yos, Xos, Cos);
    printf("IFM subvolume step: YIS_STEP x XIS_STEP = %d x %d\n", Yis_step, Xis_step);
    printf("IFM subvolume offset: YIS_OFFSET x XIS_OFFSET = %d x %d\n", Yis_offset, Xis_offset);
    printf("Y split = %d, X split = %d\n", Y_split, X_split);
    printf("DQ_ENABLE = %d, Q_ENABLE = %d\n", dq_enable_tb, q_enable_tb);

    run_bilinear_pixel_resize_bf16(
        half_pixel, align_corner, asymmetric,
        Co, Xo, Yo,
        Ci, Xi, Yi,
        Cos, Xos, Yos,
        Cis, Xis, Yis,
        Xis_step, Yis_step,
        Xis_offset, Yis_offset,
        Y_split, X_split,
        dq_enable_tb, q_enable_tb
    );

    return 0;
}