#ifndef GEMM_HPP
#define GEMM_HPP

#include <random>
#include <nlohmann/json.hpp>
#include "common.hpp"
#include "qdq_utils_aie4.hpp"

using namespace waic_runtime_aie4;
// Added: reusable float RNG in [0,1)
inline float random_float() {
    static thread_local std::mt19937 generator(std::random_device{}());
    static thread_local std::uniform_real_distribution<float> distribution(0.0f, 1.0f);
    return distribution(generator);
}

struct GemmQdqint16x8_RT_Params
{
    int16_t shift_out;
    uint8_t ifm_sign;
    uint8_t wgt_sign;
    uint8_t ofm_sign;
    // NOTE: Used to align the struct size to 128 bytes.
    uint8_t reserved[123];
};

//TODO: Add support for vectorized C1 and C2 coeffs

struct GemmQdqint16x16_RT_Params
{
    int16_t shift_res = 0; // unused for here; 
    uint8_t sign_A = 0;  // unused for here;
    uint8_t sign_W = 0;  // unused for here;
    uint8_t sign_O = 0;  // unused for here; 
    // 0: scalar, 1: vectorized C1, 2: vectorized C2
    uint8_t vector_coeff = 0; // unused for here; kernel input only
    uint8_t pad[2] = {0}; 
    float c0 = 0.0;
    float c1 = 0.0;
    float c2 = 0.0;
    float c3 = 0.0;
    // NOTE: Used to align the struct size to 128 bytes.
    uint8_t reserved[104] = {0};
};

template<typename Ta, typename Tw, typename Tacc>
void cpu_matmul(
    ActTensor<Ta> act1,
    ActTensor<Tw> act2,
    ActTensor<Tacc> ofm,
    int const transpose_wgts
)
{
    if (transpose_wgts) {
    assert(act1.C == act2.C);
    assert(ofm.X == act1.X);
    assert(ofm.C == act2.X);
    for (int b = 0; b < ofm.Y; ++b) {
        for (int m = 0; m < ofm.X; ++m) {
            for (int n = 0; n < ofm.C; ++n) {
                Tacc acc = 0;
                for (int k = 0; k < act1.C; ++k) {
                    int a = act1.at(k, b, m);
                    int w = act2.at(k, b, n);
                    acc += a * w;
                }
                ofm.at(n, b, m) = acc; 
            }
        }
    }
    }
    else {
    //assert(act1.C == act2.X);
    assert(ofm.X == act1.X);
    assert(ofm.C == act2.C);
    for (int b = 0; b < ofm.Y; ++b) {
        for (int m = 0; m < ofm.X; ++m) {
            for (int n = 0; n < ofm.C; ++n) {
                Tacc acc = 0;
                for (int k = 0; k < std::min(act1.C, act2.X); ++k) {
                    int a = act1.at(k, b, m);
                    int w = act2.at(n, b, k);
                    acc += a * w;
                }
                ofm.at(n, b, m) = acc; 
            }
        }
    }
    }

}

template<typename Ta, typename Tw, typename To, typename Tacc>
inline void cpu_act_x_act_term_qdq(
    ActTensor<Ta> act1,
    ActTensor<Tw> act2,
    ActTensor<Tacc> gemm_out,
    ActTensor<To> out,
    GemmQdqint16x16_RT_Params params,
    int const transpose_wgts
)
{
    /*
        QDQ formula:
        Y = gemm_out * qdq_params.C2 + ifm_sum * qdq_params.C1 + wgt_sum * qdq_params.C3 + C0;
    */
    ActTensor<Tacc> ifm_sum(1, act1.Y, act1.X, malloc(sizeof(Tacc) * act1.X * act1.Y));
    for (int b = 0; b < act1.Y; ++b) {
        for (int r = 0; r < act1.X; ++r) {
            Tacc sum = 0;
            for (int c = 0; c < std::min(act1.C, act2.X); ++c) {
                sum += act1.at(c, b, r);
            }
            ifm_sum.at(0, b, r) = sum;
        }
    }
    // FIX: wgt_sum must have one entry per output channel (act2.X == out.C == N)
    ActTensor<Tacc> wgt_sum(out.C, act2.Y, 1, malloc(sizeof(Tacc) * out.C * act2.Y));
    for (int b = 0; b < act2.Y; ++b) {
        for (int oc = 0; oc < out.C; ++oc) {
            Tacc sum = 0;
            if (transpose_wgts) {
                for (int k = 0; k < act2.C; ++k) { // sum across input depth K
                    sum += act2.at(k, b, oc);
                }
            } else {
            for (int k = 0; k < act2.X; ++k) { // sum across input depth K
                sum += act2.at(oc, b, k);
            }
          }
            wgt_sum.at(oc, b, 0) = sum;
        }
    }

    for (int b = 0; b < out.Y; ++b) {
        for (int r = 0; r < out.X; ++r) {
            for (int c = 0; c < out.C; ++c) {
                Tacc gemm_val = gemm_out.at(c, b, r);
                Tacc ifm_sum_val = ifm_sum.at(0, b, r);
                Tacc wgt_sum_val = wgt_sum.at(c, b, 0);
                float res = (gemm_val * params.c2) +
                            (ifm_sum_val * params.c1) +
                            (wgt_sum_val * params.c3) +
                            params.c0;
                out.at(c, b, r) = quantize_float_to_int16<To>(res, params.shift_res, params.sign_O == 1);
            }
        }
    }
    free(ifm_sum.data);
    free(wgt_sum.data);
}


