#include <string>
#include <sstream>
#include <fstream>
#include <assert.h>
#include <stdlib.h>
#include <stdio.h>
#if !ASM_MODE
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // !ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
#include "dma.hpp"
#endif // !ASM_MODE
#endif // __AIESIM__
#include "gemm.hpp"

#if !ASM_MODE
ComputeGraph g_compute_graph;
#endif // !ASM_MODE

template<typename Ta, typename Tw, typename Tb, typename To>
int run_gemm(
    int const M,
    int const K,
    int const N,
	int const Morig,
	int const Korig,
	int const Norig,
    int const Msubv,
    int const Ksubv,
    int const Nsubv,
    int const sign_act,
    int const sign_wgt,
    int const sign_out,
    int const shift_out,
    int const vector_coeff,
    int const n_split,
    int const is_int4,
    int const read_model_data,
    std::string const_path,
    std::string node_name,
    int const debug_mode
){
    if (debug_mode) {
        printf("M = %d\n", M);
        printf("K = %d\n", K);
        printf("N = %d\n", N);
        printf("Msubv = %d\n", Msubv);
        printf("Ksubv = %d\n", Ksubv);
        printf("Nsubv = %d\n", Nsubv);
    }

    int sign_A = sign_act;
    int sign_W = sign_wgt;
    int sign_O = sign_out;

    using Tacc = int64_t;
    using Tc0 = float;
    using Tc1 = float;
    using Tc2 = float;

    GemmQdqint16x8_RT_Params qdq_params;
    qdq_params.shift_out = shift_out;
    qdq_params.ifm_sign = sign_A;
    qdq_params.wgt_sign = sign_W;
    qdq_params.ofm_sign = sign_O;
    int N_padded = iceil(N, (Nsubv * n_split));
    int K_padded = iceil(K, Ksubv);

    if (debug_mode) {
        printf("N_padded = %d\n", N_padded);
        printf("K_padded = %d\n", K_padded);
    }
 
    int act_size = ActTensor<Ta>::size(K, 1, M);
    int wgt_size = GemmWgtTensor<Tw, Tc0, Tc1, Tc2>::size(K_padded, N_padded, Ksubv, Nsubv, is_int4);
    int out_size = ActTensor<To>::size(N, 1, M);
    int acc_size = ActTensor<Tacc>::size(N, 1, M);
    if (debug_mode) {
        printf("act_size = %d\n", act_size);
        printf("wgt_size = %d\n", wgt_size);
        printf("acc_size = %d\n", acc_size);
        printf("out_size = %d\n", out_size);
    }

#if !ASM_MODE
    ActTensor<Ta> aie_act(
        K, 1, M,
        adf::GMIO::malloc(act_size)
    );

    GemmWgtTensor<Tw, Tc0, Tc1, Tc2> aie_wgt(
        K_padded, N_padded, Ksubv, Nsubv, vector_coeff, is_int4,
        adf::GMIO::malloc(wgt_size)
    );

    ActTensor<To> aie_out(
        N, 1, M,
        adf::GMIO::malloc(out_size)
    );
#else
    ActTensor<Ta> aie_act(
        K, 1, M,
        malloc(act_size)
    );

    GemmWgtTensor<Tw, Tc0, Tc1, Tc2> aie_wgt(
        K_padded, N_padded, Ksubv, Nsubv, vector_coeff, is_int4,
        malloc(wgt_size)
    );
#endif // !ASM_MODE

    ActTensor<Tacc> cpu_gemm_out(
        N, 1, M,
        malloc(acc_size)
    );

    ActTensor<To> cpu_out(
        N, 1, M,
        malloc(out_size)
    );



    // Print the subv sizes
    if (debug_mode) {
        printf("subv_wgt_size = %d \n", aie_wgt.subv_wgt_size);
        printf("subv_c0_size = %d \n", aie_wgt.subv_c0_size);
        printf("subv_c1_size = %d \n", aie_wgt.subv_c1_size);
        printf("subv_c2_size = %d \n", aie_wgt.subv_c2_size);
        printf("subv_qdq_size = %d \n", aie_wgt.subv_qdq_size);
        printf("total_subv_size = %d \n", aie_wgt.subv_size);
    }

    if(!read_model_data){
        if(is_int4){
            init_random_gemm_a16w4(aie_act, aie_wgt, qdq_params, is_int4, K, N);
        }
        else{
            init_random_gemm_a16w8(aie_act, aie_wgt, qdq_params, K, N);
        }
        cpu_matmul(aie_act, aie_wgt, cpu_gemm_out, is_int4, K, N);
        cpu_3term_qdq(aie_act, cpu_gemm_out, aie_wgt, cpu_out, debug_mode);
    }
    else {
        // TODO: add act formatting in debug mode
        init_model_data<Ta, Tw, Tb, Tc0, Tc1, Tc2>(const_path, node_name, aie_wgt, qdq_params, is_int4, Korig, Norig, vector_coeff, debug_mode);
    }
    
#if ASM_MODE
    write_bin_file("ifm.bin", (char*)aie_act.data, act_size);
    write_bin_file("wgt.bin", (char*)aie_wgt.data, wgt_size);
    write_bin_file("ofm.bin", (char*)cpu_out.data, out_size);
    write_external_buffer_json(out_size, act_size, wgt_size);
#endif // ASM_MODE
#ifdef __AIESIM__
#if !ASM_MODE
    log_tensor(aie_act, "ACT TENSOR");
    log_tensor(aie_wgt, "WGT TENSOR", is_int4);
    log_tensor(cpu_gemm_out, "CPU GEMM OUT TENSOR");
    log_tensor(cpu_out, "CPU OUT TENSOR");
#if USE_CERT_LIBRARY
    run_cert_sim(g_compute_graph,
                 reinterpret_cast<void*>(aie_ofm.data), out_size,
                 reinterpret_cast<void*>(aie_ifm.data), act_size,
                 reinterpret_cast<void*>(aie_wgt.data), wgt_size);
#else    
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_out.data, aie_act.data, aie_wgt.data);
    g_compute_graph.end();
