#include <string>
#include <sstream>
#include <fstream>
#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__
#include "gemm_act.hpp"

#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

template<typename Ta, typename Tw, typename Tb, typename To>
int run_gemm_int16x16(
    int const Batch,
    int const M,
    int const K,
    int const N,
    int const Msubv,
    int const Ksubv,
    int const Nsubv,
    int const Mgemm_orig,
    int const Kgemm_orig,
    int const Ngemm_orig,
    int const transpose_wgts,
    int const shift_out,
    int const vector_coeff,
    int const sign_act,
    int const sign_wgt,
    int const sign_out,
    int const debug_mode,
    int const read_model_data,
    std::string const_path,
    std::string node_name
){
    if (debug_mode) {
        printf("BATCH_SIZE = %d\n", Batch);
        printf("M = %d\n", M);
        printf("K = %d\n", K);
        printf("N = %d\n", N);
        printf("Msubv = %d\n", Msubv);
        printf("Ksubv = %d\n", Ksubv);
        printf("Nsubv = %d\n", Nsubv);
        printf("shift_out = %d\n", shift_out);
        printf("vector_coeff = %d\n", vector_coeff);
        printf("sign_act = %d\n", sign_act);
        printf("sign_wgt = %d\n", sign_wgt);
        printf("sign_out = %d\n", sign_out);
    }

    using Tacc = int64_t;
    using Tc0 = float;
    using Tc1 = float;
    using Tc2 = float;

    GemmQdqint16x16_RT_Params qdq_params;
    qdq_params.shift_res = shift_out;
    qdq_params.sign_A = sign_act;
    qdq_params.sign_W = sign_wgt;
    qdq_params.sign_O = sign_out;
    qdq_params.c0 = 0;
    qdq_params.c1 = 0;
    qdq_params.c2 = 1;
    qdq_params.c3 = 0;
    qdq_params.vector_coeff = vector_coeff;

    int act1_size = ActTensor<Ta>::size(K, Batch, M);
    int act2_size = 0;
    if (transpose_wgts)
    {act2_size = ActTensor<Tw>::size(K, Batch, N);}
    else
    {act2_size = ActTensor<Tw>::size(N, Batch, Kgemm_orig);}
    int out_size = ActTensor<To>::size(N, Batch, M);
    int acc_size = ActTensor<Tacc>::size(N, Batch, M);
    int qdq_param_size = sizeof(qdq_params);
    
    if (debug_mode) {
        printf("act1_size = %d\n", act1_size);
        printf("act2_size = %d\n", act2_size);
        printf("transpose_wgts = %d\n", transpose_wgts);
        printf("acc_size = %d\n", acc_size);
        printf("out_size = %d\n", out_size);
    }
    int act_size = act1_size+act2_size;
    
#if !ASM_MODE
    void* aie_act = adf::GMIO::malloc(act_size);
    void* aie_qdqprm = adf::GMIO::malloc(qdq_param_size);
#else
    void* aie_act = malloc(act_size);
    void* aie_qdqprm = malloc(qdq_param_size);
#endif // !ASM_MODE
    void* aie_act1 = aie_act;
    void* aie_act2 = static_cast<int8_t*>(aie_act) + act1_size;

    ActTensor<Ta> act1(
        K, Batch, M,
        aie_act1
    );
    
    ActTensor<Tw> act2(
        transpose_wgts ? K : N,
        Batch,
        transpose_wgts ? N : Kgemm_orig,
        aie_act2
    );

    ActTensor<Tacc> cpu_gemm_out(
        N, Batch, M,
        malloc(acc_size)
    );

#if !ASM_MODE
    ActTensor<To> aie_out(
        N, Batch, M,
        adf::GMIO::malloc(out_size)
    );
#else
    ActTensor<To> aie_out(
        N, Batch, M,
        malloc(out_size)
    );
#endif // !ASM_MODE

    ActTensor<To> cpu_out(
        N, Batch, M,
        malloc(out_size)
    );
 
    if(!read_model_data){
        init_random_gemm_a16a16(act1, act2, qdq_params, transpose_wgts, Mgemm_orig, Kgemm_orig, Ngemm_orig);
        cpu_matmul(act1, act2, cpu_gemm_out, transpose_wgts);
        cpu_act_x_act_term_qdq(act1, act2, cpu_gemm_out, cpu_out, qdq_params, transpose_wgts);}
    else {
        // TODO: add act formatting in debug mode
        init_model_data<Ta, Tw, Tc0, Tc1, Tc2>(const_path, node_name, qdq_params, Kgemm_orig, vector_coeff, debug_mode);
    }

    memcpy(aie_qdqprm, &qdq_params, sizeof(qdq_params));

#if ASM_MODE
    write_bin_file("ifm.bin", (char*)aie_act, act_size);
    write_bin_file("wgt.bin", (char*)aie_qdqprm, qdq_param_size);
    write_bin_file("ofm.bin", (char*)cpu_out.data, out_size);
    write_external_buffer_json(act_size, qdq_param_size, out_size);
#endif // ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(act1, "ACT1 TENSOR");
    log_tensor(act2, "ACT2 TENSOR");
    log_tensor(cpu_gemm_out, "CPU GEMM OUT TENSOR");
    log_tensor(cpu_out, "CPU OUT TENSOR");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_out.data), out_size,
                 reinterpret_cast<void*>(aie_act), act_size,
                 reinterpret_cast<void*>(aie_qdqprm), qdq_param_size);
