#ifndef COMMON_HPP
#define COMMON_HPP

#include <stdint.h>
#include <assert.h>
#include <math.h>
#include <fenv.h>
#include <type_traits>
#include <stdexcept>


#include <iostream>
#include <vector>
#include <random>
#include <fstream>
#include <string>
#include <cstring>
#include <nlohmann/json.hpp>
#include <filesystem>
#include <cstdio>

#if USE_CERT_LIBRARY
#include <adf.h>
#include <adf/adf_api/AIERuntimeControl.h>
#include "uc_sim_wrapper.h"
#endif //USE_CERT_LIBRARY

#ifndef __IS_QDQ_FP16__
#define __IS_QDQ_FP16__ 0
#endif

template<typename T>
struct ActTensor
{
    int const C;
    int const Y;
    int const X;
    T* const data;

    ActTensor(int C, int Y, int X, void* data)
        : C(C)
        , Y(Y)
        , X(X)
        , data(static_cast<T*>(data))
    {}

    T& at(int c, int y, int x)
    {
        assert((0 <= c) && (c < C));
        assert((0 <= y) && (y < Y));
        assert((0 <= x) && (x < X));
        int idx = (y * X * C) + (x * C) + c;
        return data[idx]; // NOLINT
    }

    static int size(int C, int Y, int X)
    {
        return C * Y * X * sizeof(T);
    }
};

struct bfloat16_t
{
    uint16_t value;
};

// IEEE 754 half-precision (float16) - portable struct-based implementation
struct float16_t
{
    uint16_t value;
};

#if __IS_QDQ_FP16__
    using QDQFloatType = float16_t;
#else
    using QDQFloatType = bfloat16_t;
#endif

// Type trait to check for float16 types (bfloat16_t or float16_t)
template<typename T>
constexpr bool is_float16_type_v = std::is_same<T, bfloat16_t>::value || std::is_same<T, float16_t>::value;

// Forward declarations for float16 conversion functions
inline float bfloat16_to_float(bfloat16_t bf);
inline float float16_to_float(float16_t hf);

// Helper trait to get the appropriate conversion functions for float16 types
template<typename T>
struct Float16Traits {
    static float to_float(T val) { return static_cast<float>(val); }
    static T from_float(float val) { return static_cast<T>(val); }
};

template<>
struct Float16Traits<bfloat16_t> {
    static float to_float(bfloat16_t val) { return bfloat16_to_float(val); }
    static bfloat16_t from_float(float val);
};

template<>
struct Float16Traits<float16_t> {
    static float to_float(float16_t val) { return float16_to_float(val); }
    static float16_t from_float(float val);
};

inline int ceildiv(int x, int d)
{
    assert(x >= 0);
    assert(d > 0);
    return (x + d - 1) / d;
}

inline int iceil(int x, int d)
{
    assert(x >= 0);
    assert(d > 0);
    return ceildiv(x, d) * d;
}

template<typename T>
inline void log_tensor(
    ActTensor<T> tensor,
    std::string const& name)
{
    std::cout << name << ": \n";
    for (int y = 0; y < tensor.Y; ++y) {
        for (int x = 0; x < tensor.X; ++x) {
            for (int c = 0; c < tensor.C; ++c) {
                if constexpr (is_float16_type_v<T>) {
                    std::cout << int(Float16Traits<T>::to_float(tensor.at(c, y, x))) << " ";
                } else {
                    std::cout << int(tensor.at(c, y, x)) << " ";
                }
            }
            std::cout << "\n";
        }
        std::cout << "\n";
    }
}

