#include <assert.h>
#include <iostream>
#include <fstream>
#include <sstream>
#include <string>
#include <vector>
#include <cmath>
#include <cstdint>
#include <numeric>
#include <algorithm>
#include "elemwise_qdq.hpp"
#ifndef __TXNRT__
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "super.hh"
#include "graph.hpp"
#endif // __TXNRT__
#if defined(__AIESIM__) || defined(__TXNRT__)
#include "dma.hpp"
#endif // __AIESIM__ || __TXNRT__

/* NOTE:
1. for int_16, the Tin = Tout = uint16_t
2. for int_8
    1) qdq_mode = 3 (disabled both)   Tin = Tout = uint8_t
    2) qdq_mode = 0 (dq only)  Tin = uint8_t; Tout = uint16_t
    3) qdq_mode = 1 ( q only)  Tin = uint16_t; Tout = uint8_t
    4) qdq_mode = 2 (dq + q)  Tin = Tout = uint8_t
*/

#if INT_16 == 1
    using Tin = uint16_t;
    using Tout = uint16_t;
#else
  #if QDQ_MODE == 3
    using Tin = uint8_t;
    using Tout = uint8_t;
  #elif QDQ_MODE == 0
    using Tin = uint8_t;
    using Tout = uint16_t;
  #elif QDQ_MODE == 1
    using Tin = uint16_t;
    using Tout = uint8_t;
  #elif QDQ_MODE == 2
    using Tin = uint8_t;
    using Tout = uint8_t;
  #else
    #error "INVALID QDQ_MODE"
  #endif

#endif

using Welem = int32_t;

int max_64_W8(int x){
    int m = 8;
    // return std::max(64, ((x + m - 1) / m) * m);
    return ((x + m - 1) / m) * m;
}



/* QDQ Packing:
    1. Each element is 4 bytes
    2. We first pack zero-point and scale for dequant of each input
    3. After 2 * num_input we pack zero-point and scale for quant
    4. After (2 * num_input) + 2 we pack dq ENABLE/DISABLE and q ENABLE/DISABLE
*/

template<typename T>
void init_wgt_mat(T* data, int qdq_mode, int num_inputs) {
    //for dq
    int start = 0;
    int stop  = num_inputs;
    // bool is_16 = std::is_same<Tin, uint16_t>::value;
    //NOTE: the zp and scale for each input should be different
    int32_t zp;
    float scale;
    float fmin = 0.0f; //fmin and fmax can be random data
    float fmax = 128.0f;
    //NOTE: force fmin and fmax with below till force the zp=0 and scale = 1 (kind of bypass)
    // float fmin = 0.0f;
    // float fmax = 65535.0f;
    for (int i = start; i < stop; i++) {
        compute_scale_and_zp<Tin>(fmin, fmax, scale, zp);
        data[2*i] = zp;
        data[2*i+ 1] = float_to_bfloat16(scale);
    }
    //for q
    if (qdq_mode == 1 ) //and is_int8
        compute_scale_and_zp<Tout>(fmin, fmax, scale, zp);
    else
        compute_scale_and_zp<Tin>(fmin, fmax, scale, zp);
    data[stop*2] = zp;
    data[stop*2 + 1] = float_to_bfloat16(1/scale);

    //for enable
    if (qdq_mode == 0) {
        data[stop*2 + 2] = 1; //dq ENABLE
        data[stop*2 + 3] = 0; //q DISABLE
    } else if (qdq_mode == 1)
    {
        data[stop*2 + 2] = 0; //dq DISABLE
        data[stop*2 + 3] = 1; //q ENABLE
    } else if (qdq_mode == 2)
    {
        data[stop*2 + 2] = 1; //dq ENABLE
        data[stop*2 + 3] = 1; //q ENABLE
    } else if (qdq_mode == 3)
    {
        data[stop*2 + 2] = 0; //dq DISABLE
        data[stop*2 + 3] = 0; //q DISABLE
    } else {
        std::cout << "ERROR MODE!" << std::endl;
    }
}

void write_bin_file(std::string filename, char* data, size_t size) {
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}


