#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__

using namespace std;

#include <stdexcept>
#include <iostream>
#include <vector>
#include <cstring>
#include <cassert>
#include <cstdlib> 
#include <stdint.h>
#include <assert.h>
#include <math.h>
#include <fenv.h>

#include "common.hpp"
#include "binary.hpp"

#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

std::string const ifm_bin_path1 = "../intermediate_bins/ifm1.bin";
std::string const ifm_bin_path2 = "../intermediate_bins/ifm2.bin";

const int ADD = 0;
const int SUB = 1;
const int MUL = 2;
const int DIV = 3;


int run_binary_int8(
    int const ifm_chs, int const ifm_chs_pad, int const ifm_cols, int const ifm_rows,
    int const ofm_chs, int const ofm_chs_pad, int const ofm_cols, int const ofm_rows,
    int const ifm_bytes, int const ofm_bytes,
    int const wgt_size,
    int const a_on_wgt, int const b_on_wgt,
    int const shift_in, int const shift_in1, int const shift_res,
    int const read_ifm, int const op_type
){
    if (a_on_wgt && b_on_wgt) {
        throw std::invalid_argument("Both A_ON_WGT and B_ON_WGT cannot be enabled simultaneously");
    }
    int ifm_size = ActTensor<int8_t>::size(ifm_chs_pad, ifm_cols, ifm_rows);
    int total_ifm_size = 0;
    int total_wgt_size = wgt_size;
    if (a_on_wgt || b_on_wgt) {
        total_ifm_size += ifm_size;
        total_wgt_size += ifm_size;
    } else {
        total_ifm_size += ifm_size * 2;
    }
    int ofm_size = ifm_size;

#if !ASM_MODE
    auto aie_ifm = static_cast<int8_t*>(adf::GMIO::malloc(total_ifm_size));
    auto aie_wgt = static_cast<int8_t*>(adf::GMIO::malloc(total_wgt_size));
#else
    auto aie_ifm = static_cast<int8_t*>(malloc(total_ifm_size));
    auto aie_wgt = static_cast<int8_t*>(malloc(total_wgt_size));

#endif // !ASM_MODE

    auto aie_matA = a_on_wgt ? static_cast<int8_t*>(aie_wgt) : static_cast<int8_t*>(aie_ifm);
    auto aie_matB = b_on_wgt ? static_cast<int8_t*>(aie_wgt) : static_cast<int8_t*>(aie_ifm + (ifm_size * !a_on_wgt));

    int* shift_in_val = (int*)aie_wgt;
    shift_in_val[0] = shift_in;
    shift_in_val[1] = shift_in1;
    shift_in_val[2] = shift_res;

    ActTensor<int8_t> aie_ifm_a(
        ifm_chs_pad, ifm_cols, ifm_rows,
        aie_matA
    );

    ActTensor<int8_t> aie_ifm_b(
        ifm_chs_pad, ifm_cols, ifm_rows,
        aie_matB
    );

    ActTensor<int8_t> cpu_ofm(
        ofm_chs_pad, ofm_cols, ofm_rows,
        malloc(ofm_size)
    );

#if !ASM_MODE
    ActTensor<int8_t> aie_ofm(
        ofm_chs_pad, ofm_cols, ofm_rows,
        adf::GMIO::malloc(ofm_size)
    );
#else
    ActTensor<int8_t> aie_ofm(
        ofm_chs_pad, ofm_cols, ofm_rows,
        malloc(ofm_size)
    );
#endif // !ASM_MODE

    if (read_ifm) {
        read_bin_file(ifm_bin_path1, reinterpret_cast<char*>(aie_matA), ifm_size);
        read_bin_file(ifm_bin_path2, reinterpret_cast<char*>(aie_matB), ifm_size);
    }
    else {
        init_tensor_random(aie_ifm_a, ofm_chs, 0, 4);
        init_tensor_random(aie_ifm_b, ofm_chs, 0, 4);
    }

    switch(op_type){
        case ADD:
            add(aie_ifm_a, aie_ifm_b, cpu_ofm, shift_in, shift_in1, shift_res); break;
    }

#if ASM_MODE
    write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), total_ifm_size);
    write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), total_wgt_size);
    write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm.data), ofm_size);
#endif

#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(aie_ifm_a, "IFM_A =\n");
    log_tensor(aie_ifm_b, "IFM_B =\n");
    log_tensor(cpu_ofm, "CPU OFM =\n");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm), ifm_size*2,
                 reinterpret_cast<void*>(aie_wgt), total_wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm, aie_wgt);
    g_compute_graph.end();
#endif // USE_CERT_LIBRARY
    log_tensor(aie_ofm, "AIE OFM =\n");

    int epsilon = 1;
    int err = cmp_tensor(cpu_ofm, aie_ofm, 1, epsilon);
    if (err == 0)
        printf("DI_PASS\n");
    else
        printf("DI_FAIL\n");