template<typename T>
inline void log_tensor(
    ActTensor<T> tensor,
    std::string const& name,
    int Yis, int Xis, int Cis, int Ci_gran)
{
    std::cout << name << ": \n";
    for (int bc = 0; bc < tensor.C; bc += Cis) {
        for (int by = 0; by < tensor.Y; by += Yis) {
            for (int bx = 0; bx < tensor.X; bx += Xis) {
                std::cout << "subvol Block [C: " << bc << "-" << std::min(bc + Cis, tensor.C) - 1
                          << ", Y: " << by << "-" << std::min(by + Yis, tensor.Y) - 1
                          << ", X: " << bx << "-" << std::min(bx + Xis, tensor.X) - 1 << "]:\n";
                for (int c = bc; c < std::min(bc + Cis, tensor.C); c += Ci_gran) {
                    // std::cout << "Sub-Block [C: " << c << "-" << std::min(c + Ci_gran, tensor.C) - 1 << "]:\n";
                    for (int y = by; y < std::min(by + Yis, tensor.Y); ++y) {
                        for (int x = bx; x < std::min(bx + Xis, tensor.X); ++x) {
                            // std::cout << "Row [Y: " << y << ", X: " << x << "]: ";
                            for (int cc = c; cc < std::min(c + Ci_gran, tensor.C); ++cc) {
                                std::cout << int(tensor.at(cc, y, x)) << " ";
                            }
                            std::cout << "\n";
                        }
                    }
                    std::cout << "\n";
                }
            }
        }
    }
}

template<typename T>
inline int cmp_tensor(
    ActTensor<T> expected,
    ActTensor<T> received,
    int ofm_sign,
    int64_t epsilon = 0)
{
    assert(expected.C == received.C);
    assert(expected.Y == received.Y);
    assert(expected.X == received.X);

    int err_count = 0;
    for (int y = 0; y < expected.Y; ++y) {
        for (int x = 0; x < expected.X; ++x) {
            for (int c = 0; c < expected.C; ++c) {
                int64_t cpu_val, aie_val;
                if constexpr (is_float16_type_v<T>) {
                    cpu_val = static_cast<int64_t>(std::round(Float16Traits<T>::to_float(expected.at(c, y, x))));
                    aie_val = static_cast<int64_t>(std::round(Float16Traits<T>::to_float(received.at(c, y, x))));
                } else {
                    cpu_val = int64_t(expected.at(c, y, x));
                    aie_val = int64_t(received.at(c, y, x));
                }

                int64_t diff = cpu_val - aie_val;
                diff = (diff < 0) ? -diff : diff;
                bool fail = (diff > epsilon);
                bool warn = (diff > 0);
                if (fail) {
                    err_count += 1;
                }
                if (fail) {
                    std::cout << "ERROR: [Y: " << y << ", X: " << x << ", C: " << c << "]: "
                              << "Expected: " << cpu_val << ", "
                              << "Received: " << aie_val << "\n";
                } else if (warn){
                    std::cout << "WARNING: [Y: " << y << ", X: " << x << ", C: " << c << "]: "
                              << "Expected: " << cpu_val << ", "
                              << "Received: " << aie_val << "\n";
                } else {
                    std::cout << "PASS: [Y: " << y << ", X: " << x << ", C: " << c << "]: "
                              << "Expected: " << cpu_val << ", "
                              << "Received: " << aie_val << "\n";
                }
            }
        }
    }
    return err_count;
}

inline void write_bin_file(const std::string& filename, char* data, size_t size) {
    std::fstream file;
    file.open(filename, std::ios::out | std::ios::binary);
    file.write(data, size);
}

#ifdef _WIN32
inline FILE* wfopen_long(const std::string& filename, const wchar_t* mode) {
    std::wstring w = std::filesystem::absolute(filename).wstring();
    if (w.rfind(L"\\\\", 0) == 0)
        w = L"\\\\?\\UNC\\" + w.substr(2);
    else
        w = L"\\\\?\\" + w;
    return _wfopen(w.c_str(), mode);
}
#endif

inline int read_bin_file(const std::string& filename, char* data, size_t size) {
    std::ifstream file(filename, std::ios::binary | std::ios::ate);
    if (!file) {
        std::cerr << "Error: Unable to open file " << filename << std::endl;
        return 0;
    }
    std::streamsize filesize = file.tellg();
    if (static_cast<size_t>(filesize) != size) {
        std::cerr << "Error: Size mismatch for file " << filename << std::endl;
        std::cerr << "  Expected size: " << size << " bytes" << std::endl;
        std::cerr << "  Actual size:   " << filesize << " bytes" << std::endl;
    }
    assert(filesize == size);  // Prevent size mismatch

    file.seekg(0, std::ios::beg);
    if (!file.read(reinterpret_cast<char*>(data), size)) {
        std::cout << "Error: Unable to read file " << filename << std::endl;
        file.close();
        return 0;
    }
    file.close();
    return (size_t)file.gcount();
}

