#ifndef ADD_BDCAST_HPP
#define ADD_BDCAST_HPP

#include <iostream>
#include <random>
#include "common.hpp"
#include "qdq_utils_aie4.hpp"

const int ADD = 0;
const int SUB = 1;
const int MUL = 2;
const int DIV = 3;

struct BroadcastParams16 {
    int ifm_a_n, ifm_a_chs, ifm_a_chs_orig, ifm_a_cols, ifm_a_rows;
    int ifm_b_n, ifm_b_chs, ifm_b_chs_orig, ifm_b_cols, ifm_b_rows;
    int ofm_n, ofm_chs, ofm_chs_orig, ofm_cols, ofm_rows;
    int ifm_bytes, ofm_bytes;
    int wgt_size, dq_buf_offset, q_buf_offset;
    float dq_a_zp, dq_a_sc;
    float dq_b_zp, dq_b_sc;
    float q_zp, q_sc;
    bool dq_enable, q_enable;
    int read_ifm, op_type;
    int read_model_data;
    std::string const_path;
    std::string node_name;
    int sign_A, sign_O;
    int a_on_wgt;
    int b_on_wgt;
    int has_scalar_broadcast;
    int debug_mode; // 1 enables CPU reference path
};

int8_t srs_int8(int64_t acc, int shift)
{
    fesetround(FE_TONEAREST);
    acc = int64_t(nearbyint(acc / double(int64_t(1) << shift)));
    return (acc > INT8_MAX) ? INT8_MAX :
           (acc < INT8_MIN) ? INT8_MIN :
                               static_cast<int8_t>(acc);
}

void add_broadcast_8(ActTensor<int8_t> matA, ActTensor<int8_t> matB, ActTensor<int8_t> ofm, int shift_in, int shift_in1, int shift_res) {
    for (int y = 0; y < ofm.Y; ++y) {
        for (int x = 0; x < ofm.X; ++x) {
            for (int c = 0; c < ofm.C; ++c) {

                int cA = (matA.C == 1) ? 0 : c;
                int yA = (matA.Y == 1) ? 0 : y;
                int xA = (matA.X == 1) ? 0 : x;

                int cB = (matB.C == 1) ? 0 : c;
                int yB = (matB.Y == 1) ? 0 : y;
                int xB = (matB.X == 1) ? 0 : x;

                int32_t val_a = matA.at(cA, yA, xA) << shift_in;
                int32_t val_b = matB.at(cB, yB, xB) << shift_in1;
                ofm.at(c, y, x) = srs_int8(val_a + val_b, shift_res);
            }
        }
    }
}

template<typename T>
void broadcast_16(ActTensor<T> matA, ActTensor<T> matB, ActTensor<T> ofm, 
                  int ifm_a_n, int ifm_a_y, int ifm_b_n, int ifm_b_y, int ofm_n, int ofm_y,
                  int op_type, int has_scalar_broadcast) {
    static_assert(is_float16_type_v<T>, "broadcast_16 only supports bfloat16_t or float16_t");
    using Traits = Float16Traits<T>;
    for (int n = 0; n < ofm_n; ++n) {
        for (int y = 0; y < ofm_y; ++y) {
            for (int x = 0; x < ofm.X; ++x) {
                for (int c = 0; c < ofm.C; ++c) {

                    int yA = (ifm_a_y == 1) ? 0 : y;
                    int xA = (matA.X == 1) ? 0 : x;
                    int cA = (matA.C == 1) ? 0 : c;

                    int yB = (ifm_b_y == 1) ? 0 : y;
                    int xB = (matB.X == 1) ? 0 : x;
                    int cB = (has_scalar_broadcast) ? 0 : c;

                    // perform manual indexing to support folded n y dimension, since
                    // ActTensor only supports Y, X, and C.
                    int idx_a = (n % ifm_a_n) * (ifm_a_y * matA.X * matA.C) + (yA * matA.X * matA.C) + (xA * matA.C) + cA;
                    int idx_b = (n % ifm_b_n) * (ifm_b_y * matB.X * matB.C) + (yB * matB.X * matB.C) + (xB * matB.C) + cB;
                    int idx_o = (n % ofm_n) * (ofm_y * ofm.X * ofm.C) + (y * ofm.X * ofm.C) + (x * ofm.C) + c; 

                    auto val_a = Traits::to_float(matA.data[idx_a]);
                    auto val_b = Traits::to_float(matB.data[idx_b]);

                    float out;
                    switch (op_type){
                        case ADD:
                            out = val_a + val_b;
                            break;
                        case MUL:
                            out = val_a * val_b;
                            break;
                        case SUB:
                            out = val_a - val_b;
                            break;
                        case DIV:
                            out = val_a / val_b;
                            break;
                        default:
                            throw std::invalid_argument("Unsupported operation type in broadcast_16");
                    }

                    ofm.data[idx_o] = Traits::from_float(out);
                }
            }
        }
    }
}

template<typename T> 
void init_tensor_random(ActTensor<T>& tensor, int c_pad, int64_t min, int64_t max)
{
    for (int c = 0; c < tensor.C; ++c) {
        for (int y = 0; y < tensor.Y; ++y) {
            for (int x = 0; x < tensor.X; ++x) {

                if constexpr (is_float16_type_v<T>) {
                    float val = (c >= c_pad) ? 0.0f : static_cast<float>((rand() % (max - min + 1)) + min);
                    tensor.at(c, y, x) = Float16Traits<T>::from_float(val);
                } else {
                    if (c >= c_pad){
                        tensor.at(c, y, x) = T(0);
                    } else{
                        tensor.at(c, y, x) = T((rand() % (max - min + 1)) + min);
                    }
                }
            }
        }
    }
}


