#include "conv/conv_host_runtime.hpp"
#include <cstdio>
#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__

#include "conv_xint8/kernel_setup/lcp_compute.h"

#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__
void write_bin_file(std::string filename, char* data, size_t size)
{
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

template<typename T>
T* allocate_memory(size_t size) {
#ifdef __TXNRT__
    return static_cast<T*>(malloc(size));
#else
    return static_cast<T*>(adf::GMIO::malloc(size));
#endif
}

template<typename T>
void deallocate_memory(T* ptr) {
#ifdef __TXNRT__
    free(ptr);
#else
    adf::GMIO::free(ptr);
#endif
}

template<typename Ta, typename Tw, typename Tb>
void run_conv_testbench() {
    /*NOTE: without real QDQ applied, the uint8 maxim = 255, so we have to shift the acc */
    //1. cacuate the accumulation size;
    //2. limit both IFM and WGT range[0, 1], if the range is bigger, the acc_size will be adjusted accordingly
    int acc_size = round_up_to_multiple(C_IN, C_IN_SUBV) * KERNEL_Y * KERNEL_X ;
    int8_t shift_res = (acc_size < 255) ? 0 : (acc_size / 255);
    shift_res = (IS_XINT8 == 1 || IS_A8W8 ==1) ? static_cast<int8_t>(std::log2(shift_res)) : 0;
    int8_t shift_tdm = 0;
    TestConfig cfg = {
        .Ci = C_IN, .Yi = Y_IN, .Xi = X_IN,
        .Yis = YIS, .Xis = XIS, .Yos = YOS, .Xos = XOS,
        .Co = C_OUT, .Yo = Y_OUT, .Xo = X_OUT,
        .Ky = KERNEL_Y, .Kx = KERNEL_X,
        .Sy = STRIDE_Y, .Sx = STRIDE_X,
        .Py_b = PAD_Y_BEFORE, .Px_b = PAD_X_BEFORE,
        .Py_a = PAD_Y_AFTER, .Px_a = PAD_X_AFTER,
        .Cis = C_IN_SUBV, .Cos = C_OUT_SUBV,
        .Co_split = C_OUT_SPLIT,
        .shift_res = shift_res,
        .shift_tdm = shift_tdm,
#if IS_XINT8
        .lrelu_alpha = 205,
        .lrelu_shift = 11,
        .enable_matAdd = ENABLE_ADD,
        .elw_shift_ifm1 = 8, .elw_shift_ifm2 = 8, .elw_shift_ofm = 8,
        .epsilon = 1,
        .ParamSize = 256,
        .fused_op = QUOTE(FUSED_OP),
#else
        .lrelu_alpha = 1,
        .lrelu_shift = 0,
        .enable_matAdd = false,
        .elw_shift_ifm1 = 0, .elw_shift_ifm2 = 0, .elw_shift_ofm = 0,
        .epsilon = 0,
        .ParamSize = 0,
#endif
    };

    cfg.Cop = round_up_to_multiple(cfg.Co, cfg.Cos * cfg.Co_split);
    cfg.Cip = round_up_to_multiple(cfg.Ci, cfg.Cis);

#if IS_XINT8
    if (cfg.fused_op == "Relu") {
        cfg.lrelu_alpha = 0;
        cfg.lrelu_shift = 0;
    } else if (cfg.fused_op == "None") {
        cfg.lrelu_alpha = 1;
        cfg.lrelu_shift = 0;
    }
#endif

    int ifm_size = ActTensor<Ta>::size(cfg.Ci, cfg.Yi, cfg.Xi);
    int ofm_size = ActTensor<Ta>::size(cfg.Co, cfg.Yo, cfg.Xo);
    int wgt_size = ConvWgtTensor<Tw, Tb, IS_XINT8, IS_A8W8>::size(
            cfg.Cop, cfg.Cip,
            cfg.Ky, cfg.Kx, cfg.Cos, cfg.Cis);
    const size_t aie_skipadd_size = ActTensor<Ta>::size(cfg.Co, cfg.Yo, cfg.Xo);
    ActTensor<Ta> aie_ifm(cfg.Ci, cfg.Yi, cfg.Xi, allocate_memory<Ta>(ifm_size));
    ActTensor<Ta> aie_ofm(cfg.Co, cfg.Yo, cfg.Xo, allocate_memory<Ta>(ofm_size));
    ConvWgtTensor<Tw, Tb, IS_XINT8, IS_A8W8> aie_wgt(
            cfg.Cop, cfg.Cip,
            cfg.Ky, cfg.Kx,
            cfg.Cos, cfg.Cis,
            allocate_memory<Tw>(wgt_size));
    ActTensor<Ta> aie_skipadd_ifm(cfg.Co, cfg.Yo, cfg.Xo, allocate_memory<Ta>(aie_skipadd_size));
    ActTensor<Ta> cpu_ofm(cfg.Co, cfg.Yo, cfg.Xo, malloc(ofm_size));
    int8_t* ifm_add_buffer = allocate_memory<int8_t>(ifm_size + aie_skipadd_size);
    void* weights_bias_convP = allocate_memory<void>(cfg.ParamSize + wgt_size);
    aie_ifm.init_random();
    aie_wgt.init_random();
    aie_skipadd_ifm.init_random();
    memcpy(ifm_add_buffer, aie_ifm.data, ifm_size);

#if IS_XINT8
    arch_params_lcp_t conv1x1_kernel_params = compute_conv_kernel_params( cfg.Xis,  cfg.Cis,  cfg.Yis,  cfg.Ky,  cfg.Kx,
                    cfg.Xos,  cfg.Cos,  cfg.Yos,  cfg.Sx,  cfg.Sy, cfg.shift_res,  cfg.elw_shift_ifm1,
                    cfg.elw_shift_ifm2,  cfg.elw_shift_ofm, cfg.lrelu_alpha, cfg.lrelu_shift);
    memcpy(weights_bias_convP, (void *)&conv1x1_kernel_params, sizeof(conv1x1_kernel_params)); //TODO This line need to be removed once the new kernel is available.
    memcpy(static_cast<int8_t*>(weights_bias_convP) + cfg.ParamSize, aie_wgt.data, wgt_size); //TODO +256 Need to be removed once the new kernel is available.
    wgt_size += cfg.ParamSize;
#else
    for (int co = 0; co < aie_wgt.Co; ++co) {
            aie_wgt.set_qdq_c0(co, 0);
    }
    aie_wgt.set_qdq_c1(0);
    aie_wgt.set_qdq_c2(1);
    aie_wgt.set_shift_tdm(cfg.shift_tdm);
    aie_wgt.set_shift_res(cfg.shift_res);
    aie_wgt.set_zp_wgt(0);
    memcpy(weights_bias_convP, aie_wgt.data, wgt_size);
#endif

    cpu_conv_2d<Ta, Tw, Tb, IS_XINT8, IS_A8W8>(aie_ifm, aie_wgt, cpu_ofm, cfg);

    if (cfg.enable_matAdd) {
        memcpy((int8_t*) ifm_add_buffer + ifm_size, aie_skipadd_ifm.data, cfg.Co * cfg.Yo * cfg.Xo);
        ifm_size += aie_skipadd_size;
        cpu_add_2d(aie_skipadd_ifm, cpu_ofm, cpu_ofm,
                8 - cfg.elw_shift_ifm1, 8 - cfg.elw_shift_ifm2, 8 - cfg.elw_shift_ofm);
    }
#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
            DmaBins bins = run_dma_layer_config();
            bins.save();
            write_bin_file("ifm.bin", reinterpret_cast<char*>(ifm_add_buffer), (ifm_size));
            write_bin_file("wgt.bin", reinterpret_cast<char*>(weights_bias_convP), wgt_size );
            write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm.data), ofm_size);
    #else
    aie_ifm.print("AIE IFM =\n");
    aie_wgt.print("AIE WGT =\n");
    cpu_ofm.print("CPU OFM =\n");
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm.data, ifm_add_buffer, weights_bias_convP);
    g_compute_graph.end();
    aie_ofm.print("AIE OFM =\n");
    int err_count = check_result(cpu_ofm, aie_ofm,cfg.epsilon);
    std::string shape_name = "conv_"+std::to_string(cfg.Ci)+"x"+std::to_string(cfg.Yi)+"x"+std::to_string(cfg.Xi)+"_"+std::to_string(cfg.Co)+"x"+std::to_string(cfg.Yo)+"x"+std::to_string(cfg.Xo)+"_"+std::to_string(cfg.Kx)+"x"+std::to_string(cfg.Ky)+"_"+QUOTE(FUSED_OP)+"_"+std::to_string(ENABLE_ADD);
    if (err_count == 0) {
        printf("DI CHECK(%s): PASS\n", shape_name.c_str());
    } else {
        printf("DI CHECK(%s): FAIL\n", shape_name.c_str());
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif
    // Cleanup
    deallocate_memory(aie_ifm.data);
    deallocate_memory(aie_ofm.data);
    deallocate_memory(aie_wgt.data);
    deallocate_memory(aie_skipadd_ifm.data);
    deallocate_memory(weights_bias_convP);
    deallocate_memory(ifm_add_buffer);
    free(cpu_ofm.data);
}

int main() {
    srand(0xABCD);
#if IS_XINT8
        run_conv_testbench<int8_t, int8_t, int16_t>();
#elif IS_A8W8
        run_conv_testbench<uint8_t, uint8_t, uint8_t>();
#else
        run_conv_testbench<uint16_t, uint8_t, uint8_t>();
#endif
#ifndef __TXNRT__
    assert(false);
#endif
    return 0;
}
