#ifndef GEMM_HPP
#define GEMM_HPP

#include <random>
#include <nlohmann/json.hpp>
#include "common.hpp"
#include "qdq_utils_aie4.hpp"

using namespace waic_runtime_aie4;
struct GemmQdqint16x8_RT_Params
{
    int16_t shift_out = 0;
    uint8_t ifm_sign = 0;
    uint8_t wgt_sign = 0;
    uint8_t ofm_sign = 0;
    // NOTE: Used to align the struct size to 128 bytes.
    uint8_t reserved[123] = {0};
};


template<typename Tw, typename Tc0, typename Tc1, typename Tc2>
struct GemmWgtTensor
{
    static const int subv_align_bytes = 128;
    
    int const K;
    int const N;
    int const Ksubv;
    int const Nsubv;
    int const vect_coeff;
    int const is_int4;
    int const subv_wgt_size;
    int const subv_c0_size;
    int const subv_c1_size;
    int const subv_c2_size;
    int const subv_qdq_size;
    int const subv_size;
    int const data_size;
    Tw* const data;

    GemmWgtTensor(int K, int N, int Ksubv, int Nsubv, int vect_coeff, int is_int4, void* data)
        : K(K)
        , N(N)
        , Ksubv(Ksubv)
        , Nsubv(Nsubv)
        , vect_coeff(vect_coeff)
        , is_int4(is_int4)
        , subv_wgt_size(compute_subv_wgt_size(Ksubv, Nsubv,is_int4))
        , subv_c0_size(compute_subv_c0_size(Nsubv))
        , subv_c1_size(compute_subv_c1_size(Nsubv))
        , subv_c2_size(compute_subv_c2_size(Nsubv))
        , subv_qdq_size(compute_subv_qdq_size())
        , subv_size(compute_subv_size(Ksubv, Nsubv,is_int4))
        , data_size(size(K, N, Ksubv, Nsubv, is_int4))
        , data(static_cast<Tw*>(data))
    {
        assert((Ksubv % 64 == 0) && "Ksubv must be multiple of 64");
        assert((Nsubv == 64) && "Nsubv must be 64");
        assert((K % Ksubv == 0) && "K should be divisible by Ksubv");
        assert((N % Nsubv == 0) && "N should be divisible by Nsubv");
    }

    char* subv_ptr(int k, int n)
    {
        assert((0 <= k) && (k < K));
        assert((0 <= n) && (n < N));
        int offset = subv_size * (
            ((k / Ksubv) * (N / Nsubv)) +
            (n / Nsubv)
        );
        return (char *)data + offset; // NOLINT
    }

    Tw& at(int k, int n)
    {
        assert((0 <= k) && (k < K));
        assert((0 <= n) && (n < N));
        // K:Ksubv N:64 K:64
        int subv_idx = 
            (((k % Ksubv) / 64 ) * Nsubv * 64) + 
            ((n % 64) * 64) + 
            (k % 64);
        auto ptr = reinterpret_cast<Tw*>(subv_ptr(k, n));
        return ptr[subv_idx];
    }

    Tw at(int k, int n, int is_int4)
    {
        assert((0 <= k) && (k < K));
        assert((0 <= n) && (n < N));
        // K:Ksubv N:64 K:64
        int subv_idx = 
            (((k % Ksubv) / 64 ) * Nsubv * 64) + 
            ((n % 64) * 64) + 
            (k % 64);
        // For int4, two values are packed per byte
        int byte_idx = subv_idx / 2;
        int nibble_pos = subv_idx % 2;  // 0 = low nibble, 1 = high nibble
        auto ptr = reinterpret_cast<uint8_t*>(subv_ptr(k, n));
        auto byte_value = ptr[byte_idx];
        // Return the value based on the nibble position
        uint8_t nibble = nibble_pos ? static_cast<Tw>((byte_value & 0xF0) >> 4) : static_cast<Tw>(byte_value & 0x0F);

        // Perform sign extension based on Tw type
        if constexpr (std::is_signed_v<Tw>) {
            // For signed types (int8_t): sign extend from 4-bit to 8-bit
            // If bit 3 is set (value >= 8), extend with 1s to make negative
            return static_cast<Tw>((nibble & 0x08) ? (nibble | 0xF0) : nibble);
        } else {
            // For unsigned types (uint8_t): no sign extension needed
            return static_cast<Tw>(nibble);
        }
    }