#else
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_out.data, aie_act, aie_qdqprm);
    g_compute_graph.end();
#endif //USE_CERT_LIBRARY

    int epsilon = 1;
    int err = cmp_tensor(cpu_out, aie_out, qdq_params.sign_O, epsilon);
    if (err == 0)
        printf("GEMM_A16A16 DI_PASS: M=%d, K=%d, N=%d, Msubv=%d, Ksubv=%d, Nsubv=%d SHIFT_OUT=%d, VECTOR_COEFF=%d \n", M, K, N, Msubv, Ksubv, Nsubv, shift_out, vector_coeff);
    else
        printf("GEMM_A16A16 DI_FAIL: M=%d, K=%d, N=%d, Msubv=%d, Ksubv=%d, Nsubv=%d SHIFT_OUT=%d, VECTOR_COEFF=%d \n", M, K, N, Msubv, Ksubv, Nsubv, shift_out, vector_coeff);
#endif // ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_act);
    adf::GMIO::free(aie_qdqprm);
    adf::GMIO::free(aie_out.data);
#else
    free(aie_act);
    free(aie_qdqprm);
    free(aie_out.data);
#endif // !ASM_MODE
    free(cpu_out.data);
    free(cpu_gemm_out.data);
    return 0;
}

int main(void)
{
    auto cfg = load_json("gemm_cfg.json");
    int const Batch = extract_json(cfg, "B_GEMM_A16W8");
    int const Mgemm = extract_json(cfg, "M_GEMM_A16W8");
    int const Kgemm = extract_json(cfg, "K_GEMM_A16W8");
    int const Ngemm = extract_json(cfg, "N_GEMM_A16W8");
    int const is_int4 = extract_json(cfg, "IS_INT4_WGT");
    int const Msubv = extract_json(cfg, "M_SUBV_A16W8");
    int const Ksubv = extract_json(cfg, "K_SUBV_A16W8");
    int const Nsubv = extract_json(cfg, "N_SUBV_A16W8");
    int const transpose_wgts = extract_json(cfg, "TRANSPOSE_WGTS");

    int const Mgemm_orig = extract_json(cfg, "M_GEMM_ORIG");
    int const Kgemm_orig = extract_json(cfg, "K_GEMM_ORIG");
    int const Ngemm_orig = extract_json(cfg, "N_GEMM_ORIG");

    int shift_out = 8;
    int const vector_coeff = extract_json(cfg, "COEFF_VECTOR"); 

    int const read_ifm = extract_json(cfg, "READ_IFM");
    int const read_wgt = extract_json(cfg, "READ_WGT");

    int const n_split = extract_json(cfg, "N_SPLIT"); 

    int sign_act = extract_json(cfg, "SIGN_ACT");
    int sign_wgt = extract_json(cfg, "SIGN_WGT");
    int sign_out = extract_json(cfg, "SIGN_OUT");
    int debug_mode = extract_json(cfg, "DEBUG");
    int const read_md = extract_json(cfg, "READ_MD");

    std::string const node_name = extract_json_str(cfg, "NODE_NAME");
    std::string const md_path = extract_json_str(cfg, "MD_PATH");
    std::string const dtype_act = extract_json_str(cfg, "DTYPE_ACT");
    std::string const dtype_wgt = extract_json_str(cfg, "DTYPE_WGT");
    std::string const dtype_ofm = extract_json_str(cfg, "DTYPE_OFM");

    if (read_md) {
        shift_out = 0;
    }

    if (debug_mode) {
    std::cout << "=== GEMM WGT Formatting ===\n"
        << "NODE_NAME   : " << node_name  << '\n'
        << "MD_PATH     : " << md_path    << '\n'
        << "READ_MD     : " << read_md    << '\n'
        << "DTYPE_ACT   : " << dtype_act  << '\n'
        << "DTYPE_WGT   : " << dtype_wgt  << '\n'
        << "DTYPE_OFM   : " << dtype_ofm  << '\n'
        << "==========================\n";
    }

    if (sign_act == 0 && sign_wgt == 0 && sign_out == 0) {
        run_gemm_int16x16<uint16_t, uint16_t, uint16_t, uint16_t>(
            Batch,
            Mgemm, Kgemm, Ngemm,
            Msubv, Ksubv, Nsubv,
            Mgemm_orig, Kgemm_orig, Ngemm_orig,
            transpose_wgts,
            shift_out, vector_coeff,
            sign_act, sign_wgt, sign_out, 
            debug_mode,
            read_md,
            md_path,
            node_name
        );
    } else if (sign_act == 1 && sign_wgt == 1 && sign_out == 1) {
        run_gemm_int16x16<int16_t, int16_t, int16_t, int16_t>(
            Batch,
            Mgemm, Kgemm, Ngemm,
            Msubv, Ksubv, Nsubv,
            Mgemm_orig, Kgemm_orig, Ngemm_orig,
            transpose_wgts,
            shift_out, vector_coeff,
            sign_act, sign_wgt, sign_out,
            debug_mode,
            read_md,
            md_path,
            node_name
        );
    } else {
        std::cout << "ERROR: Unsupported formatting for MatMul_actxact type for <Ta, Tw, To> " <<
		dtype_act << ", " << dtype_wgt << ", " << dtype_ofm << std::endl;
    }

    return 0;
}