template <typename Ti>
void read_broadcast_model_data(
    const BroadcastParams16& params,
    Ti *ifm_a_ptr,
    Ti *ifm_b_ptr,
    int a_on_wgt,
    int b_on_wgt
) {
    std::string const_path = params.const_path;
    std::string node_name = params.node_name;

    waic_runtime_aie4::replace_symbols(node_name);
    std::filesystem::path ifm_b_file = {const_path + "/" + node_name + "/" + "B.bin"};
    std::filesystem::path ifm_a_file = {const_path + "/" + node_name + "/" + "A.bin"};

    if (a_on_wgt) {
        if (!std::filesystem::exists(ifm_a_file)) {
            throw std::runtime_error("IFM A file not found: " + ifm_a_file.string());
        }
        load_ifm(ifm_a_file.string(), ifm_a_ptr, params.ifm_a_chs, params.ifm_a_chs_orig, params.ifm_a_n * params.ifm_a_rows, params.ifm_a_cols);
    }
    if (b_on_wgt) {
        if (!std::filesystem::exists(ifm_b_file)) {
            throw std::runtime_error("IFM B file not found: " + ifm_b_file.string());
        }
        load_ifm(ifm_b_file.string(), ifm_b_ptr, params.ifm_b_chs, params.ifm_b_chs_orig, params.ifm_b_n * params.ifm_b_rows, params.ifm_b_cols);
    }
}

template <typename Ti>
inline void broadcast_op_init_model_data(
    const BroadcastParams16& params,
    BinaryQDQParams* ptr_qdq_prm,
    Ti *ifm_a_ptr,
    Ti *ifm_b_ptr,
    int a_on_wgt,
    int b_on_wgt
) {
    std::string const_path = params.const_path;
    std::string node_name = params.node_name;
    bool dequant_enable = params.dq_enable;
    bool quant_enable = params.q_enable;

    // get scale, zp
    std::vector<nlohmann::json> scale_zp = waic_runtime_aie4::get_scale_zp_vector(const_path, node_name);

    if(dequant_enable)
    {
        float in_a_s = scale_zp[0];
        float in_a_zp = scale_zp[1];
        float in_b_s = scale_zp[2];
        float in_b_zp = scale_zp[3];

        ptr_qdq_prm->dq_a_zp = in_a_zp;
        ptr_qdq_prm->dq_a_sc = in_a_s;
        ptr_qdq_prm->dq_b_zp = in_b_zp;
        ptr_qdq_prm->dq_b_sc = in_b_s;
        ptr_qdq_prm->dq_enable = 1;
    }

    if(quant_enable)
    {
        float o_s;
        float o_zp;
        if (dequant_enable)
        {
            o_s = scale_zp[4];
            o_zp = scale_zp[5];
        } else
        {
            o_s = scale_zp[0];
            o_zp = scale_zp[1];
        }
        
        ptr_qdq_prm->q_zp = o_zp;
        ptr_qdq_prm->q_sc = 1.0f/o_s;
        ptr_qdq_prm->q_enable = 1;
    }

    read_broadcast_model_data(params, ifm_a_ptr, ifm_b_ptr, a_on_wgt, b_on_wgt);
}

void initialize_qdq_buffer(float* qdq_buf, int dqbuf_offset, int qbuf_offset ) {
    BinaryQDQParams* qdq_params = reinterpret_cast<BinaryQDQParams*>(qdq_buf);
    float* dq_buf = qdq_buf + (dqbuf_offset / sizeof(float));
    float* q_buf  = qdq_buf + (qbuf_offset  / sizeof(float));
    for (int i = 0; i < 32; ++i) {
        // dq kernel expects zp from bytes 0-128 and scales from bytes 256-512 for a given offet.
        // we pack the ifm zero points and scalar values together.
        dq_buf[i   ] = qdq_params->dq_a_zp;
        dq_buf[i+32] = qdq_params->dq_b_zp;
        dq_buf[i+64] = qdq_params->dq_a_sc;
        dq_buf[i+96] = qdq_params->dq_b_sc;
        q_buf [i   ] = qdq_params->q_zp;
        q_buf [i+64] = qdq_params->q_sc;
    }
}

inline void print_qdq_section(const char* label, const float* buf, int offset) {
    std::cout << label;
    for (int i = 0; i < 32; ++i) {
        std::cout << (i ? ", " : " ") << buf[offset + i];
    }
    std::cout << '\n';
}

void print_qdq_buffer(float* qdq_buf, int dqbuf_offset, int qbuf_offset) {
    BinaryQDQParams* qdq_params = reinterpret_cast<BinaryQDQParams*>(qdq_buf);
    float* dq_buf = qdq_buf + (dqbuf_offset / sizeof(float));
    float* q_buf  = qdq_buf + (qbuf_offset  / sizeof(float));

    std::cout << "QDQ params: dq_a_zp=" << qdq_params->dq_a_zp
              << ", dq_b_zp=" << qdq_params->dq_b_zp
              << ", dq_a_sc=" << qdq_params->dq_a_sc
              << ", dq_b_sc=" << qdq_params->dq_b_sc
              << ", q_zp=" << qdq_params->q_zp
              << ", q_sc=" << qdq_params->q_sc << '\n';

    print_qdq_section("DQ A zero points:", dq_buf, 0);
    print_qdq_section("DQ B zero points:", dq_buf, 32);
    print_qdq_section("DQ A scales:", dq_buf, 64);
    print_qdq_section("DQ B scales:", dq_buf, 96);
    print_qdq_section("Q zero points:", q_buf, 0);
    print_qdq_section("Q scales:", q_buf, 64);
}

#endif // ADD_BDCAST_HPP