#endif //USE_CERT_LIBRARY

    int epsilon = 1;
    int err = cmp_tensor(cpu_out, aie_out, qdq_params.ofm_sign, epsilon);
    if (err == 0)
        printf("GEMM_A16W8 DI_PASS: M=%d, K=%d, N=%d, Msubv=%d, Ksubv=%d, Nsubv=%d SHIFT_OUT=%d, VECTOR_COEFF=%d is_wgt_int4=%d \n", M, K, N, Msubv, Ksubv, Nsubv, shift_out, vector_coeff, is_int4);
    else
        printf("GEMM_A16W8 DI_FAIL: M=%d, K=%d, N=%d, Msubv=%d, Ksubv=%d, Nsubv=%d SHIFT_OUT=%d, VECTOR_COEFF=%d is_wgt_int4=%d \n", M, K, N, Msubv, Ksubv, Nsubv, shift_out, vector_coeff, is_int4);
#endif // !ASM_MODE
#endif // __AIESIM__

#if !ASM_MODE
    adf::GMIO::free(aie_act.data);
    adf::GMIO::free(aie_wgt.data);
    adf::GMIO::free(aie_out.data);
#else
    free(aie_act.data);
    free(aie_wgt.data);
#endif // !ASM_MODE

    free(cpu_out.data);
    return 0;
}