    void set(int k, int n, int is_int4, Tw value)
    {
        assert((0 <= k) && (k < K));
        assert((0 <= n) && (n < N));

        int subv_idx = 
            (((k % Ksubv) / 64 ) * Nsubv * 64) + 
            ((n % 64) * 64) + 
            (k % 64);

        int byte_idx = subv_idx / 2;
        int nibble_pos = subv_idx % 2;
        auto ptr = reinterpret_cast<uint8_t*>(subv_ptr(k, n));
        if (nibble_pos == 0) {
            // Set low nibble, preserve high nibble
            ptr[byte_idx] = (ptr[byte_idx] & 0xF0) | (value & 0x0F);
        } else {
            // Set high nibble, preserve low nibble
            ptr[byte_idx] = (ptr[byte_idx] & 0x0F) | ((value & 0x0F) << 4);
        }
    }

    Tc0& c0_at(int c)
    {
        assert((0 <= c) && (c < N));
        int subv_idx = c % Nsubv;
        auto ptr = reinterpret_cast<Tc0*>(subv_ptr(K-1, c) + subv_wgt_size); // NOLINT
        return ptr[subv_idx]; // NOLINT

    }

    Tc1& c1_at(int c)
    {
        assert((0 <= c) && (c < N));
        int subv_idx = c % Nsubv;
        auto ptr = reinterpret_cast<Tc1*>(subv_ptr(K-1, c) + subv_wgt_size + subv_c0_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    Tc2& c2_at(int c)
    {
        assert((0 <= c) && (c < N));
        int subv_idx = c % Nsubv;
        auto ptr = reinterpret_cast<Tc2*>(subv_ptr(K-1, c) + subv_wgt_size + subv_c0_size + subv_c1_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    void set_qdq_params(GemmQdqint16x8_RT_Params qdq_params)
    {
        for(int c=0; c<N; c++){
            for(int r=0; r<K; r+=Ksubv){ // NOTE: THe qdq params are copied to all the subvols
                auto qdq_params_ptr = reinterpret_cast<GemmQdqint16x8_RT_Params*>(subv_ptr(r, c) + subv_wgt_size + subv_c0_size + subv_c1_size + subv_c2_size); // NOLINT
                *qdq_params_ptr = qdq_params;
            }
        }
    }

    GemmQdqint16x8_RT_Params get_qdq_params(){
        auto qdq_params_ptr = reinterpret_cast<GemmQdqint16x8_RT_Params*>(subv_ptr(0, 0) + subv_wgt_size + subv_c0_size + subv_c1_size + subv_c2_size); // NOLINT
        return *qdq_params_ptr;
    }


    static int compute_subv_wgt_size(int Ksubv, int Nsubv, int is_int4)
    {   
        int subv_wgt_size = iceil(Ksubv * Nsubv * int(sizeof(Tw)), subv_align_bytes);
        if(is_int4) subv_wgt_size = subv_wgt_size / 2;
        return subv_wgt_size;
    }

    static int compute_subv_c0_size(int Nsubv)
    {
        return iceil(Nsubv * int(sizeof(Tc0)), subv_align_bytes);
    }

    static int compute_subv_c1_size(int Nsubv)
    {
        return iceil(Nsubv * int(sizeof(Tc1)), subv_align_bytes);
    }

    static int compute_subv_c2_size(int Nsubv)
    {
        return iceil(Nsubv * int(sizeof(Tc2)), subv_align_bytes);
    }

    static int compute_subv_qdq_size()
    {
        return sizeof(GemmQdqint16x8_RT_Params);
    }

    static int compute_subv_size(int Ksubv, int Nsubv, int is_int4)
    {
        int subv_wgt_size = compute_subv_wgt_size(Ksubv, Nsubv, is_int4);
        int subv_c0_size = compute_subv_c0_size(Nsubv);
        int subv_c1_size = compute_subv_c1_size(Nsubv);
        int subv_c2_size = compute_subv_c2_size(Nsubv);
        int subv_qdq_size = compute_subv_qdq_size();
        int subv_size = subv_wgt_size + 
                        subv_c0_size + 
                        subv_c1_size + 
                        subv_c2_size + 
                        subv_qdq_size;
        return subv_size;
    }

    static int size(int K, int N, int Ksubv, int Nsubv, int is_int4)
    {
        int num_subv = (K / Ksubv) * (N / Nsubv);
        int subv_size = compute_subv_size(Ksubv, Nsubv, is_int4);
        int total_size = num_subv * subv_size;
        return total_size;

    }

};

template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2, typename Tacc>
void cpu_matmul(
    ActTensor<Ta> ifm,
    GemmWgtTensor<Tw, Tc0, Tc1, Tc2> Wgt,
    ActTensor<Tacc> ofm,
    int is_int4,
    int K,
    int N
)
{
    assert(ifm.C == K);
    assert(ofm.X == ifm.X);
    assert(ofm.C == N);
    for (int m = 0; m < ofm.X; ++m) {
        for (int n = 0; n < ofm.C; ++n) {
            Tacc acc = 0;
            for (int k = 0; k < ifm.C; ++k) {
                int a = ifm.at(k, 0, m);
                int w = is_int4 ? Wgt.at(k, n, is_int4) : Wgt.at(k, n);
                acc += a * w;
            }
            ofm.at(n, 0, m) = acc; 
        }
    }
}

template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2, typename To, typename Tacc>
inline void cpu_3term_qdq(
    ActTensor<Ta> act,
    ActTensor<Tacc> gemm_out,
    GemmWgtTensor<Tw, Tc0, Tc1, Tc2> wgt,
    ActTensor<To> out,
    int const debug_mode
)
{
    /*
        QDQ formula:
        Y = gemm_out * qdq_params.C2 + ifm_sum * qdq_params.C1 + C0;
    */
    ActTensor<Tacc> ifm_sum(1, 1, act.X, malloc(sizeof(Tacc) * act.X));
    for (int r = 0; r < act.X; ++r) {
        Tacc sum = 0;
        for (int c = 0; c < act.C; ++c) {
            sum += act.at(c, 0, r);
        }
        ifm_sum.at(0, 0, r) = sum;
    }
    GemmQdqint16x8_RT_Params qdq_params = wgt.get_qdq_params();
    for (int r = 0; r < out.X; ++r) {
        for (int c = 0; c < out.C; ++c) {
            Tacc gemm_val = gemm_out.at(c, 0, r);
            Tacc ifm_sum_val = ifm_sum.at(0, 0, r);
            float res = (gemm_val * wgt.c2_at(c)) + (ifm_sum_val * wgt.c1_at(c)) + wgt.c0_at(c);
            // Quantize the result to the output type
            out.at(c, 0, r) = quantize_float_to_int16<To>(res, qdq_params.shift_out, qdq_params.ofm_sign == 1);
            if (debug_mode) {
                printf("gemm_val = %ld\n", gemm_val);
                printf("ifm_sum_val = %ld\n", ifm_sum_val);
                printf("res = %f, C2 = %f, C1 = %f, c0 = %f\n", res, wgt.c2_at(c), wgt.c1_at(c), wgt.c0_at(c));
                printf("out.at(%d, %d) = %d\n", r, c, out.at(c, 0, r));
            }
        }
    }
}

template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2>
inline void init_random_gemm_a16w8(
    ActTensor<Ta> act,
    GemmWgtTensor<Tw, Tc0, Tc1, Tc2> wgt,
    GemmQdqint16x8_RT_Params qdq_params,
    int K,
    int N
)
{
    int64_t ifm_min = 0;
    int64_t ifm_max = 64;
    int64_t wgt_min = 0;
    int64_t wgt_max = 64;

    float c0_min = 0.0F;
    float c0_max = +16.828F;                            
    float c1_min = 0.0F;
    float c1_max = +2.828F;                             
    float c2_min = 0.0F;
    float c2_max = +1.828F;

    // Use constexpr to ensure compile-time conversion
    constexpr auto RAND_MAX_F = static_cast<float>(RAND_MAX);

    for(int r = 0; r < act.X; ++r) {
        for (int c = 0; c < act.C; ++c) {
            act.at(c, 0, r) = Ta((rand() % (ifm_max - ifm_min)) + ifm_min); 
        }
    }

    for(int r = 0; r < wgt.K; ++r){
        for(int c = 0; c < wgt.N; ++c){
            bool valid_idx = (r < K) && (c < N);
            wgt.at(r, c) = (valid_idx) ? Tw((rand() % (wgt_max - wgt_min)) + wgt_min) : Tw(0);
        }
    }

    // Generate random C0 coefficients between c0_min and c0_max
    for(int c = 0; c < wgt.N; c++){
        bool valid_idx = (c < N);
        wgt.c0_at(c) = (valid_idx) ? c0_min + static_cast<float>(rand()) / RAND_MAX_F * (c0_max - c0_min) : static_cast<float>(0);
    }

    if (wgt.vect_coeff > 0) {
        // NOTE: This indicates per channel QDQ on C1
        for(int c = 0; c < wgt.N; c++){
            bool valid_idx = (c < N);
            wgt.c1_at(c) = (valid_idx) ? c1_min + static_cast<float>(rand()) / RAND_MAX_F * (c1_max - c1_min) : static_cast<float>(0);
        }
    } else {
        // If the vector coefficient is 0, we set c1 to a constant value
        float wgt_c1_scalar = c1_min + static_cast<float>(rand()) / RAND_MAX_F * (c1_max - c1_min);
        for(int c = 0; c < wgt.N; c++){
            bool valid_idx = (c < N);
            wgt.c1_at(c) = (valid_idx) ? wgt_c1_scalar : static_cast<float>(0);
        }
    }

    if (wgt.vect_coeff > 1) {
        // NOTE: This indicates per channel QDQ on C2
        for(int c = 0; c < wgt.N; c++){
            bool valid_idx = (c < N);
            wgt.c2_at(c) = (valid_idx) ? c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min) : static_cast<float>(0);
        }
    } else {
        // If the vector coefficient is 1 or 0, we set c2 to a constant value
        float wgt_c2_scalar = c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min);
        for(int c = 0; c < wgt.N; c++){
            bool valid_idx = (c < N);
            wgt.c2_at(c) = (valid_idx) ? wgt_c2_scalar : static_cast<float>(0);
        }
    }

    wgt.set_qdq_params(qdq_params);
}

template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2>
inline void init_random_gemm_a16w4(
    ActTensor<Ta> act,
    GemmWgtTensor<Tw, Tc0, Tc1, Tc2> wgt,
    GemmQdqint16x8_RT_Params qdq_params,
    int is_int4,
    int K,
    int N
)
{
    int64_t ifm_min = 0;
    int64_t ifm_max = 64;
    int64_t wgt_min = 0;
    int64_t wgt_max = 4;

    float c0_min = 0.0F;
    float c0_max = +16.828F;                            
    float c1_min = 0.0F;
    float c1_max = +2.828F;                             
    float c2_min = 0.0F;
    float c2_max = +1.828F;

    // Use constexpr to ensure compile-time conversion
    constexpr auto RAND_MAX_F = static_cast<float>(RAND_MAX);

    for(int r = 0; r < act.X; ++r) {
        for (int c = 0; c < act.C; ++c) {
            act.at(c, 0, r) = Ta((rand() % (ifm_max - ifm_min)) + ifm_min); 
        }
    }

    for(int r = 0; r < wgt.K; ++r){
        for(int c = 0; c < wgt.N; ++c){
            bool valid_idx = (r < K) && (c < N);
            auto value = (valid_idx) ? Tw((rand() % (wgt_max - wgt_min)) + wgt_min) : Tw(0);
            wgt.set(r, c, is_int4, value);
        }
    }

    // Generate random C0 coefficients between c0_min and c0_max
    for(int c = 0; c < wgt.N; c++){
        wgt.c0_at(c) = c0_min + static_cast<float>(rand()) / RAND_MAX_F * (c0_max - c0_min);
    }

    if (wgt.vect_coeff > 0) {
        // NOTE: This indicates per channel QDQ on C1
        for(int c = 0; c < wgt.N; c++){
            wgt.c1_at(c) = c1_min + static_cast<float>(rand()) / RAND_MAX_F * (c1_max - c1_min);
        }
    } else {
        // If the vector coefficient is 0, we set c1 to a constant value
        float wgt_c1_scalar = c1_min + static_cast<float>(rand()) / RAND_MAX_F * (c1_max - c1_min);
        for(int c = 0; c < wgt.N; c++){
            wgt.c1_at(c) = wgt_c1_scalar;
        }
    }

    if (wgt.vect_coeff > 1) {
        // NOTE: This indicates per channel QDQ on C2
        for(int c = 0; c < wgt.N; c++){
            wgt.c2_at(c) = c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min);
        }
    } else {
        // If the vector coefficient is 1 or 0, we set c2 to a constant value
        float wgt_c2_scalar = c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min);
        for(int c = 0; c < wgt.N; c++){
            wgt.c2_at(c) = wgt_c2_scalar;
        }
    }