#endif // !ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm);
    adf::GMIO::free(aie_wgt);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm);
    free(aie_wgt);
#endif // !ASM_MODE
    free(cpu_ofm.data);
    return 0;
}




// quantized types
template<typename Ti, typename To, typename tzp>
int run_binary_16(
    int const ifm_chs, int const ifm_chs_pad, int const ifm_cols, int const ifm_rows,
    int const ofm_chs, int const ofm_chs_pad, int const ofm_cols, int const ofm_rows,
    int const ifm_bytes, int const ofm_bytes,
    int const wgt_size, int const a_on_wgt, int const b_on_wgt,
    float dq_a_zp, float dq_a_sc,
    float dq_b_zp, float dq_b_sc,
    float q_zp, float q_sc,
    bool dq_enable, bool q_enable,
    int const read_ifm, int const op_type,
    int const read_model_data,
    std::string const_path,
    std::string node_name
){
    if (a_on_wgt && b_on_wgt) {
        throw std::invalid_argument("Both A_ON_WGT and B_ON_WGT cannot be enabled simultaneously");
    }

    int ifm_size = ActTensor<Ti>::size(ifm_chs_pad, ifm_cols, ifm_rows);
    int total_ifm_size = 0;
    int total_wgt_size = wgt_size;
    if (a_on_wgt || b_on_wgt) {
        total_ifm_size += ifm_size;
        total_wgt_size += ifm_size;
    } else {
        total_ifm_size += ifm_size * 2;
    }
    int ofm_size = ifm_size; 

#if !ASM_MODE
    auto aie_ifm = static_cast<void*>(adf::GMIO::malloc(total_ifm_size));
    auto aie_wgt = static_cast<void*>(adf::GMIO::malloc(total_wgt_size));
#else
    auto aie_ifm = static_cast<void*>(malloc(total_ifm_size));
    auto aie_wgt = static_cast<void*>(malloc(total_wgt_size));
#endif // !ASM_MODE

    Ti* aie_matA = static_cast<Ti*>(a_on_wgt ? aie_wgt : aie_ifm);
    Ti* aie_matB = b_on_wgt ? (Ti*)(reinterpret_cast<int8_t*>(aie_wgt)) : (Ti*)(reinterpret_cast<int8_t*>(aie_ifm) + (ifm_size * !a_on_wgt));

    assert(total_wgt_size > 24); // need at least 24 bytes for qdq
    // initialize wgt buffer to zero
    for (int i = 0; i < total_wgt_size; i++) {
        reinterpret_cast<int8_t*>(aie_wgt)[i] = 0;
    }

    ActTensor<Ti> aie_ifm_a(
        ifm_chs_pad, ifm_cols, ifm_rows,
        aie_matA
    );

    ActTensor<Ti> aie_ifm_b(
        ifm_chs_pad, ifm_cols, ifm_rows,
        aie_matB
    );

    ActTensor<To> cpu_ofm(
        ofm_chs_pad, ofm_cols, ofm_rows,
        malloc(ofm_size)
    );

#if !ASM_MODE
    ActTensor<To> aie_ofm(
        ofm_chs_pad, ofm_cols, ofm_rows,
        adf::GMIO::malloc(ofm_size)
    );
#else
    ActTensor<To> aie_ofm(
        ofm_chs_pad, ofm_cols, ofm_rows,
        malloc(ofm_size)
    );
#endif // !ASM_MODE

    std::string const ifm_bin_path1 = "../intermediate_bins/ifm1.bin";
    std::string const ifm_bin_path2 = "../intermediate_bins/ifm2.bin";
    if (read_ifm) {
        int bytes;
        bytes = read_bin_file(ifm_bin_path1, reinterpret_cast<char*>(aie_matA), ifm_size);
        if (bytes == 0) {
            throw std::runtime_error("Error reading IFM bin file 1");
        }
        bytes = read_bin_file(ifm_bin_path2, reinterpret_cast<char*>(aie_matB), ifm_size);
        if (bytes == 0) {
            throw std::runtime_error("Error reading IFM bin file 2");
        }
    }
    else {
        init_tensor_random<Ti>(aie_ifm_a, ofm_chs, 1, 16);
        init_tensor_random<Ti>(aie_ifm_b, ofm_chs, 1, 16);
    }

    BinaryQDQParams* qdq_prm = reinterpret_cast<BinaryQDQParams*>(b_on_wgt ? ((Ti*) (reinterpret_cast<int8_t*>(aie_wgt) + ifm_size)) : aie_wgt);
    if (!read_model_data) {
        qdq_prm->dq_a_zp = dq_a_zp;
        qdq_prm->dq_a_sc = dq_a_sc;
        qdq_prm->dq_b_zp = dq_b_zp;
        qdq_prm->dq_b_sc = dq_b_sc;
        qdq_prm->q_zp = q_zp;
        qdq_prm->q_sc = 1.0f / q_sc;
        qdq_prm->q_enable = q_enable;
        qdq_prm->dq_enable = dq_enable;
    }
    else {
        binary_op_init_model_data<Ti>(
            const_path,
            node_name,
            qdq_prm,
            dq_enable,
            q_enable,
            aie_matA,
            aie_matB,
            ifm_chs_pad, ifm_chs, ifm_rows, ifm_cols, a_on_wgt, b_on_wgt
        );
    }

    if (!read_model_data) {
        ActTensor<QDQFloatType> *aie_ifm_t0, *aie_ifm_t1, *cpu_ofm_t2;

        // create temporary buffers if needed
        if (qdq_prm->dq_enable){
            QDQFloatType* T0 = static_cast<QDQFloatType*>(malloc(ifm_size));
            QDQFloatType* T1 = static_cast<QDQFloatType*>(malloc(ifm_size));
            dequant<uint16_t*, QDQFloatType*, float>((uint16_t*)aie_ifm_a.data, T0,  ifm_rows, ifm_cols, ifm_chs_pad, qdq_prm->dq_a_sc, float(qdq_prm->dq_a_zp));
            dequant<uint16_t*, QDQFloatType*, float>((uint16_t*)aie_ifm_b.data, T1,  ifm_rows, ifm_cols, ifm_chs_pad, qdq_prm->dq_b_sc, float(qdq_prm->dq_b_zp));
            aie_ifm_t0 = new ActTensor<QDQFloatType>( ifm_chs_pad, ifm_cols, ifm_rows, T0);
            aie_ifm_t1 = new ActTensor<QDQFloatType>( ifm_chs_pad, ifm_cols, ifm_rows, T1);
        } else{ // use existing buffers if not needed
            aie_ifm_t0 = reinterpret_cast<ActTensor<QDQFloatType>*>(&aie_ifm_a);
            aie_ifm_t1 = reinterpret_cast<ActTensor<QDQFloatType>*>(&aie_ifm_b);
        }

        if (qdq_prm->q_enable){ //create temp buf if needed
            QDQFloatType* T2 = static_cast<QDQFloatType*>(malloc(ofm_size));
            cpu_ofm_t2 = new ActTensor<QDQFloatType> (ofm_chs_pad, ofm_cols, ofm_rows, T2 );
        } else { //use existing buf if not needed
            cpu_ofm_t2 = new ActTensor<QDQFloatType> (ofm_chs_pad, ofm_cols, ofm_rows, cpu_ofm.data );
        }


        switch(op_type){
            case ADD:
                add(*aie_ifm_t0, *aie_ifm_t1, *cpu_ofm_t2); break;
            case MUL:
                mul(*aie_ifm_t0, *aie_ifm_t1, *cpu_ofm_t2); break;
        }

        // remove temporary buffers if they were made
        if(qdq_prm->dq_enable){
            free(aie_ifm_t0);
            free(aie_ifm_t1);
        }

        if (qdq_prm->q_enable){
            quant_bfloat16_to_uint16<QDQFloatType*, float>(cpu_ofm_t2->data, (uint16_t*)cpu_ofm.data, ofm_rows, ofm_cols, ofm_chs_pad, qdq_prm->q_sc, float(qdq_prm->q_zp));
            free(cpu_ofm_t2);
        }
    }
#if ASM_MODE
    write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), total_ifm_size);
    write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), total_wgt_size);
    write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm.data), ofm_size);