template<typename Ta, typename Tw>
inline void init_random_gemm_a16a16(
    ActTensor<Ta> act1,
    ActTensor<Tw> act2,
    GemmQdqint16x16_RT_Params qdq_params,
    int const transpose_wgts,
    int const Mgemm_orig,
    int const Kgemm_orig,
    int const Ngemm_orig
)
{   //TODO: given the input range find out ideal values
    float c0_min = (qdq_params.sign_O == 1) ? -16.828F : 0.0F;
    float c0_max = +16.828F;                            
    float c1_min = (qdq_params.sign_O == 1) ? -2.828F : 0.0F;
    float c1_max = +2.828F;                             
    float c2_min = (qdq_params.sign_O == 1) ? -1.828F : 0.0F;
    float c2_max = +1.828F;

    int64_t ifm_min = (qdq_params.sign_A == 1) ? -16 : 0;
    int64_t ifm_max = (qdq_params.sign_A == 1) ? +16 : 64;
    int64_t wgt_min = (qdq_params.sign_W == 1) ? -16 : 0;
    int64_t wgt_max = (qdq_params.sign_W == 1) ? +16 : 64;
    
    // Initialize act1 and act2
    for(int b = 0; b < act1.Y; ++b) {
        for(int r = 0; r < act1.X; ++r) {
            for (int c = 0; c < act1.C; ++c) {
                if (c < Kgemm_orig) {
                act1.at(c, b, r) = Ta((rand() % (ifm_max - ifm_min)) + ifm_min); 
                } else {
                act1.at(c, b, r) = Ta(0);
            }
            }
        }
    }
    for(int b = 0; b < act2.Y; ++b) {
        for(int r = 0; r < act2.X; ++r) {
            for (int c = 0; c < act2.C; ++c) {
                if (c < Ngemm_orig) {
                act2.at(c, b, r) = Tw((rand() % (ifm_max - ifm_min)) + ifm_min);
                } else {
                act2.at(c, b, r) = Tw(0);
                }
            }
        }
    }
    // Replaced repeated (float(rand()) / RAND_MAX) with random_float()
    qdq_params.c0 = c0_min + random_float() * (c0_max - c0_min);

    if (qdq_params.vector_coeff > 0) {
        qdq_params.c1 = c1_min + random_float() * (c1_max - c1_min);
    } else {
        float wgt_c1_scalar = c1_min + random_float() * (c1_max - c1_min);
        qdq_params.c1 = wgt_c1_scalar;
    }

    if (qdq_params.vector_coeff > 1) {
        qdq_params.c2 = c2_min + random_float() * (c2_max - c2_min);
    } else {
        float wgt_c2_scalar = c2_min + random_float() * (c2_max - c2_min);
        qdq_params.c2 = wgt_c2_scalar;
    }
}

template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2>
inline void init_model_data(
    std::string const_path,
    std::string node_name,
    GemmQdqint16x16_RT_Params& qdq_params,
    int K,
    int const vec_coeff,
    int const debug_mode
) {
    // Need to check whether K is padded, if no calculate padded shape
    int ch_size = K;
    // get scale, zp
    std::vector<json> scale_zp = get_scale_zp_vector(const_path, node_name);
    int scale_zp_size = 6;

    if (scale_zp.size() != scale_zp_size) {
        std::cout << "scale_zp vector has wrong size for node "
            << node_name << std::endl;
    }
    std::tuple<float, float, float, float> qdq_values;
    float in_s = scale_zp[0];
    uint16_t in_zp = scale_zp[1];
    float w_s = scale_zp[2];
    uint16_t w_zp = scale_zp[3];
    float o_s = scale_zp[4];
    uint16_t o_zp = scale_zp[5];
    qdq_values = qdq_act_matmul_uint16_uint16_cstm(in_s, in_zp, ch_size, w_s, w_zp, o_s, o_zp);

    // Get qdq values
    float C0 = std::get<0>(qdq_values);
    float C1 = std::get<1>(qdq_values);
    float C2 = std::get<2>(qdq_values);
    float C3 = std::get<3>(qdq_values);
    if (debug_mode) {
        printf("C0 value = %.8f\n", C0);
        printf("C1 value = %.8f\n", C1);
        printf("C2 value = %.8f\n", C2);
        printf("C3 value = %.8f\n", C3);
    }

    // Set qdq values
    qdq_params.c0 = C0;
    qdq_params.c1 = C1;
    qdq_params.c2 = C2;
    qdq_params.c3 = C3;
}

#endif //GEMM_HPP