    wgt.set_qdq_params(qdq_params);
}

template<typename Ta, typename Tw, typename Tb, typename Tc0, typename Tc1, typename Tc2>
inline void init_model_data(
    std::string const_path,
    std::string node_name,
    GemmWgtTensor<Tw, Tc0, Tc1, Tc2> wgt,
    GemmQdqint16x8_RT_Params qdq_params,
    int is_int4,
    int K,
    int N,
    int const vec_coeff,
    int const debug_mode
) {

    int raw_wgt_size = (K * N * sizeof(Tw));
    void* raw_wgt_data = malloc(raw_wgt_size);

    // get wgt data
    replace_symbols(node_name);
    std::filesystem::path wgt_file {const_path + "/" + node_name + "/" + "B.bin"};
    read_bin_file(wgt_file, reinterpret_cast<char*>(raw_wgt_data), raw_wgt_size);
    std::vector<Tw> wf(K * N);
    std::memcpy(wf.data(), raw_wgt_data, raw_wgt_size);
    std::vector<int64_t> w_shape = {K, N};
    std::string wgt_type = is_int4 ? "int4" : "int8";
    std::vector<std::vector<Tw>> weight = fold2D<Tw>(wf, w_shape, wgt_type);
    for (int k = 0; k < wgt.K; ++k) {
        for (int n = 0; n < wgt.N; ++n) {
            bool valid_idx = (k < K) && (n < N);
	    Tw value = valid_idx ? weight[k][n] : Tw(0);
	    if (is_int4) {
                wgt.set(k, n, is_int4, value);
	    } else {
                wgt.at(k, n) = valid_idx ? value : Tw(0);
	    }
        }
    }

    // get scale, zp
    std::vector<json> scale_zp = get_scale_zp_vector(const_path, node_name);
    int scale_zp_size = 6; // for matmul + other op, this value is 10

    bool is_bias = true;
    if (scale_zp.size() == scale_zp_size + 2) {
        is_bias = true;
    }
    else if (scale_zp.size() == scale_zp_size) {
        is_bias = false;
    }
    else {
        std::cout << "scale_zp vector has wrong size for node "
            << node_name << std::endl;
    }
    std::tuple<std::vector<float>, std::vector<float>, std::vector<float>> qdq_values;
    if (is_bias) {
	    std::filesystem::path bias_file {const_path + "/" + node_name + "/" + "Bias.bin"};
	// get bias data
        int raw_bias_size = N * sizeof(Tb);
        void* raw_bias_data = malloc(raw_bias_size);
        read_bin_file(bias_file, reinterpret_cast<char*>(raw_bias_data), raw_bias_size);
	std::vector<Tb> bias(N);
	std::memcpy(bias.data(), raw_bias_data, raw_bias_size);
	// get scale/zp
        float in_s = scale_zp[0];
        uint16_t in_zp = scale_zp[1];
        float o_s = scale_zp[6];
        uint16_t o_zp = scale_zp[7];
	if (vec_coeff == 0) {
            float w_s = scale_zp[2];
            uint8_t w_zp = scale_zp[3];
            float b_s = scale_zp[4];
            int32_t b_zp = scale_zp[5];
            // calculate c0, c1, c2
	    // Note: qdq calculation for 8A8W_bias, 16A16W_bias and 16A8W_bias matmul_bias should be the same
            qdq_values = dq_uint16A_uint8W_bias_matmul_q_param_gen<Tw, Tb>(
                         in_s, in_zp, weight, w_s, w_zp, bias, b_s, b_zp, o_s, o_zp);
	} else if (vec_coeff == 1) {
            std::cout << "ERROR: Unsupported channelwise formatting for vec_coeff = 1" << std::endl;
	} else if (vec_coeff > 1) {
            std::vector<float> w_s = scale_zp[2];
            std::vector<uint16_t> w_zp = scale_zp[3];
            if (is_int4) {
                float b_s = scale_zp[4];
                int32_t b_zp = scale_zp[5];
		qdq_values = dq_uint16A_int4W_bias_matmul_q_param_gen<Tw, Tb>(
			    in_s, in_zp, weight, w_s, w_zp, bias, b_s, b_zp, o_s, o_zp);
	    } else {
                std::vector<float> b_s = scale_zp[4];
                std::vector<uint16_t> b_zp = scale_zp[5];
                qdq_values = dq_uint16A_int8W_bias_matmul_q_param_gen_chwise<Tw, Tb>(
                            in_s, in_zp, weight, w_s, w_zp, bias, b_s, b_zp, o_s, o_zp);
	    }
	}	
        free(raw_bias_data);
    } else {
	// get scale/zp
    	float in_s = scale_zp[0];
    	uint16_t in_zp = scale_zp[1];
    	float o_s = scale_zp[4];
    	uint16_t o_zp = scale_zp[5];
	if (vec_coeff == 0) {
    	    float w_s = scale_zp[2];
    	    uint8_t w_zp = scale_zp[3];
            // calculate c0, c1, c2
	    // Note: qdq calculation for 8A8W, 16A16W and 16A8W matmul_nobias should be the same
            qdq_values = calculate_matmul_qdq_params_no_bias<Tw, Tb>(
	    		weight, in_s, in_zp, w_s, w_zp, o_s, o_zp);
	} else if (vec_coeff == 1) {
            std::cout << "ERROR: Unsupported channelwise formatting for vec_coeff = 1" << std::endl;
	} else if (vec_coeff > 1) {
            std::vector<float> w_s = scale_zp[2];
            std::vector<uint16_t> w_zp = scale_zp[3];
            if (is_int4) {
                qdq_values = calculate_matmul_qdq_params_uint16_int4<Tw, Tb>(
	        		weight, in_s, in_zp, w_s, w_zp, o_s, o_zp);
            } else {
	        // create empty bias info
	        std::vector<Tb> bias;
                std::vector<float> b_s;
                std::vector<uint16_t> b_zp;
                qdq_values = dq_uint16A_int8W_bias_matmul_q_param_gen_chwise<Tw, Tb>(
	        		in_s, in_zp, weight, w_s, w_zp, bias, b_s, b_zp, o_s, o_zp);
	    }
	}	
    }
    // Add the vectors to make it compiled.
    std::vector<float> C0_vec = std::get<0>(qdq_values);
    std::vector<float> C1_vec = std::get<1>(qdq_values);
    std::vector<float> C2_vec = std::get<2>(qdq_values);
    if (debug_mode) {
        printf("C0 values\n");
        for (int i = 0; i < N; ++i) {
            printf("%f\n", C0_vec[i]);
        }
        printf("C1 values\n");
        for (int i = 0; i < N; ++i) {
            printf("%f\n", C1_vec[i]);
        }
        printf("C2 values\n");
        for (int i = 0; i < N; ++i) {
            printf("%f\n", C2_vec[i]);
        }
    }

    for (int n = 0; n < N; ++n) {
        wgt.c0_at(n) = C0_vec[n];
        if (vec_coeff > 0) {
            wgt.c1_at(n) = C1_vec[n];
        } else {
            wgt.c1_at(n) = C1_vec[0];
        }
        if (vec_coeff > 1) {
            wgt.c2_at(n) = C2_vec[n];
        } else {
            wgt.c2_at(n) = C2_vec[0];
        } 
    }
    for (int n = N; n < wgt.N; ++n) {
        wgt.c0_at(n) = 0.0f;
        wgt.c1_at(n) = 0.0f;
        wgt.c2_at(n) = 0.0f;
    }
    wgt.set_qdq_params(qdq_params);

    free(raw_wgt_data);
}