template<typename T>
void init_random_mat(int input_idx, T* data, int num_rows, int num_cols, int num_chs, int num_chs_p,
                     int qdq_mode, uint16_t zp, float scale,
                     float min_f = 0.0f, float max_f = 64.0f, // this range will cover both int8 and int16 with relatively good dynamic
                     float safe_quant_scale = 0.1f) // assumes quant scale ~0.1
{
    for (int h = 0; h < num_rows; ++h) {
        for (int w = 0; w < num_cols; ++w) {
            for (int c = 0; c < num_chs_p; ++c) {
                int index = (h * num_cols * num_chs_p) + (w * num_chs_p) + c;
                float rnd_data = (c < num_chs) ? ((max_f - min_f) * (rand() / float(RAND_MAX))) + min_f : 0;
                T tmp = (qdq_mode== 1) ? float_to_bfloat16(rnd_data) : quantize<T>(rnd_data, scale, zp );
                data[index] = tmp;
            }
        }
    }
}


template<typename T>
void print_mat(T* data, int num_rows, int num_cols, int num_chs, std::string msg = "") {
    std::cout << msg;
    for (int h = 0; h < num_rows; ++h) {
        for (int w = 0; w < num_cols; ++w) {
            for (int c = 0; c < num_chs; ++c) {
                printf("%6d", data[(h * num_chs * num_cols) + (w * num_chs) + c]);
            }
            printf("\n");
        }
        printf("\n");
    }
    printf("\n");
}


template<typename Tin, typename Tout>
void concat_mats(Tin** in_mat, int in_rows, std::vector<int>& in_cols, std::vector<int>& in_channels,
                        std::vector<int>& in_channels_p,
                        int N, int concat_mode, bool kernel_but_no_depad,
                        Tout* out_mat, int out_rows, int out_cols, int out_channels,
                        int qdq_mode, Welem* qdq_param)
{
    // bool is_16 = std::is_same<Telem, uint16_t>::value;
    // quant coeff
    uint16_t q_z = qdq_param[N*2];
    //NOTE: why doing this because we do 1/scale for q scale when sending to kernel
    // so: q_scale = 1/scale.
    float q_inv_or_s = (1/bfloat16_to_float(qdq_param[N*2 + 1]));
    float tmp_f;
    Tin tmp_int_in;
    Tout tmp_int_out;
    if (concat_mode == 0) {  // Concat on channels
        int ch_offset = 0;
        int IN_W = in_cols[0];  // Assuming same width for all inputs
        for (int n = 0; n < N; ++n) {
            // per input dequant coeff
            uint16_t dq_z = qdq_param[n*2];
            float dq_inv_or_s = bfloat16_to_float(qdq_param[n*2 + 1]);

            int valid_ch = kernel_but_no_depad ? in_channels_p[n] : in_channels[n];
            for (int h = 0; h < in_rows; ++h) {
                for (int w = 0; w < IN_W; ++w) {
                    for (int c = 0; c < valid_ch; ++c) {
                        tmp_int_in = in_mat[n][(h * IN_W * in_channels_p[n]) + (w * in_channels_p[n]) + c];
                        if (qdq_mode == 0 || qdq_mode == 2) {
                            tmp_f = dequantize<Tout>(tmp_int_in, dq_inv_or_s, dq_z);
                            tmp_int_out = float_to_bfloat16(tmp_f);
                        }
                        if (qdq_mode == 1){
                            tmp_int_out = quantize<Tin>(bfloat16_to_float(tmp_int_in), q_inv_or_s, q_z);
                        }
                        if (qdq_mode == 2) {
                            tmp_int_out = quantize<Tout>(tmp_f, q_inv_or_s, q_z);
                        }
                        if (qdq_mode == 3){
                            tmp_int_out = tmp_int_in;
                        }
                        int dst_idx = (h * out_cols * out_channels) + (w * out_channels) + ch_offset + c;
                        out_mat[dst_idx] = tmp_int_out;
                    }
                }
            }
            ch_offset += valid_ch;
        }
    } else if (concat_mode == 1) {  // Concat on width
        int col_offset = 0;
        int IN_C = in_channels_p[0];  // Assuming same channels for all inputs
        for (int n = 0; n < N; ++n) {
            // per input dequant coeff
            uint16_t dq_z = qdq_param[n*2];
            float dq_inv_or_s = bfloat16_to_float(qdq_param[n*2 + 1]);

            for (int h = 0; h < in_rows; ++h) {
                for (int w = 0; w < in_cols[n]; ++w) {
                    for (int c = 0; c < IN_C; ++c) {
                        tmp_int_in = in_mat[n][(h * in_cols[n] * IN_C) + (w * IN_C) + c];
                        if (qdq_mode == 0 || qdq_mode == 2) {
                            tmp_f = dequantize<Tout>(tmp_int_in, dq_inv_or_s, dq_z);
                            tmp_int_out = float_to_bfloat16(tmp_f);
                        }
                        if (qdq_mode == 1){
                            tmp_int_out = quantize<Tin>(bfloat16_to_float(tmp_int_in), q_inv_or_s, q_z);
                        }
                        if (qdq_mode == 2) {
                            tmp_int_out = quantize<Tout>(tmp_f, q_inv_or_s, q_z);
                        }
                        if (qdq_mode == 3){
                            tmp_int_out = tmp_int_in;
                        }
                        int dst_idx = (h * out_cols * out_channels) + ((col_offset + w) * out_channels) + c;
                        out_mat[dst_idx] = tmp_int_out;

                    }
                }
            }
            col_offset += in_cols[n];
        }
    }
}