inline size_t read_bin_file(const std::filesystem::path& p, char* data, size_t size) {
    return read_bin_file(p.string(), data, size);
}

template<typename T>
T saturate(T val, T min, T max) {
    return std::min(std::max(val, min), max);
}

template<typename T>
inline T quantize_float_to_int16(float x, int shift, bool sign = true) {
    float scaled = x / (1 << shift);
    if (!sign and scaled < 0.0f) {
            scaled = 0.0f;
    }
    if (sign) {
        // For signed output - divide by 2^shift (not multiply)
        auto rounded = static_cast<int32_t>(scaled + (scaled >= 0 ? 0.5F : -0.5F));
        return static_cast<int16_t>(saturate<int32_t>(rounded, INT16_MIN, INT16_MAX));
    }
    // For unsigned output - divide by 2^shift (not multiply)
    auto rounded = static_cast<uint32_t>(scaled + 0.5F);
    return static_cast<uint16_t>(saturate<uint32_t>(rounded, 0, UINT16_MAX));
}

union Float32Bits
{
  uint32_t u;
  float f;
};

float bfloat2float(uint16_t bfloatBits)
{
  const uint32_t kF32BfMantiBitDiff = 16;
  Float32Bits floatBits;
  floatBits.u = static_cast<uint32_t>(bfloatBits) << kF32BfMantiBitDiff;
  return floatBits.f;
}


float fp16_to_fp32(uint16_t h)
{
    uint16_t h_sign = (h & 0x8000) >> 15;
    uint16_t h_exp  = (h & 0x7C00) >> 10;
    uint16_t h_frac = (h & 0x03FF);

    uint32_t f_sign = h_sign << 31;
    uint32_t f_exp, f_frac;

    if (h_exp == 0)
    {
        // Zero or subnormal
        if (h_frac == 0)
        {
            f_exp = 0;
            f_frac = 0;
        }
        else
        {
            // Normalize the subnormal number
            int shift = 0;
            while ((h_frac & 0x0400) == 0)
            {
                h_frac <<= 1;
                shift++;
            }
            h_frac &= 0x03FF;
            f_exp = (127 - 15 - shift) << 23;
            f_frac = h_frac << 13;
        }
    }
    else if (h_exp == 0x1F)
    {
        // Inf or NaN
        f_exp = 0xFF << 23;
        f_frac = h_frac << 13;
    }
    else
    {
        // Normalized number
        f_exp = (h_exp - 15 + 127) << 23;
        f_frac = h_frac << 13;
    }

    uint32_t f_bits = f_sign | f_exp | f_frac;
    float f;
    std::memcpy(&f, &f_bits, sizeof(f));
    return f;
}

inline int extract_json(const nlohmann::json& j, const char* k) {
    if (!j.contains(k) || !j[k].is_number_integer()) {
        throw std::runtime_error(std::string("Missing/non-int key: ") + k);
    }
    return j[k].get<int>();
}

inline float extract_json_float(const nlohmann::json& j, const char* k) {
    if (!j.contains(k) || !j[k].is_number_float()) {
        throw std::runtime_error(std::string("Missing/non-float key: ") + k);
    }
    return j[k].get<float>();
}

inline std::string extract_json_str(const nlohmann::json& json_obj, const char* key) {
    // Check if key exists
    if (!json_obj.contains(key)) {
        throw std::runtime_error(std::string("Missing key: ") + key);
    }

    const auto& value = json_obj[key];

    // Check for null values
    if (value.is_null()) {
        throw std::runtime_error(std::string("Null value for key: ") + key);
    }

    // Ensure the value is a string
    if (!value.is_string()) {
        throw std::runtime_error(std::string("Non-string value for key: ") + key);
    }

    // Return the string (can be empty "")
    return value.get<std::string>();
}

auto load_json = [](const char* path) {
    return nlohmann::json::parse(std::ifstream{path});
};