template<typename Tw, typename Tc0, typename Tc1, typename Tc2>
inline void log_tensor(GemmWgtTensor<Tw, Tc0, Tc1, Tc2> mat, const char* msg = nullptr, int is_int4 = 0)
{
    if (msg != nullptr) {
        std::cout << msg << "\n";
    }
    for(int i = 0; i < mat.K; ++i) {
        for (int j = 0; j < mat.N; ++j) {
            int64_t w = is_int4 ? mat.at(i, j, is_int4) : mat.at(i, j);
            std::cout << w << " ";
        }
        std::cout << "\n";
    }
    std::cout << "Co_coeff: \n";
    for(int c = 0; c < mat.N; ++c) {
        std::cout << mat.c0_at(c) << " ";
    }
    std::cout << "\n C1_coeff: \n";
    for(int c = 0; c < mat.N; ++c) {
        std::cout << mat.c1_at(c) << " ";
    }
    std::cout << "\n C2_coeff: \n";
    for(int c = 0; c < mat.N; ++c) {
        std::cout << mat.c2_at(c) << " ";
    }
    std::cout << "\n QDQ Params: \n";
    auto qdq_params = mat.get_qdq_params();
    std::cout << "Shift_out: " << qdq_params.shift_out << "\n";
    std::cout << "ifm_sign: " << int(qdq_params.ifm_sign) << "\n";
    std::cout << "wgt_sign: " << int(qdq_params.wgt_sign) << "\n";
    std::cout << "ofm_sign: " << int(qdq_params.ofm_sign) << "\n";
}


#endif //GEMM_HPP