template<typename T>
void concat_mats_padding(T* in_data, T* out_data, int h_in, int w_in, int c_in, int c_in_p)
{
    for(int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in_p; ++c) {
                T val = in_data[(i * w_in * c_in) + (j * c_in) + c];
                out_data[(i * w_in * c_in_p) + (j * c_in_p) + c] = (c < c_in) ? val : 0;
            }
        }
    }
}


template<typename T>
int check_result(T* expected, T* received,
                 int num_rows,
                 int num_cols,
                 int num_chs,
                 int qdq_mode,
                 int threshold = 8) {
    int err_count = 0;
    float diff;
    // bool is_16 = std::is_same<T, uint16_t>::value;

    for (int h = 0; h < num_rows; ++h) {
        for (int w = 0; w < num_cols; ++w) {
            for (int c = 0; c < num_chs; ++c) {
                int idx = (h * num_cols * num_chs) + (w * num_chs) + c;
                T e = expected[idx];
                T r = received[idx];

                diff = std::abs(static_cast<float>(e) - static_cast<float>(r));

                if (diff > threshold) {
                    std::cout << "ERROR: [h=" << h << ", w=" << w << ", c=" << c << "]: ";
                    std::cout << "Expected: " << +e
                                << ", Received: " << +r
                                << ", Diff: " << diff << "\n";
                    ++err_count;
                }else if (diff > 0) {
                    std::cout << "WARNING: [h=" << h << ", w=" << w << ", c=" << c << "]: ";
                    std::cout << "Expected: " << +e
                                << ", Received: " << +r
                                << ", Diff: " << diff << "\n";
        }
    }
    }
    }
    return err_count;
}



#ifndef __TXNRT__
ComputeGraph g_compute_graph;
#endif // __TXNRT__

