#include <string>
#include <sstream>
#include <fstream>
#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#include <filesystem>
#include <random>

#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__
#include "common.hpp"

#include "uniop.hpp"
#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

using namespace std;

#ifdef DEBUG_MODE
#define DEBUG_PRINT(...) printf(__VA_ARGS__)
#else
#define DEBUG_PRINT(...)
#endif


int get_bytes_per_elem(string& s)
{
    int N = s.size();
    assert(N >= 2);
    unordered_map<string, int> bytes_per_elem = {
        {"uint8", 1},
        {"int8", 1},
        {"uint16", 2},
        {"int16", 2},
        {"uint32", 4},
        {"int32", 4},
        {"uint", 4},
        {"int", 4},
        {"float", 4},
        {"float32", 4},
        {"float16", 2},
        {"bfloat16", 2}
    };

    if(bytes_per_elem.count(s) > 0)
       return bytes_per_elem[s];
    else
    {
       printf("s = %s\n", s.c_str());
       throw std::invalid_argument("non-supported input/output type!");
    }
}

template <typename TG, typename TB, typename TGDQ, typename TBDQ, typename TLUT>
int process_uniop(std::string function,
                int GammaBeta_dim,
                uint32_t lut_len,
                int32_t lut_start_val,
                int32_t lut_end_val,
                int gen_io)
{
    auto cfg = load_json("uniop_cfg.json");

    int const _Y = extract_json(cfg, "_Y");
    int const _X = extract_json(cfg, "_X");
    int const _C = extract_json(cfg, "_C");
    int const rC = extract_json(cfg, "_trueC");
    std::cout << "Y = " << _Y << " X = " << _X << " C = " << _C << std::endl;

    std::string test_data_dir = cfg["TEST_DATA_DIR"];

    DEBUG_PRINT("test_data_dir : %s\n", test_data_dir.c_str());
    DEBUG_PRINT("function : %s\n", function.c_str());
    DEBUG_PRINT("(%d, %d, %d)\n", _Y, _X, _C);

    string dtype_act = cfg["DTYPE_ACT"], dtype_out = cfg["DTYPE_OUT"];
    int act_bytes_per_elem = get_bytes_per_elem(dtype_act);
    int out_bytes_per_elem = get_bytes_per_elem(dtype_out);

    if(act_bytes_per_elem == -1 or out_bytes_per_elem == -1)
    {
        DEBUG_PRINT("Error! Non-supported input/output type!\n");
        exit(1);
    }
    
    typedef uint16_t act_type;
	typedef uint16_t out_type;
    
    int const act_size = _Y * _X * _C * act_bytes_per_elem;
    int const out_size = _Y * _X * _C * out_bytes_per_elem;

    int const abo_size = act_size;
    int const bbo_size = UnaryOpWgtTensor<TGDQ, TBDQ, TLUT>::size(GammaBeta_dim, lut_len);
    int const obo_size = out_size;

    void* aie_a_bo = allocate(abo_size);
    UnaryOpWgtTensor<TGDQ, TBDQ, TLUT> aie_b_bo(GammaBeta_dim, lut_len, allocate(bbo_size));

    void* aie_o_bo = allocate(out_size);
    void* cpu_o_bo = allocate(out_size, false);
    
    int const read_model_data = extract_json(cfg, "READ_MD");
    bool dequant_enable = (extract_json(cfg, "DEQUANT_ENABLE") == 1);
    bool quant_enable = (extract_json(cfg, "QUANT_ENABLE") == 1);
    bool nlf_enable = (extract_json(cfg, "NLF_ENABLE") == 1);
    float pwla_alpha = extract_json_float(cfg, "PWLA_ALPHA");
    if (!read_model_data) 
    {
        if (gen_io) {
            read_bin_file(test_data_dir+"input_0.bin" , (char*)aie_a_bo, act_size);
            read_bin_file(test_data_dir+"output_0.bin", (char*)cpu_o_bo, out_size);
            if(GammaBeta_dim)
            {
                read_bin_file(test_data_dir+"input_1.bin" , (char*)(&(aie_b_bo.Beta_at(0))), 2 * GammaBeta_dim * sizeof(uint16_t));
            }
        }

        if(dequant_enable)
        {
            string out_dtype = cfg["DTYPE_OUT"];
            int dq_float_type = (out_dtype=="float32")? uniop_qdq_param::FloatType::FLOAT32 : uniop_qdq_param::FloatType::FLOAT16;
            float zp_i = extract_json_float(cfg, "DEQUANT_ZERO_POINT");
            float sc_i = extract_json_float(cfg, "DEQUANT_SCALE");
            aie_b_bo.set_dq(zp_i, sc_i, cfg["SIGN_ACT"], dq_float_type);
        }

        if(quant_enable)
        {
            string act_dtype = cfg["DTYPE_ACT"];
            int q_float_type = (act_dtype=="float32")? uniop_qdq_param::FloatType::FLOAT32 : uniop_qdq_param::FloatType::FLOAT16;
            float zp_o = extract_json_float(cfg, "QUANT_ZERO_POINT");
            float sc_o = 1.0f/extract_json_float(cfg, "QUANT_SCALE");
            aie_b_bo.set_q(zp_o, sc_o, cfg["SIGN_OUT"], q_float_type);
        }
    } 
    else 
    {
        std::string const node_name = extract_json_str(cfg, "NODE_NAME");
        std::string const md_path = extract_json_str(cfg, "MD_PATH");
        int sign_A = cfg["SIGN_ACT"];
        int sign_O = cfg["SIGN_OUT"];
        std::string const dq_float_str = extract_json_str(cfg, "DTYPE_OUT");
        std::string const q_float_str = extract_json_str(cfg, "DTYPE_ACT");
        uniop_init_model_data<TGDQ, TBDQ, TLUT>(md_path, node_name, &aie_b_bo, dequant_enable, quant_enable, sign_A, sign_O, dq_float_str, q_float_str);
        if (GammaBeta_dim)
        {
            uniop_init_gamma_beta_model_data<TG, TB, TGDQ, TBDQ>(function, md_path, node_name, &aie_b_bo, dequant_enable, rC, GammaBeta_dim);
        }
    }
    
    // Calculate LUT if lut len is greater than 0
    //
    if (lut_len)
    {
        uniop_calculate_lut<TGDQ, TBDQ, TLUT>(function, lut_start_val, lut_end_val, lut_len, &aie_b_bo);
    }

#if 1//ASM_MODE
    write_bin_file("ifm.bin", (char*)aie_a_bo, abo_size);
    write_bin_file("wgt.bin", (char*)aie_b_bo.data, bbo_size);
    write_bin_file("ofm.bin", (char*)cpu_o_bo, obo_size);
    write_external_buffer_json(obo_size, abo_size, bbo_size);
#endif // ASM_MODE

#ifdef __AIESIM__
#if !ASM_MODE
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_o_bo), obo_size,
                 reinterpret_cast<void*>(aie_a_bo), abo_size,
                 reinterpret_cast<void*>(aie_b_bo.data), bbo_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_o_bo, aie_a_bo, aie_b_bo.data);
    g_compute_graph.end();