#endif // ASM_MODE

#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(aie_ifm_a, "IFM A=\n");
    log_tensor(aie_ifm_b, "IFM B=\n");
    log_tensor(cpu_ofm, "CPU OFM =\n");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), ofm_size,
                 reinterpret_cast<void*>(aie_ifm), ifm_size*2,
                 reinterpret_cast<void*>(aie_wgt), total_wgt_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, aie_ifm, aie_wgt);
    g_compute_graph.end();
#endif // USE_CERT_LIBRARY
    log_tensor(aie_ofm, "AIE OFM =\n");

    int epsilon = 1;
    int err = cmp_tensor(cpu_ofm, aie_ofm, 1, epsilon);
    if (err == 0)
        printf("DI_PASS\n");
    else
        printf("DI_FAIL\n");
#endif // !ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_ifm);
    adf::GMIO::free(aie_wgt);
    adf::GMIO::free(aie_ofm.data);
#else
    free(aie_ifm);
    free(aie_wgt);
#endif // !ASM_MODE
    free(cpu_ofm.data);
    return 0;
}


int main(void)
{
    auto cfg = load_json("binary_cfg.json");
    int ifm_rows  = extract_json(cfg, "IFM_ROWS");
    int ifm_cols  = extract_json(cfg, "IFM_COLS");
    int ifm_chs   = extract_json(cfg, "IFM_CHS");
    int ifm_chs_pad   = extract_json(cfg, "IFM_CHS_PAD");
    int ofm_chs_pad   = ifm_chs_pad;
    int ifm_bytes = extract_json(cfg, "IFM_BYTES");
    int ofm_rows  = extract_json(cfg, "OFM_ROWS");
    int ofm_cols  = extract_json(cfg, "OFM_COLS");
    int ofm_chs   = extract_json(cfg, "OFM_CHS");
    int ofm_bytes = extract_json(cfg, "OFM_BYTES");
    int wgt_size  = extract_json(cfg, "WGT_SIZE");
    int a_on_wgt  = extract_json(cfg, "A_ON_WGT");
    int b_on_wgt  = extract_json(cfg, "B_ON_WGT");
    bool q_enable  = bool(extract_json(cfg, "Q_ENABLE"));
    bool dq_enable  = bool(extract_json(cfg, "DQ_ENABLE"));
    int op_type  =  extract_json(cfg, "OP_TYPE");
    int shift_in  = 0;
    int shift_in1 = 0;
    int shift_res = 0;

    float q_zp = generateRandomFloat(-10, 0);
    float q_sc = generateRandomFloat(0.2, 0.5);
    float dq_a_zp = q_zp;
    float dq_b_zp = q_zp;
    float dq_a_sc = 1.0f / q_sc;
    float dq_b_sc = 1.0f / q_sc;

    int read_ifm = extract_json(cfg, "READ_IFM");
    int const read_md = extract_json(cfg, "READ_MD");
    std::string const node_name = extract_json_str(cfg, "NODE_NAME");
    std::string const md_path = extract_json_str(cfg, "MD_PATH");

    if (ifm_bytes == 1){
        int shift_in  = 0;
        int shift_in1 = 0;
        int shift_res = 0;
        run_binary_int8(
            ifm_chs, ifm_chs_pad, ifm_cols, ifm_rows,
            ofm_chs, ofm_chs_pad, ofm_cols, ofm_rows,
            ifm_bytes, ofm_bytes,
            wgt_size, a_on_wgt, b_on_wgt,
            shift_in, shift_in1, shift_res, read_ifm,
            op_type
        );
        return 0;
    }
    assert(ifm_bytes == 2);

    if (dq_enable && !q_enable) {
        run_binary_16<uint16_t, QDQFloatType, int16_t>( ifm_chs, ifm_chs_pad, ifm_cols, ifm_rows, ofm_chs, ofm_chs_pad, ofm_cols, ofm_rows, ifm_bytes, ofm_bytes, wgt_size, b_on_wgt, dq_a_zp, dq_a_sc, dq_b_zp, dq_b_sc, q_zp, q_sc, dq_enable, q_enable, read_ifm, op_type, read_md, md_path, node_name);
    }
    else if (!dq_enable && q_enable) {
        run_binary_16<QDQFloatType, uint16_t, int16_t>( ifm_chs, ifm_chs_pad, ifm_cols, ifm_rows, ofm_chs, ofm_chs_pad, ofm_cols, ofm_rows, ifm_bytes, ofm_bytes, wgt_size, b_on_wgt, dq_a_zp, dq_a_sc, dq_b_zp, dq_b_sc, q_zp, q_sc, dq_enable, q_enable, read_ifm, op_type, read_md, md_path, node_name);
    }
    else if (dq_enable && q_enable) {
        run_binary_16<uint16_t, uint16_t, int16_t>( ifm_chs, ifm_chs_pad, ifm_cols, ifm_rows, ofm_chs, ofm_chs_pad, ofm_cols, ofm_rows, ifm_bytes, ofm_bytes, wgt_size, b_on_wgt, dq_a_zp, dq_a_sc, dq_b_zp, dq_b_sc, q_zp, q_sc, dq_enable, q_enable, read_ifm, op_type, read_md, md_path, node_name);
    }
    else if (!dq_enable && !q_enable) {
        run_binary_16<QDQFloatType, QDQFloatType, int16_t>( ifm_chs, ifm_chs_pad, ifm_cols, ifm_rows, ofm_chs, ofm_chs_pad, ofm_cols, ofm_rows, ifm_bytes, ofm_bytes, wgt_size, b_on_wgt, dq_a_zp, dq_a_sc, dq_b_zp, dq_b_sc, q_zp, q_sc, dq_enable, q_enable, read_ifm, op_type, read_md, md_path, node_name);
    }

    else {
        throw std::runtime_error("Check qdq_params and DI dq_enable & q_enable");
    }

    return 0;
}
