#ifndef ADD_HPP
#define ADD_HPP

#include <random>
#include "common.hpp"
#include "qdq_utils_aie4.hpp"

int8_t srs_int8(int64_t acc, int shift)
{
    fesetround(FE_TONEAREST);
    acc = int64_t(nearbyint(acc / double(int64_t(1) << shift)));
    return (acc > INT8_MAX) ? INT8_MAX :
           (acc < INT8_MIN) ? INT8_MIN :
                               static_cast<int8_t>(acc);
}

void add(ActTensor<int8_t> matA, ActTensor<int8_t> matB, ActTensor<int8_t> ofm, int shift_in, int shift_in1, int shift_res) {
    for (int i = 0; i < matA.C; ++i) {      
        for (int y = 0; y < matA.Y; ++y) {      
            for (int j = 0; j < matA.X; ++j) {
                int32_t val_a = matA.at(i, y, j) << shift_in;
                int32_t val_b = matB.at(i, y, j) << shift_in1;
                ofm.at(i, y, j) = srs_int8(val_a + val_b, shift_res);
            }
        }
    }
}

template<typename T, typename = std::enable_if_t<is_float16_type_v<T>>>
void add(ActTensor<T> matA, ActTensor<T> matB, ActTensor<T> ofm) {
    for (int i = 0; i < matA.C; ++i) {
        for (int y = 0; y < matA.Y; ++y) {
            for (int j = 0; j < matA.X; ++j) {
                float val_a = Float16Traits<T>::to_float(matA.at(i, y, j));
                float val_b = Float16Traits<T>::to_float(matB.at(i, y, j));
                float val_o = val_a + val_b;
                ofm.at(i, y, j) = Float16Traits<T>::from_float(val_o);
            }
        }
    }
}

template<typename T, typename = std::enable_if_t<is_float16_type_v<T>>>
void mul(ActTensor<T> matA, ActTensor<T> matB, ActTensor<T> ofm) {
    for (int i = 0; i < matA.C; ++i) {
        for (int y = 0; y < matA.Y; ++y) {
            for (int j = 0; j < matA.X; ++j) {
                float val_a = Float16Traits<T>::to_float(matA.at(i, y, j));
                float val_b = Float16Traits<T>::to_float(matB.at(i, y, j));
                float val_o = val_a * val_b;
                ofm.at(i, y, j) = Float16Traits<T>::from_float(val_o);
            }
        }
    }
}

template<typename T> 
void init_tensor_random(ActTensor<T>& tensor, int64_t c_orig, int64_t min, int64_t max)
{
    for (int c = 0; c < tensor.C; ++c) {
        for (int y = 0; y < tensor.Y; ++y) {
            for (int x = 0; x < tensor.X; ++x) {
                bool in_orig = (c < c_orig);
                int val = in_orig ? ((rand() % (max - min + 1)) + min) : 0;

                if constexpr (is_float16_type_v<T>) {
                    tensor.at(c, y, x) = Float16Traits<T>::from_float(static_cast<float>(val));
                } else {
                    tensor.at(c, y, x) = T(val);
                }
            }
        }
    }
}


template <typename Ti>
void read_binary_model_data(
    std::string const_path,
    std::string node_name,
    Ti *ifm_a_ptr,
    Ti *ifm_b_ptr,
    int chs,
    int chs_orig,
    int rows,
    int cols,
    int a_on_wgt,
    int b_on_wgt
) {
    waic_runtime_aie4::replace_symbols(node_name);
    std::filesystem::path ifm_b_file = {const_path + "/" + node_name + "/" + "B.bin"};
    std::filesystem::path ifm_a_file = {const_path + "/" + node_name + "/" + "A.bin"};

    if (a_on_wgt) {
        if (!std::filesystem::exists(ifm_a_file)) {
            throw std::runtime_error("A.bin not found at " + ifm_a_file.string());
        }
        try {
            load_ifm<Ti>(ifm_a_file.string(), ifm_a_ptr, chs, chs_orig, rows, cols);
        }
        catch (const std::exception& e) {
            std::cerr << "[ERR] Failed to read IFM A:" << e.what() << std::endl;
            std::exit(EXIT_FAILURE);
        }
    }
    if (b_on_wgt) {
        if (!std::filesystem::exists(ifm_b_file)) {
            throw std::runtime_error("B.bin not found at " + ifm_b_file.string());
        }
        try {
            load_ifm<Ti>(ifm_b_file.string(), ifm_b_ptr, chs, chs_orig, rows, cols);
        }
        catch (const std::exception& e) {
            std::cerr << "[ERR] Failed to read IFM B:" << e.what() << std::endl;
            std::exit(EXIT_FAILURE);
        }
    }
}


template <typename Ti>
inline void binary_op_init_model_data(
    std::string const_path,
    std::string node_name,
    BinaryQDQParams* ptr_qdq_prm,
    bool dequant_enable,
    bool quant_enable,
    Ti *ifm_a_ptr,
    Ti *ifm_b_ptr,
    int chs,
    int chs_orig,
    int rows,
    int cols,
    int a_on_wgt,
    int b_on_wgt
) {
    // get scale, zp
    std::vector<nlohmann::json> scale_zp = waic_runtime_aie4::get_scale_zp_vector(const_path, node_name);

    if(dequant_enable)
    {
        float in_a_s = scale_zp[0];
        float in_a_zp = scale_zp[1];
        float in_b_s = scale_zp[2];
        float in_b_zp = scale_zp[3];

        ptr_qdq_prm->dq_a_zp = in_a_zp;
        ptr_qdq_prm->dq_a_sc = in_a_s;
        ptr_qdq_prm->dq_b_zp = in_b_zp;
        ptr_qdq_prm->dq_b_sc = in_b_s;
        ptr_qdq_prm->dq_enable = 1;
    }

    if(quant_enable)
    {
        float o_s;
        float o_zp;
        if (dequant_enable)
        {
            o_s = scale_zp[4];
            o_zp = scale_zp[5];
        } else
        {
            o_s = scale_zp[0];
            o_zp = scale_zp[1];
        }
        o_s = 1.0f / o_s;
        ptr_qdq_prm->q_zp = o_zp;
        ptr_qdq_prm->q_sc = o_s;
        ptr_qdq_prm->q_enable = 1;
    }

    read_binary_model_data(const_path, node_name, ifm_a_ptr, ifm_b_ptr, chs, chs_orig, rows, cols, a_on_wgt, b_on_wgt);

}

#endif // ADD_HPP