#endif //USE_CERT_LIBRARY

    float const relative_error = 1.0;
    
    // in case to dump the input too : 
    typedef uint16_t act_type;
    ActTensor<act_type> aie_act(_C, 1, _Y*_X, aie_a_bo);
    float err_cnt_i = check_result_rmse<ActTensor<act_type>, ActTensor<act_type>>(aie_act, aie_act, relative_error, !dequant_enable);

    if(cfg["DTYPE_OUT"] != "float32")  // "float16" or "uint16"
    {
        typedef uint16_t out_type;
        ActTensor<out_type> aie_out(_C, 1, _Y*_X, aie_o_bo);
        ActTensor<out_type> cpu_out(_C, 1, _Y*_X, cpu_o_bo);
        float err_cnt_o = check_result_rmse<ActTensor<out_type>, ActTensor<out_type>>(cpu_out, aie_out, relative_error, !quant_enable);
    }
    else
    {
        ActTensor<float> aie_out(_C, 1, _Y*_X, aie_o_bo);
        ActTensor<float> cpu_out(_C, 1, _Y*_X, cpu_o_bo);
        float err_cnt_o = check_result_rmse<ActTensor<float>, ActTensor<float>>(cpu_out, aie_out, relative_error, true);
    }
#endif // !ASM_MODE
#endif // __AIESIM__

    deallocate(aie_a_bo);
    deallocate(aie_b_bo.data);
    deallocate(aie_o_bo);
    deallocate(cpu_o_bo, false);

    DEBUG_PRINT("Test Ended\n");
    return 0;
}