void write_external_buffer_json(std::size_t arg0, // OFM
                                std::size_t arg1, // IFM
                                std::size_t arg2, // WGT
                                const std::string& path = "external_buffer_id.json")
{
    nlohmann::json j;

    j["external_buffers"]["buffer0"] = {
        {"xrt_id", 0},
        {"logical_id", 0},
        {"size_in_bytes", arg0},
        {"name", "g.ofm_ddr"}
    };

    j["external_buffers"]["buffer1"] = {
        {"xrt_id", 1},
        {"logical_id", 1},
        {"size_in_bytes", arg1},
        {"name", "g.ifm_ddr"}
    };

    j["external_buffers"]["buffer2"] = {
        {"xrt_id", 2},
        {"logical_id", 2},
        {"size_in_bytes", arg2},
        {"name", "g.wts_ddr"}
    };

    j["external_buffers"]["buffer3"] = {
        {"xrt_id", 3},
        {"logical_id", 3},
        {"size_in_bytes", 12288},
        {"name", "g.param_ddr"}
    };

    std::ofstream f(path);
    f << j.dump(4) << std::endl;
}

void* allocate(int num_bytes, bool gmio_alloc=true)
{
#if !ASM_MODE
    return gmio_alloc? adf::GMIO::malloc(num_bytes) : malloc(num_bytes);
#else
    return malloc(num_bytes);
#endif
}

void deallocate(void* ptr, bool gmio_free=true)
{
#if !ASM_MODE
    gmio_free? adf::GMIO::free(ptr) : free(ptr);
#else
    free(ptr);
#endif
}

#if USE_CERT_LIBRARY
inline void load_param_bin(const std::string& filename, void* dst, size_t bytes) {
    std::ifstream in(filename, std::ios::binary);
    in.read(static_cast<char*>(dst), static_cast<std::streamsize>(bytes));
}

void inline executeCERTSim(std::string const& control_elf_path,
                           std::vector<uint32_t> const& columns,
                           std::map<std::string, patch> const& sym_tbl);

using SymbolsMap = std::map<std::string, patch>;

inline SymbolsMap initSymbols(uint64_t addr0,
                              uint64_t addr1,
                              uint64_t addr2,
                              uint64_t addr3,
                              bool isCERTPatching)
{
    SymbolsMap table;
    table["0"] = {addr0, isCERTPatching};
    table["1"] = {addr1, isCERTPatching};
    table["2"] = {addr2, isCERTPatching};
    table["3"] = {addr3, isCERTPatching};
    return table;
}

void GetPhyAddr(void* virtualAddr, int size, uint64_t& phyAddr)
{
    if (adf::GMIO::get_ddr_address(virtualAddr, size, phyAddr) == adf::return_code::ok)
        std::cout << "Found phy address:" << std::hex << phyAddr << std::endl;
    else
        std::cout << "Error phy address not found\n";
}

inline void run_cert_sim(ComputeGraph compute_graph,
                         void* aie_ofm, std::size_t ofm_size,
                         void* aie_ifm, std::size_t ifm_size,
                         void* aie_wgt, std::size_t wgt_size,
                         std::size_t prm_size = 12288) {
    // Init graph enables isolation
    compute_graph.init();

    // Load param.bin to a GMIO virtual_addr
    auto aie_prm = static_cast<int8_t*>(adf::GMIO::malloc(prm_size));
    load_param_bin("param.bin", aie_prm, prm_size);

    // Get a corresponding physical_addr for GMIO virtual_addr
    uint64_t ofm_phys = 0, ifm_phys = 0, wgt_phys = 0, prm_phys = 0;
    GetPhyAddr(aie_ofm, ofm_size, ofm_phys);
    GetPhyAddr(aie_ifm, ifm_size, ifm_phys);
    GetPhyAddr(aie_wgt, wgt_size, wgt_phys);
    GetPhyAddr(aie_prm, prm_size, prm_phys);

    // Generate a map of xrt_id to physical_addr
    std::map<std::string, patch> symbolTable;
    symbolTable = initSymbols(ofm_phys, ifm_phys, wgt_phys, prm_phys, false);
    std::vector<uint32_t> columns = {0,1,2};

    // Execute CERT_SIM
    adf::syncPSToGM();
    executeCERTSim("./Work_AIE4/control.elf", columns, symbolTable);
    adf::syncPSFromGM();

    adf::GMIO::free(aie_prm);
}
#endif // USE_CERT_LIBRARY