int main(void)
{
    auto cfg = load_json("gemm_cfg.json");
    int const Mgemm = extract_json(cfg, "M_GEMM_A16W8");
    int const Kgemm = extract_json(cfg, "K_GEMM_A16W8");
    int const Ngemm = extract_json(cfg, "N_GEMM_A16W8");
	int const Morig = extract_json(cfg, "M_GEMM_ORIG");
    int const Korig = extract_json(cfg, "K_GEMM_ORIG");
    int const Norig = extract_json(cfg, "N_GEMM_ORIG");
    int const is_int4 = extract_json(cfg, "IS_INT4_WGT");
    int const Msubv = extract_json(cfg, "M_SUBV_A16W8");
    int const Ksubv = extract_json(cfg, "K_SUBV_A16W8");
    int const Nsubv = extract_json(cfg, "N_SUBV_A16W8");

    int shift_out = 8;
    int const vector_coeff = extract_json(cfg, "COEFF_VECTOR"); 

    int const n_split = extract_json(cfg, "N_SPLIT"); 

    int sign_act = extract_json(cfg, "SIGN_ACT");
    int sign_wgt = extract_json(cfg, "SIGN_WGT");
    int sign_out = extract_json(cfg, "SIGN_OUT");
    int debug_mode = extract_json(cfg, "DEBUG");

    std::string const node_name = extract_json_str(cfg, "NODE_NAME");
    std::string const md_path = extract_json_str(cfg, "MD_PATH");
    std::string const dtype_act = extract_json_str(cfg, "DTYPE_ACT");
    std::string const dtype_wgt = extract_json_str(cfg, "DTYPE_WGT");
    std::string const dtype_bias = extract_json_str(cfg, "DTYPE_BIAS");
    std::string const dtype_ofm = extract_json_str(cfg, "DTYPE_OFM");
    int const read_md = extract_json(cfg, "READ_MD");
    if (read_md) {
        shift_out = 0;
    }

    if (debug_mode) {
        std::cout << "=== GEMM WGT Formatting ===\n"
            << "NODE_NAME   : " << node_name  << '\n'
            << "MD_PATH     : " << md_path    << '\n'
            << "READ_MD     : " << read_md    << '\n'
            << "DTYPE_ACT   : " << dtype_act  << '\n'
            << "DTYPE_WGT   : " << dtype_wgt  << '\n'
            << "DTYPE_BIAS  : " << dtype_bias << '\n'
            << "DTYPE_OFM   : " << dtype_ofm  << '\n'
            << "==========================\n";
    }

    // uint8 weights
    if (dtype_act == "uint16" &&
        dtype_wgt == "uint8" &&
	(dtype_bias == "int32" || dtype_bias == "0") &&
	dtype_ofm == "uint16") {
        run_gemm<uint16_t, uint8_t, int32_t, uint16_t>(
            Mgemm, Kgemm, Ngemm, Morig, Korig, Norig, Msubv, Ksubv, Nsubv, sign_act, sign_wgt,
            sign_out, shift_out, vector_coeff, n_split, is_int4, read_md,
            md_path, node_name, debug_mode);
    } else if (dtype_act == "uint16" &&
        dtype_wgt == "uint8" &&
	(dtype_bias == "uint16" || dtype_bias == "0") &&
	dtype_ofm == "uint16") {
        run_gemm<uint16_t, uint8_t, uint16_t, uint16_t>(
            Mgemm, Kgemm, Ngemm, Morig, Korig, Norig, Msubv, Ksubv, Nsubv, sign_act, sign_wgt,
            sign_out, shift_out, vector_coeff, n_split, is_int4, read_md,
            md_path, node_name, debug_mode);
    } else if (dtype_act == "uint16" &&
        dtype_wgt == "uint8" &&
	(dtype_bias == "uint8" || dtype_bias == "0") &&
	dtype_ofm == "uint16") {
        run_gemm<uint16_t, uint8_t, uint8_t, uint16_t>(
            Mgemm, Kgemm, Ngemm, Morig, Korig, Norig, Msubv, Ksubv, Nsubv, sign_act, sign_wgt,
            sign_out, shift_out, vector_coeff, n_split, is_int4, read_md,
            md_path, node_name, debug_mode);
    // int8 weights
    } else if (dtype_act == "uint16" &&
        (dtype_wgt == "int8" || dtype_wgt == "int4") &&
	(dtype_bias == "int32" || dtype_bias == "0") &&
	dtype_ofm == "uint16") {
        run_gemm<uint16_t, int8_t, int32_t, uint16_t>(
            Mgemm, Kgemm, Ngemm, Morig, Korig, Norig, Msubv, Ksubv, Nsubv, sign_act, sign_wgt,
            sign_out, shift_out, vector_coeff, n_split, is_int4, read_md,
            md_path, node_name, debug_mode);
    } else if (dtype_act == "uint16" &&
        (dtype_wgt == "int8" || dtype_wgt == "int4") &&
	(dtype_bias == "uint16" || dtype_bias == "0") &&
	dtype_ofm == "uint16") {
        run_gemm<uint16_t, int8_t, uint16_t, uint16_t>(
            Mgemm, Kgemm, Ngemm, Morig, Korig, Norig, Msubv, Ksubv, Nsubv, sign_act, sign_wgt,
            sign_out, shift_out, vector_coeff, n_split, is_int4, read_md,
            md_path, node_name, debug_mode);
    } else if (dtype_act == "uint16" &&
        (dtype_wgt == "int8" || dtype_wgt == "int4") &&
	(dtype_bias == "uint8" || dtype_bias == "0") &&
	dtype_ofm == "uint16") {
        run_gemm<uint16_t, int8_t, uint8_t, uint16_t>(
            Mgemm, Kgemm, Ngemm, Morig, Korig, Norig, Msubv, Ksubv, Nsubv, sign_act, sign_wgt,
            sign_out, shift_out, vector_coeff, n_split, is_int4, read_md,
            md_path, node_name, debug_mode);
    } else {
        std::cout << "ERROR: Unsupported MatMul type for <Ta, Tw, Tb, To> " <<
		dtype_act << ", " << dtype_wgt << ", " << dtype_bias << ", " <<
		dtype_ofm << std::endl;
    }

    return 0;
}