int main(void)
{
    int status = 0;
    auto cfg = load_json("uniop_cfg.json");
    std::string function = cfg["FUNCTION"];
    int const _C = extract_json(cfg, "_C");
    int const gen_io = extract_json(cfg, "gen_io");

    uint32_t const LUT_LEN = 512;
    int32_t const LUT_START_VAL = -5;
    int32_t const LUT_END_VAL = 5;

    string gamma_type = "void";
    string beta_type = "void";
    int GammaBeta_dim = 0;

    uint32_t lut_length = 0;
    int32_t lut_start_value = 0;
    int32_t lut_end_value = 0;

    if (function == "layernorm" || function == "groupnorm")
    {
        //
        // Get gamma and beta datatypes
        //
        gamma_type = extract_json_str(cfg, "DTYPE_GAMMA");
        beta_type = extract_json_str(cfg, "DTYPE_BETA");
        GammaBeta_dim = _C;
    }
    
    if (function == "tanh" ||
        function == "swish" ||
        function == "sigmoid" ||
        function == "silu" ||
        function == "gelu" ||
        function == "elu")
    {
        //
        // TODO get lut len, and lut range from config file
        //
        lut_length = LUT_LEN;
        lut_start_value = LUT_START_VAL;
        lut_end_value = LUT_END_VAL;
    }

    using TGDQ = uint16_t;
    using TBDQ = uint16_t;
    using TLUT = uint16_t;

    if (gamma_type == "uint8" && beta_type == "uint8")
    {
        status = process_uniop<uint8_t, uint8_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "uint8" && beta_type == "int8")
    {
        status = process_uniop<uint8_t, int8_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "uint8" && beta_type == "int32")
    {
        status = process_uniop<uint8_t, int32_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "int8" && beta_type == "uint8")
    {
        status = process_uniop<int8_t, uint8_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "int8" && beta_type == "int8")
    {
        status = process_uniop<int8_t, int8_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "int8" && beta_type == "int32")
    {
        status = process_uniop<int8_t, int32_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "int32" && beta_type == "uint8")
    {
        status = process_uniop<int32_t, uint8_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "int32" && beta_type == "int8")
    {
        status = process_uniop<int32_t, int8_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "int32" && beta_type == "int32")
    {
        status = process_uniop<int32_t, int32_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "void" && beta_type == "void")
    {
        //
        // When gamma and beta types are void, just provide default type to
        // template data types. Since GammaBeta_dim is 0, GammaBeta calculation will
        // not happen
        //
        status = process_uniop<int32_t, int32_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else if (gamma_type == "uint16" && beta_type == "uint16")
    {
        //
        // When gamma and beta types are void, just provide default type to
        // template data types. Since GammaBeta_dim is 0, GammaBeta calculation will
        // not happen
        //
        status = process_uniop<uint16_t, uint16_t, TGDQ, TBDQ, TLUT>(function,
                                                                GammaBeta_dim,
                                                                lut_length,
                                                                lut_start_value,
                                                                lut_end_value,
                                                                gen_io);
    }
    else
    {
        //
        // Error message
        //
        std::cout << "Unsupported data type" << std::endl;
    }

    return status;
}