struct BinaryQDQParams {
    float dq_a_zp;
    float dq_a_sc;
    float dq_b_zp;
    float dq_b_sc;
    float q_zp;
    float q_sc;
    bool dq_enable;
    bool q_enable;
};

float generateRandomFloat(float min, float max) {
    std::random_device rd;  // Seed for random number engine
    std::mt19937 gen(rd()); // Mersenne Twister engine
    std::uniform_real_distribution<float> dist(min, max);
    return dist(gen);
}

inline uint32_t float_to_uint(float f)
{
    uint32_t i = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_i[0] = ptr_f[0];
    ptr_i[1] = ptr_f[1];
    ptr_i[2] = ptr_f[2];
    ptr_i[3] = ptr_f[3];
    return i;
}

inline bfloat16_t float_to_bfloat16(float fp)
{
    uint32_t bits = float_to_uint(fp);
    uint32_t lsb = (bits >> 16) & 0x1;
    uint32_t bias = 0x7FFF + lsb;
    uint32_t rnd = bits + bias;
    return bfloat16_t{uint16_t(rnd >> 16)};
}

// Convert float32 to float16 (IEEE 754 half precision)
uint16_t float32_to_float16(float value) {
    uint32_t bits;
    std::memcpy(&bits, &value, sizeof(bits));

    uint32_t sign     = (bits >> 31) & 0x1;
    uint32_t exponent = (bits >> 23) & 0xFF;
    uint32_t mantissa = bits & 0x7FFFFF;

    uint16_t hsign, hexponent, hmantissa;

    hsign = (uint16_t)(sign << 15);

    if (exponent == 255) {
        // Inf or NaN
        hexponent = 0x1F;
        hmantissa = (mantissa ? 0x200 : 0);
    } else if (exponent > 142) {
        // Overflow -> Inf
        hexponent = 0x1F;
        hmantissa = 0;
    } else if (exponent < 113) {
        // Subnormal or zero
        if (exponent < 103) {
            // Too small -> zero
            hexponent = 0;
            hmantissa = 0;
        } else {
            // Subnormal half precision
            uint32_t shift = 113 - exponent;
            hmantissa = (mantissa | 0x800000) >> (shift + 13);
            hexponent = 0;
        }
    } else {
        // Normalized number
        hexponent = exponent - 112;
        hmantissa = mantissa >> 13;
    }

    return (uint16_t)(hsign | (hexponent << 10) | hmantissa);
}

inline float uint_to_float(uint32_t i)
{
    float f = 0;
    char* ptr_f = reinterpret_cast<char*>(&f);
    char* ptr_i = reinterpret_cast<char*>(&i);
    ptr_f[0] = ptr_i[0];
    ptr_f[1] = ptr_i[1];
    ptr_f[2] = ptr_i[2];
    ptr_f[3] = ptr_i[3];
    return f;
}
inline float bfloat16_to_float(bfloat16_t bf)
{
    return uint_to_float(uint32_t(bf.value) << 16);
}

inline float16_t float_to_float16(float fp)
{
    return float16_t{float32_to_float16(fp)};
}

inline float float16_to_float(float16_t hf)
{
    uint16_t h = hf.value;
    uint32_t sign = (h & 0x8000) << 16;
    int32_t exp = (h >> 10) & 0x1F;
    uint32_t mant = h & 0x3FF;

    uint32_t bits;
    if (exp == 0) {
        if (mant == 0) {
            bits = sign;  // Zero
        } else {
            // Subnormal -> normalize
            exp = 1;
            while ((mant & 0x400) == 0) {
                mant <<= 1;
                exp--;
            }
            mant &= 0x3FF;
            bits = sign | ((exp + 127 - 15) << 23) | (mant << 13);
        }
    } else if (exp == 31) {
        bits = sign | 0x7F800000 | (mant << 13);  // Inf or NaN
    } else {
        bits = sign | ((exp + 127 - 15) << 23) | (mant << 13);
    }
    return uint_to_float(bits);
}

// Deferred implementations of Float16Traits::from_float (needs float_to_bfloat16/float_to_float16)
inline bfloat16_t Float16Traits<bfloat16_t>::from_float(float val) { return float_to_bfloat16(val); }
inline float16_t Float16Traits<float16_t>::from_float(float val) { return float_to_float16(val); }