int main(void)
{
    srand(0xABCD);
    constexpr int WGT_ELEMENT = 64;
    int aie_rows = AIE_ROWS;
    int aie_cols = AIE_COLS;
    int concat_mode = CONCAT_MODE;
    int num_inputs = NUM_INPUTS;
    int qdq_mode = QDQ_MODE; //0: DEQUANT; 1: QUANT; 2: BOTH; 3: NONE
    int qdq_size = WGT_ELEMENT * sizeof(Welem);
    bool padding_enable = bool(PADDING_EN);
    bool is_kernel = bool(IS_KERNEL);
    bool is_kernel_depad_avail = bool(IS_KERNEL_DEPAD);
    bool is_16 = bool(INT_16);
    // int is_const_input = IS_CONST_INPUT;
    // printf("is_const_input: %d", is_const_input);
    std::vector<int> input_rows, input_cols, input_chs;
    std::vector<std::string> input_types;
    std::vector<int> input_chs_p(NUM_INPUTS);
    std::ifstream file("shapes.txt");
    std::string line;

    if (file.is_open()) {
        int line_idx = 0;

        while (std::getline(file, line)) {
            std::istringstream ss(line);
            std::string value;

            while (std::getline(ss, value, ',')) {
                switch (line_idx) {
                    case 0:  // rows
                        input_rows.push_back(std::stoi(value));
                        break;
                    case 1:  // cols
                        input_cols.push_back(std::stoi(value));
                        break;
                    case 2:  // channels
                        input_chs.push_back(std::stoi(value));
                        break;
                    case 3:  // input types
                        input_types.push_back(value);
                        break;
                    default:
                        break;
                }
            }
            ++line_idx;
        }
        file.close();

        // ---- Default behavior if 4th line is missing ----
        if (input_types.empty()) {
            input_types.resize(input_rows.size(), "act");
        }

    } else {
        std::cerr << "Unable to open file shapes.txt" << std::endl;
    }

    for (int i = 0; i < num_inputs; ++i) {
        input_chs_p[i] = max_64_W8(input_chs[i]);
    }

    assert (num_inputs == int(input_rows.size()) && int(input_rows.size()) == int(input_cols.size()) &&
                            int(input_cols.size()) == int(input_chs.size()));

    int output_rows = input_rows[0];
    int output_cols = (concat_mode == 1 ) ? std::accumulate(input_cols.begin(), input_cols.end(), 0) \
                                            : input_cols[0];
    //output_chs : no padding
    //output_chs_p: padding
    // bool is_qdq = bool(qdq_mode != 3);
    int output_chs;
    if (is_kernel){
        if (is_kernel_depad_avail){
            output_chs= (concat_mode == 0 ) ? std::accumulate(input_chs.begin(), input_chs.end(), 0) \
                                        : input_chs[0];
        }else{
            output_chs= (concat_mode == 0 ) ? std::accumulate(input_chs_p.begin(), input_chs_p.end(), 0) \
                                        : input_chs_p[0];
        }

    }else{
        output_chs= (concat_mode == 0 ) ? std::accumulate(input_chs.begin(), input_chs.end(), 0) \
                                        : input_chs_p[0];
        }

    int output_chs_p = max_64_W8(output_chs);

    int ifm_size = 0;
    int wgt_size = qdq_size;
    for (int i = 0; i < num_inputs; ++i) {
        if (input_types[i] == "act")
            ifm_size += input_rows[i] * input_cols[i] * input_chs_p[i] * sizeof(Tin);
        else if (input_types[i] == "const")
            wgt_size += (input_rows[i] * input_cols[i] * input_chs_p[i]) * sizeof(Tin);
    }
    // ifm_size = ifm_size * sizeof(Tin);

    int ofm_size = output_rows * output_cols * output_chs * sizeof(Tout);
    int ofm_p_size = output_rows * output_cols * output_chs_p * sizeof(Tout);

#ifdef __TXNRT__
    auto aie_ifm = static_cast<Tin*>(malloc(ifm_size));
    auto aie_wgt = static_cast<Welem*>(malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(malloc(ofm_p_size));
#else
    auto aie_ifm = static_cast<Tin*>(adf::GMIO::malloc(ifm_size));
    auto aie_wgt = static_cast<Welem*>(adf::GMIO::malloc(wgt_size));
    auto aie_ofm = static_cast<Tout*>(adf::GMIO::malloc(ofm_p_size));
#endif // __TXNRT__
    auto cpu_ofm_nopad = static_cast<Tout*>(malloc(ofm_size));
    auto cpu_ofm = static_cast<Tout*>(malloc(ofm_p_size));
    // const int NUM_ACT_INPUTS = num_inputs - is_const_input;

    uint8_t* matrix[NUM_INPUTS];
    size_t temp_act = 0;                   // Tin elements
    size_t temp_const_bytes = qdq_size;    // bytes (after int32 weights)
    for (int i = 0; i < num_inputs; ++i) {
        size_t elem_cnt = input_rows[i] * input_cols[i] * input_chs_p[i];
        if (input_types[i] == "act") {
            // aie_ifm is Tin*
            matrix[i] = reinterpret_cast<uint8_t*>(&aie_ifm[temp_act]);
            temp_act += elem_cnt;
        } else {
            // aie_wgt is int32_t*, const region is Tin-packed
            matrix[i] = reinterpret_cast<uint8_t*>(aie_wgt) + temp_const_bytes;
            temp_const_bytes += elem_cnt * sizeof(Tin);
        }
    }
    Tout* aie_out_mat = aie_ofm;
    Tout* cpu_out_mat = cpu_ofm;

    Welem qdq_param[WGT_ELEMENT];
    // Tin* const_input_ptr = nullptr;
    if (qdq_mode == 0 || qdq_mode == 1 || qdq_mode == 2 || qdq_mode == 3){
        init_wgt_mat<Welem>(qdq_param, qdq_mode, num_inputs);
        printf("qdq_mode : %d\n", qdq_mode);
        printf("is_int16 : %d\n", INT_16);
        const size_t qdq_bytes = 64 * sizeof(Welem);
        memcpy(aie_wgt, (void*)qdq_param, qdq_bytes);

        // if (is_const_input) {
        //     const_input_ptr = reinterpret_cast<Tin*>(
        //         reinterpret_cast<uint8_t*>(aie_wgt) + qdq_bytes
        //     );
        // }
    }

    int counter_init = 0;
    for (int i = 0; i < num_inputs; ++i) {
        uint16_t zp = qdq_param[2*i];
        float scale = bfloat16_to_float(qdq_param[2*i+1]);
        init_random_mat(i, reinterpret_cast<Tin*>(matrix[i]), input_rows[i], input_cols[i], input_chs[i], input_chs_p[i], qdq_mode, zp, scale);
        counter_init += input_rows[i] * input_cols[i] * input_chs[i];
    }

    // if (is_const_input) {
    //     const int ci = num_inputs - 1;  // const input is the LAST input
    //     uint16_t zp = qdq_param[2 * ci];
    //     float scale = bfloat16_to_float(qdq_param[2 * ci + 1]);

    //     // Fill const input payload inside aie_wgt
    //     init_random_mat(ci,
    //                     const_input_ptr,
    //                     input_rows[ci], input_cols[ci],
    //                     input_chs[ci], input_chs_p[ci],
    //                     qdq_mode, zp, scale);
    //     counter_init += input_rows[ci] * input_cols[ci] * input_chs[ci];
    // }

    bool kernel_but_no_depad = is_kernel == true && is_kernel_depad_avail == false;

    // std::vector<Tin*> concat_inputs;
    // concat_inputs.reserve(num_inputs);

    // for (int i = 0; i < num_inputs; ++i) {
    //     concat_inputs.push_back(reinterpret_cast<Tin*>(matrix[i]));
    // }
    // if (is_const_input) {
    //     concat_inputs.push_back(const_input_ptr);
    // }

    print_mat(reinterpret_cast<Tin*>(matrix[0]), input_rows[0], input_cols[0], input_chs_p[0], "input0 =\n");
    print_mat(reinterpret_cast<Tin*>(matrix[1]), input_rows[1], input_cols[1], input_chs_p[1], "input1 =\n");

    concat_mats<Tin, Tout>(
        reinterpret_cast<Tin**>(matrix),
        input_rows[0], input_cols, input_chs, input_chs_p,
        num_inputs, concat_mode, kernel_but_no_depad,
        cpu_ofm_nopad,
        output_rows, output_cols, output_chs,
        qdq_mode, qdq_param);
    // print_mat(cpu_ofm_nopad, output_rows, output_cols, output_chs, "CPU cpu_ofm_nopad =\n");

    concat_mats_padding<Tout>(
        cpu_ofm_nopad, cpu_out_mat,
        output_rows, output_cols, output_chs, output_chs_p
    );
    // print_mat(cpu_out_mat, output_rows, output_cols, output_chs_p, "CPU cpu_ofm_padding =\n");


#if defined(__AIESIM__) || defined(__TXNRT__)
    #ifdef __TXNRT__
    DmaBins bins = run_dma_layer_config();
    bins.save();
    write_bin_file("ifm.bin", reinterpret_cast<char*>(aie_ifm), ifm_size);
    write_bin_file("wgt.bin", reinterpret_cast<char*>(aie_wgt), wgt_size);
    write_bin_file("ofm.bin", reinterpret_cast<char*>(cpu_ofm), ofm_p_size);
#else
    print_mat(cpu_ofm, output_rows, output_cols, output_chs_p, "CPU OFM =\n");
    g_compute_graph.init();
    run_dma_layer_config(g_compute_graph, aie_ofm, aie_ifm, aie_wgt);
    g_compute_graph.end();
    print_mat(aie_ofm, output_rows, output_cols, output_chs_p, "AIE OFM =\n");
    int threshold = (qdq_mode != 2) ? 1: (is_16 ? 128 : 8);
    int err_count = check_result<Tout>(cpu_out_mat, aie_out_mat, output_rows, output_cols, output_chs_p, qdq_mode, threshold);
    if (err_count == 0) {
        printf("DI: PASS\n");
    } else {
        printf("DI: FAIL\n");
    }
    printf("Error Count = %d\n", err_count);
    #endif // __TXNRT__
#endif // __AIESIM__

    #ifdef __TXNRT__
    free(aie_ifm);
    free(aie_wgt);
    free(aie_ofm);
    #else
    adf::GMIO::free(aie_ifm);
    adf::GMIO::free(aie_wgt);
    adf::GMIO::free(aie_ofm);
    #endif // __TXNRT__
    free(cpu_ofm_nopad);
    free(cpu_ofm);

    assert(false);
    return 0;

}