template<typename T1, typename T2, typename T3>
void dequant(T1 in_data, T2 out_data, int h_in, int w_in, int c_in, float s, T3 z)
{
    for(int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in; ++c) {

                using ElemT = std::remove_pointer_t<T1>;
                if constexpr (is_float16_type_v<ElemT>) {
                    throw std::invalid_argument("dequant should not be called with float16-type input");
                } else {
                    float val = in_data[(i * w_in * c_in) + (j * c_in) + c] - z;
                    float val1 = val * s;
                    out_data[(i * w_in * c_in) + (j * c_in) + c] = Float16Traits<QDQFloatType>::from_float(val1);
                }
            }
        }
    }
}

template<typename T1, typename T2, typename OutT>
void quant_float16_to_uint16(T1 in_data, OutT* out_data, int h_in, int w_in, int c_in, float inv_s, T2 z)
{
    for(int i = 0; i < h_in; ++i) {
        for (int j = 0; j < w_in; ++j) {
            for (int c = 0; c < c_in; ++c) {
                float val = Float16Traits<QDQFloatType>::to_float(in_data[(i * w_in * c_in) + (j * c_in) + c]);
                val = std::round(val * inv_s) + z;
                if constexpr (std::is_same<OutT, int16_t>::value || std::is_same<OutT, int8_t>::value) {
                    out_data[(i * w_in * c_in) + (j * c_in) + c] = static_cast<OutT>(quantize_float_to_int16<int16_t>(val, 0, true));
                } else if constexpr (std::is_same<OutT, uint16_t>::value || std::is_same<OutT, uint8_t>::value) {
                    out_data[(i * w_in * c_in) + (j * c_in) + c] = static_cast<OutT>(quantize_float_to_int16<uint16_t>(val, 0, false));
                }
            }
        }
    }
}

// backwards compatibility
template<typename T1, typename T2, typename OutT>
void quant_bfloat16_to_uint16(T1 in_data, OutT* out_data, int h_in, int w_in, int c_in, float inv_s, T2 z)
{
    quant_float16_to_uint16(in_data, out_data, h_in, w_in, c_in, inv_s, z);
}

template <typename T>
std::vector<float> linspace(T start_in, T end_in, int num_in)
{
    std::vector<float> linspaced;
    float start = static_cast<float>(start_in);
    float end = static_cast<float>(end_in);
    float num = static_cast<float>(num_in);

    if (num == 0) {
        return linspaced;
    }
    if (num == 1) {
        linspaced.push_back(start);
        return linspaced;
    }
    float delta = (end - start) / (num - 1);

    for (int i = 0; i < num - 1; ++i) {
        linspaced.push_back(start + delta * i);
    }
    linspaced.push_back(end);

    return linspaced;
}

template <typename Ti>
void load_ifm(const std::string& file_path, Ti* ifm_ptr, int chs, int chs_orig, int rows, int cols)
{
    size_t size_unpadded = chs_orig * rows * cols * sizeof(Ti);

    char* buf = (char*)allocate(size_unpadded);
    if (buf == NULL)
    {
        throw std::runtime_error("[ERR] Unable to allocate IFM buffer");
    }

    size_t bytes_read = read_bin_file(file_path, buf, size_unpadded);
    if (bytes_read != size_unpadded)
    {
        throw std::runtime_error(
            "IFM file read size mismatch: expected " +
            std::to_string(size_unpadded) +
            ", received " +
            std::to_string(bytes_read)
        );
    }

    Ti* buf_ptr = reinterpret_cast<Ti *>(buf);
    for (uint32_t i = 0; i < rows; i++) {
        for (uint32_t j = 0; j < cols; j++) {
            for (uint32_t k = 0; k < chs; k++) {
                if (k < chs_orig) {
                    *ifm_ptr = *buf_ptr;
                    buf_ptr++;
                }
                else {
                    if constexpr (is_float16_type_v<Ti>) {
                        ifm_ptr->value = 0;
                    }
                    else {
                        *ifm_ptr = 0;
                    }
                }
                ifm_ptr++;
            }
        }
    }
    deallocate(buf);
}

#endif // COMMON_HPP