#ifndef CONV_HPP
#define CONV_HPP

#include <random>
#include "common.hpp"
#include "qdq_utils_aie4.hpp"

using namespace waic_runtime_aie4;

enum ActivationConfig {
    AC_SRS,
    AC_RELU,
    AC_RELU6,
    AC_LRELU,
    AC_HSWISH
};

struct ConvWgtTensor_noqdq_RT_Params {
    int16_t lrelu_alpha;
    int8_t max_value;
    int8_t shift_bias;
    int8_t shift_lrelu_in;
    int8_t shift_out;
    struct Control {
        uint8_t sign_A:1;
        uint8_t sign_W:1;
        uint8_t sign_O:1;
    } ctrl;
    uint8_t reserved[121]; // Padding to align the struct size to 128 bytes.
};

struct ConvWgtTensor_qdq_RT_Params
{
    int16_t shift_out;
    uint8_t ifm_sign;
    uint8_t wgt_sign;
    uint8_t ofm_sign;
    // NOTE: Used to align the struct size to 128 bytes.
    uint8_t reserved[123];
};

template<typename Tw, typename Tb>
struct ConvWgtTensor_noqdq
{
    static const int subv_align_bytes = 128;
    int const Ci_gran;
    int const Co;
    int const Ci;
    int const Ky;
    int const Kx;
    int const Kx_padded;
    int const Cos;
    int const Cis;
    int const subv_wgt_size;
    int const subv_bias_size;
    int const subv_qdq_size;
    int const subv_size;
    int const data_size;
    Tw* const data;

    ConvWgtTensor_noqdq(int Co, int Ci, int Ky, int Kx, int Cis, int Cos, void* data)
        : Co(Co)
        , Ci(Ci)
        , Ky(Ky)
        , Kx(Kx)
        , Cos(Cos)
        , Cis(Cis)
        , data(static_cast<Tw*>(data))
        , Kx_padded((Ci < 64) ? ceildiv(64, Cis) : Kx) // Pad Kx to the next multiple of 4
        , Ci_gran((Ci < 64) ? 8 : 64) // If Ci < 64, we use a granularity of 4, otherwise 64
        , subv_wgt_size(compute_subv_wgt_size(Ky, Kx_padded, Cis, Cos))
        , subv_bias_size(compute_subv_bias_size(Cos))
        , subv_qdq_size(compute_subv_qdq_size())
        , subv_size(compute_subv_size(Ky, Kx_padded, Cis, Cos))
        , data_size(size(Co, Ci, Ky, Kx_padded, Cis, Cos))
    {
        if(Ci < 64){
            assert((Cis == 8) && "When Ci < 64, CIS must be 4");
        }
        assert((Cos == 64) && "Co must be 64");
        assert((Co % 64 == 0) && "Co must be multiple of 64");
        assert((Ci % Ci_gran == 0) && "Ci must be multiple of Ci_gran");
        assert((Co % Cos == 0) && "Co must be multiple of Cos");
        assert((Ci % Cis == 0) && "Ci must be multiple of Cis");
        assert((Ky > 0) && "Ky must be positive");
        assert((Kx > 0) && "Kx must be positive");
    }

    char* subv_ptr(int co, int ci)
    {
        assert((0 <= co) && (co < Co));
        assert((0 <= ci) && (ci < Ci));
        int offset = subv_size * (
            ((ci / Cis) * (Co / Cos)) +
            (co / Cos)
        );
        return (char *)data + offset; // NOLINT
    }

    Tw& wgt_at(int co, int ci, int ky, int kx)
    {
        assert((0 <= co) && (co < Co));
        assert((0 <= ci) && (ci < Ci));
        assert((0 <= ky) && (ky < Ky));
        assert((0 <= kx) && (kx < Kx_padded));
        int subv_idx = 0;
        if(Ci_gran == 8)
        {
        //
        // Ky Co:64 Kx:Kx_padded Ci:8
        //
            subv_idx =
                (ky * 64 * 64) +
                ((co % 64) * 64) +
                ((kx % Kx_padded) * 8) +
                (ci % 8);
        } else {
        //
        // Ci:Cis Ky Kx Co:64 Ci:64
        //
            subv_idx =
                (((ci % Cis) / 64) * Cos * Ky * Kx * 64) +
                (ky * Kx * 64 * 64) +
                (kx * 64 * 64) +
                ((co % 64) * 64) +
                (ci % 64);
        }
        auto ptr = reinterpret_cast<Tw*>(subv_ptr(co, ci)); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    Tb& bias_at(int co)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        // NOTE: The bias is referenced from the very first Cin subvol of the co shard.
        auto ptr = reinterpret_cast<Tb*>(subv_ptr(co, 0) + subv_wgt_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    void set_qdq_params(ConvWgtTensor_noqdq_RT_Params qdq_params)
    {
        // NOTE: The qdq_params are loaded for every Co shard in the conv kernel along with bias.
        for(int co=0; co<Co; co++){
            for(int ci=0; ci<Ci; ci+=Cis){ // NOTE: THe qdq params are copied to all the subvols
                auto qdq_params_ptr = reinterpret_cast<ConvWgtTensor_noqdq_RT_Params*>(subv_ptr(co, ci) + subv_wgt_size + subv_bias_size); // NOLINT
                *qdq_params_ptr = qdq_params;
            }
        }
    }

    ConvWgtTensor_noqdq_RT_Params get_qdq_params(){
        auto qdq_params_ptr = reinterpret_cast<ConvWgtTensor_noqdq_RT_Params*>(subv_ptr(0, 0) + subv_wgt_size + subv_bias_size); // NOLINT
        return *qdq_params_ptr;
    }

    static int compute_subv_wgt_size(int Ky, int Kx, int Cis, int Cos)
    {
        return iceil(Cos * Cis * Ky * Kx * int(sizeof(Tw)), subv_align_bytes);
    }

    static int compute_subv_bias_size(int Cos)
    {
        return iceil(Cos * int(sizeof(Tb)), subv_align_bytes);
    }

    static int compute_subv_qdq_size()
    {
        return sizeof(ConvWgtTensor_noqdq_RT_Params);
    }

    static int compute_subv_size(int Ky, int Kx, int Cis, int Cos)
    {
        int subv_wgt_size = compute_subv_wgt_size(Ky, Kx, Cis, Cos);
        int subv_bias_size = compute_subv_bias_size(Cos);
        int subv_qdq_size = compute_subv_qdq_size();
        int subv_size = subv_wgt_size + subv_bias_size + subv_qdq_size;
        return subv_size;
    }

    static int size(int Co, int Ci, int Ky, int Kx, int Cis, int Cos)
    {
        int num_subv = (Co / Cos) * (Ci / Cis);
        int subv_size = compute_subv_size(Ky, Kx, Cis, Cos);
        int total_size = num_subv * subv_size;
        return total_size;
    }
};

template<typename Tw, typename Tc0, typename Tc1, typename Tc2>
struct ConvWgtTensor_qdq
{
    static const int subv_align_bytes = 128;
    int const Ci_gran;
    int const Co;
    int const Ci;
    int const Ky;
    int const Kx;
    int const Kx_padded;
    int const Cos;
    int const Cis;
    int const vect_coeff;
    int const subv_wgt_size;
    int const subv_c0_size;
    int const subv_c1_size;
    int const subv_c2_size;
    int const subv_qdq_size;
    int const subv_size;
    int const data_size;
    Tw* const data;

    ConvWgtTensor_qdq(int Co, int Ci, int Ky, int Kx, int Cis, int Cos, int vect_coeff, void* data)
        : Co(Co)
        , Ci(Ci)
        , Ky(Ky)
        , Kx(Kx)
        , Cos(Cos)
        , Cis(Cis)
        , vect_coeff(vect_coeff)
        , data(static_cast<Tw*>(data))
        , Kx_padded((Ci < 64) ? ceildiv(64, Cis) : Kx) // Pad Kx to the next multiple of 4
        , Ci_gran((Ci < 64) ? 8 : 64) // If Ci < 64, we use a granularity of 4, otherwise 64
        , subv_wgt_size(compute_subv_wgt_size(Ky, Kx_padded, Cis, Cos))
        , subv_c0_size(compute_subv_c0_size(Cos))
        , subv_c1_size(compute_subv_c1_size(Cos))
        , subv_c2_size(compute_subv_c2_size(Cos))
        , subv_qdq_size(compute_subv_qdq_size())
        , subv_size(compute_subv_size(Ky, Kx_padded, Cis, Cos))
        , data_size(size(Co, Ci, Ky, Kx_padded, Cis, Cos))
    {
        assert((Cos == 64) && "Co must be 64");
        assert((Co % 64 == 0) && "Co must be multiple of 64");
        assert((Ci % Ci_gran == 0) && "Ci must be multiple of Ci_gran");
        assert((Co % Cos == 0) && "Co must be multiple of Cos");
        assert((Ci % Cis == 0) && "Ci must be multiple of Cis");
        assert((Ky > 0) && "Ky must be positive");
        assert((Kx > 0) && "Kx must be positive");
    }

    char* subv_ptr(int co, int ci)
    {
        assert((0 <= co) && (co < Co));
        assert((0 <= ci) && (ci < Ci));
        int offset = subv_size * (
            ((ci / Cis) * (Co / Cos)) +
            (co / Cos)
        );
        return (char *)data + offset; // NOLINT
    }

    Tw& wgt_at(int co, int ci, int ky, int kx)
    {
        assert((0 <= co) && (co < Co));
        assert((0 <= ci) && (ci < Ci));
        assert((0 <= ky) && (ky < Ky));
        assert((0 <= kx) && (kx < Kx_padded));
        int subv_idx = 0;
        if(Ci_gran == 8)
        {
        //
        // Ky Co:64 Kx:Kx_padded Ci:8
        //
            subv_idx =
                (ky * 64 * 64) +
                ((co % 64) * 64) +
                ((kx % Kx_padded) * 8) +
                (ci % 8);
        } else {
        //
        // Ci:Cis Ky Kx Co:64 Ci:64
        //
            subv_idx =
                (((ci % Cis) / 64) * Cos * Ky * Kx * 64) +
                (ky * Kx * 64 * 64) +
                (kx * 64 * 64) +
                ((co % 64) * 64) +
                (ci % 64);
        }
        auto ptr = reinterpret_cast<Tw*>(subv_ptr(co, ci)); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    Tc0& c0_at(int co)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        auto ptr = reinterpret_cast<Tc0*>(subv_ptr(co, Ci - 1) + subv_wgt_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    void set_c0_at(int co, Tc0 value)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        for(int ci=0; ci<Ci; ci+=Cis){ // NOTE: c0 is copied to all the subvols
            auto ptr = reinterpret_cast<Tc0*>(subv_ptr(co, ci) + subv_wgt_size); // NOLINT
            ptr[subv_idx] = value; // NOLINT
        }
    }

    Tc1& c1_at(int co)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        auto ptr = reinterpret_cast<Tc0*>(subv_ptr(co, Ci - 1) + subv_wgt_size + subv_c0_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    void set_c1_at(int co, Tc1 value)
    {
        // This function sets the C1 coefficient for a given output channel 'co'.
        // for each input channel subvolume, it sets the C1 value at the appropriate offset.
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        for(int ci=0; ci<Ci; ci+=Cis){ // NOTE: c1 is copied to all the subvols
            auto ptr = reinterpret_cast<Tc1*>(subv_ptr(co, ci) + subv_wgt_size + subv_c0_size); // NOLINT
            ptr[subv_idx] = value; // NOLINT
        }
    }

    Tc2& c2_at(int co)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        auto ptr = reinterpret_cast<Tc0*>(subv_ptr(co, Ci - 1) + subv_wgt_size + subv_c0_size + subv_c1_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    void set_c2_at(int co, Tc2 value)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        for(int ci=0; ci<Ci; ci+=Cis){ // NOTE: c2 is copied to all the subvols
            auto ptr = reinterpret_cast<Tc2*>(subv_ptr(co, ci) + subv_wgt_size + subv_c0_size + subv_c1_size); // NOLINT
            ptr[subv_idx] = value; // NOLINT
        }
    }

    void set_qdq_params(ConvWgtTensor_qdq_RT_Params qdq_params)
    {
        // NOTE: The qdq_params are loaded for every Co shard in the conv kernel along with bias.
        for(int co=0; co<Co; co++){
            for(int ci=0; ci<Ci; ci+=Cis){ // NOTE: THe qdq params are copied to all the subvols
                auto qdq_params_ptr = reinterpret_cast<ConvWgtTensor_qdq_RT_Params*>(subv_ptr(co, ci) + subv_wgt_size + subv_c0_size + subv_c1_size + subv_c2_size); // NOLINT
                *qdq_params_ptr = qdq_params;
            }
        }
    }

    ConvWgtTensor_qdq_RT_Params get_qdq_params(){
        auto qdq_params_ptr = reinterpret_cast<ConvWgtTensor_qdq_RT_Params*>(subv_ptr(0, 0) + subv_wgt_size + subv_c0_size + subv_c1_size + subv_c2_size); // NOLINT
        return *qdq_params_ptr;
    }

    static int compute_subv_wgt_size(int Ky, int Kx, int Cis, int Cos)
    {
        return iceil(Cos * Cis * Ky * Kx * int(sizeof(Tw)), subv_align_bytes);
    }

    static int compute_subv_c0_size(int Cos)
    {
        return iceil(Cos * int(sizeof(Tc0)), subv_align_bytes);
    }

    static int compute_subv_c1_size(int Cos)
    {
        return iceil(Cos * int(sizeof(Tc1)), subv_align_bytes);
    }

    static int compute_subv_c2_size(int Cos)
    {
        return iceil(Cos * int(sizeof(Tc2)), subv_align_bytes);
    }

    static int compute_subv_qdq_size()
    {
        return sizeof(ConvWgtTensor_qdq_RT_Params);
    }

    static int compute_subv_size(int Ky, int Kx, int Cis, int Cos)
    {
        int subv_wgt_size = compute_subv_wgt_size(Ky, Kx, Cis, Cos);
        int subv_c0_size = compute_subv_c0_size(Cos);
        int subv_c1_size = compute_subv_c1_size(Cos);
        int subv_c2_size = compute_subv_c2_size(Cos);
        int subv_qdq_size = compute_subv_qdq_size();
        int subv_size = subv_wgt_size + 
                        subv_c0_size +
                        subv_c1_size +
                        subv_c2_size + 
                        subv_qdq_size;
        return subv_size;
    }

    static int size(int Co, int Ci, int Ky, int Kx, int Cis, int Cos)
    {
        int num_subv = (Co / Cos) * (Ci / Cis);
        int subv_size = compute_subv_size(Ky, Kx, Cis, Cos);
        int total_size = num_subv * subv_size;
        return total_size;
    }
};

inline int8_t srs_relu_int8(int64_t acc, uint8_t max_value, uint8_t shift, int ofm_sign)
{
    // NOTE: In conv kernel the RELU activation function is controlled by the OFM sign bit and NOT the act_type in qdq params.
    // Even though the act_type is set to RELU, the OFM sign bit will override it.
    fesetround(FE_TONEAREST);
    acc = int64_t(nearbyint(acc / double(int64_t(1) << shift)));
    int64_t int_min = (ofm_sign == 0) ? std::numeric_limits<uint8_t>::min() : std::numeric_limits<int8_t>::min();
    int64_t int_max = (ofm_sign == 0) ? std::numeric_limits<uint8_t>::max() : std::numeric_limits<int8_t>::max();
    // NOTE: This value comes from
    //      clip_max / scale == 6 / 0.0625 == 96
    int64_t clip_min = ((ofm_sign == 0)) ? 0 : int_min;
    int64_t clip_max = max_value;
    int8_t val = (acc < clip_min) ? int8_t(clip_min) :
            (acc > clip_max) ? int8_t(clip_max) :
                               int8_t(acc);
    return val;
}

template<typename Ta, typename Tw, typename Tb, typename To>
inline void cpu_iconv_2d(
    ActTensor<Ta> ifm,
    ConvWgtTensor_noqdq<Tw, Tb> wgt,
    ActTensor<To> ofm,
    int Sy, int Sx,
    int Py, int Px)
{
    ConvWgtTensor_noqdq_RT_Params qdq_params = wgt.get_qdq_params();
    for (int co = 0; co < ofm.C; ++co) {
        for (int yi = -Py, yo = 0; yo < ofm.Y; yi += Sy, ++yo) {
            for (int xi = -Px, xo = 0; xo < ofm.X; xi += Sx, ++xo) {
                int64_t acc = 0;
                for (int ci = 0; ci < ifm.C; ++ci) {
                    for (int ky = 0; ky < wgt.Ky; ++ky) {
                        for (int kx = 0; kx < wgt.Kx; ++kx) {
                            int y = yi + ky;
                            int x = xi + kx;
                            int8_t a = (0 <= y && y < ifm.Y &&
                                    0 <= x && x < ifm.X) ? ifm.at(ci, y, x)
                                                         : 0;
                            int8_t w = wgt.wgt_at(co, ci, ky, kx);
                            acc += a * w;
                        }
                    }
                }
                acc += ((int64_t)wgt.bias_at(co)) << qdq_params.shift_bias;
                ofm.at(co, yo, xo) = srs_relu_int8(acc, qdq_params.max_value, qdq_params.shift_out, qdq_params.ctrl.sign_O);
            }
        }
    }
}
template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2, typename Tacc>
inline void cpu_iconv_2d(
    ActTensor<Ta> ifm,
    ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2> wgt,
    ActTensor<Tacc> ofm,
    int Sy, int Sx,
    int Py, int Px)
{
    ConvWgtTensor_qdq_RT_Params qdq_params = wgt.get_qdq_params();
    for (int co = 0; co < ofm.C; ++co) {
        for (int yi = -Py, yo = 0; yo < ofm.Y; yi += Sy, ++yo) {
            for (int xi = -Px, xo = 0; xo < ofm.X; xi += Sx, ++xo) {
                Tacc acc = 0;
                for (int ci = 0; ci < ifm.C; ++ci) {
                    for (int ky = 0; ky < wgt.Ky; ++ky) {
                        for (int kx = 0; kx < wgt.Kx; ++kx) {
                            int y = yi + ky;
                            int x = xi + kx;
                            Ta a = (0 <= y && y < ifm.Y &&
                                    0 <= x && x < ifm.X) ? ifm.at(ci, y, x)
                                                         : 0;
                            Tw w = wgt.wgt_at(co, ci, ky, kx);
                            acc += a * w;
                        }
                    }
                }
                ofm.at(co, yo, xo) = acc;
            }
        }
    }
}

template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2, typename To, typename Tacc>
inline void cpu_3term_qdq(
    ActTensor<Ta> act,
    ActTensor<Tacc> conv_out,
    ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2> wgt,
    ActTensor<To> out,
    int Sy, int Sx,
    int Py, int Px,
    bool debug_mode
)
{
    /*
        QDQ formula:
        Y = conv_out * qdq_params.C2 + ifm_sum * qdq_params.C1 + C0;
    */
    ActTensor<Tacc> ifm_sum(1, out.Y, out.X, malloc(sizeof(Tacc) * out.Y * out.X));
    for (int yi = -Py, yo = 0; yo < out.Y; yi += Sy, ++yo) {
        for (int xi = -Px, xo = 0; xo < out.X; xi += Sx, ++xo) {
            Tacc sum = 0;
            for (int c = 0; c < act.C; ++c) {
                for (int ky = 0; ky < wgt.Ky; ++ky) {
                    for (int kx = 0; kx < wgt.Kx; ++kx) {
                        int y = yi + ky;
                        int x = xi + kx;
                        sum += (0 <= y && y < act.Y &&
                            0 <= x && x < act.X) ? act.at(c, y, x) : 0;
                    }
                }
            }
            ifm_sum.at(0, yo, xo) = sum;
        }
    }
    ConvWgtTensor_qdq_RT_Params qdq_params = wgt.get_qdq_params();
    for (int y = 0; y < out.Y; ++y){
        for (int r = 0; r < out.X; ++r) {
            for (int c = 0; c < out.C; ++c) {
                Tacc conv_val = conv_out.at(c, y, r);
                if(debug_mode) std::cout << "conv_val = " << conv_val << std::endl;
                Tacc ifm_sum_val = ifm_sum.at(0, y, r);
                if(debug_mode) std::cout << "ifm_sum_val = " << ifm_sum_val << std::endl;
                float res = (conv_val * wgt.c2_at(c)) + (ifm_sum_val * wgt.c1_at(c)) + wgt.c0_at(c);
                if(debug_mode) printf("res = %f, C2 = %f, C1 = %f, c0 = %f\n", res, wgt.c2_at(c), wgt.c1_at(c), wgt.c0_at(c));
                // Quantize the result to the output type
                out.at(c, y, r) = quantize_float_to_int16<To>(res, qdq_params.shift_out, qdq_params.ofm_sign == 1);
                if(debug_mode) printf("out.at(%d, %d, %d) = %d\n", c, y, r, out.at(c, y, r));
            }
        }
    }
}

template<typename Tw, typename Tb>
inline void init_conv_noqdq_model_data(std::string const md_path,
                                    std::string node_name,
                                    ConvWgtTensor_noqdq<Tw, Tb> wgt,
                                    ConvWgtTensor_noqdq_RT_Params qdq_params,
                                    int Co_no_pad, int Ci_no_pad,
                                    int const vec_coeff,
                                    int const debug_mode)
{
    replace_symbols(node_name);
    std::filesystem::path wgt_file = {md_path + "/" + node_name + "/" + "B.bin"};
    std::filesystem::path bias_file = {md_path + "/" + node_name + "/" + "Bias.bin"};

    uint32_t num_wgt_elements = wgt.Kx * wgt.Ky * Co_no_pad * Ci_no_pad;
    size_t wgt_buf_size = sizeof(Tw) * num_wgt_elements;
    char* wgt_buf = (char *)allocate(wgt_buf_size);
    if (wgt_buf == NULL) {
        std::cout << "Unable to allocate memory for reading wgt data" << std::endl;
        return;
    }
    size_t bias_buf_size = sizeof(Tb) * Co_no_pad;
    char* bias_buf = (char *)allocate(bias_buf_size);
    if (bias_buf == NULL) {
        std::cout << "Unable to allocate memory for reading bias data" << std::endl;
        deallocate(wgt_buf);
        return;
    }
    size_t wgt_bytes_read = read_bin_file(wgt_file, wgt_buf, wgt_buf_size);
    size_t bias_bytes_read = read_bin_file(bias_file, bias_buf, bias_buf_size);

    Tb* bias_buf_ptr = (Tb *)bias_buf;

    //
    // Transpose wgt buffer from Kx, ky, Ci, Co to Co, Ci, Ky, Kx
    //
    std::vector<Tw> wgt_transposed(num_wgt_elements);
	transpose_conv_wgt<Tw>(reinterpret_cast<Tw*>(wgt_buf), wgt_transposed, wgt.Kx, wgt.Ky, wgt.Ci, wgt.Co);

    //
    // Copy transposed wgt data read from wgt bin to wgt buffer
    for (int ci = 0; ci < wgt.Ci; ++ci) {
        for (int ky = 0; ky < wgt.Ky; ++ky) {
            for (int kx = 0; kx < wgt.Kx_padded; ++kx) {
                for (int co = 0; co < wgt.Co; ++co) {
                    if ((co < Co_no_pad) &&
                        (ci < Ci_no_pad) &&
                        (kx < wgt.Kx) &&
                        (ky < wgt.Ky))
                        {
                            wgt.wgt_at(co, ci, ky, kx) = wgt_transposed[co * wgt.Ci * wgt.Ky * wgt.Kx + ci * wgt.Ky * wgt.Kx + ky * wgt.Kx + kx];;
                        }
                        else
                        {
                            wgt.wgt_at(co, ci, ky, kx) = 0;
                        }
                }
            }
        }
    }
    //
    // Copy bias data read from bias bin to wgt buffer
    //
    for (int co = 0; co < wgt.Co; ++co) 
    {
        if (co < Co_no_pad)
        {
            wgt.bias_at(co) = *bias_buf_ptr;
            bias_buf_ptr++;
        }
        else
        {
            wgt.bias_at(co) = 0;
        }
    }
    //
    // Set qdq params
    //
    wgt.set_qdq_params(qdq_params);

    deallocate(wgt_buf);
    deallocate(bias_buf);
}

template<typename Ta, typename Tw, typename Tb>
inline void init_random_conv_noqdq_a8w8(
    ActTensor<Ta> ifm,
    ConvWgtTensor_noqdq<Tw, Tb> wgt,
    ConvWgtTensor_noqdq_RT_Params qdq_params,
    int Co_no_pad, int Ci_no_pad)
{
    int64_t ifm_min = 0;
    int64_t ifm_max = 0;
    int64_t wgt_min = 0;
    int64_t wgt_max = 0;
    if (qdq_params.ctrl.sign_A == 1) {
        ifm_min = -16;
        ifm_max = +16;
    } else {
        ifm_min = 0;
        ifm_max = 8;
    }
    if (qdq_params.ctrl.sign_W == 1) {
        wgt_min = -16;
        wgt_max = +16;
    } else {
        wgt_min = 0;
        wgt_max = 8;
    }

    for (int y = 0; y < ifm.Y; ++y) {
        for (int x = 0; x < ifm.X; ++x) {
            for (int c = 0; c < ifm.C; ++c) {
                ifm.at(c, y, x) = int8_t((rand() % (ifm_max - ifm_min)) + ifm_min); 
            }
        }
    }
    for (int ci = 0; ci < wgt.Ci; ++ci) {
        for (int ky = 0; ky < wgt.Ky; ++ky) {
            for (int kx = 0; kx < wgt.Kx_padded; ++kx) {
                for (int co = 0; co < wgt.Co; ++co) {
                    bool is_pixel = (
                        (co < Co_no_pad) &&
                        (ci < Ci_no_pad) &&
                        (kx < wgt.Kx) &&
                        (ky < wgt.Ky)
                    );
                    wgt.wgt_at(co, ci, ky, kx) = int8_t((is_pixel) ? (rand() % (wgt_max - wgt_min)) + wgt_min : 0);
                }
            }
        }
    }
    for (int co = 0; co < wgt.Co; ++co) {
        wgt.bias_at(co) = int16_t((co < Co_no_pad) ? (rand() % (wgt_max - wgt_min)) + wgt_min : 0);
    }
    wgt.set_qdq_params(qdq_params);
}

template<typename Tw, typename Tc0, typename Tc1, typename Tc2, typename Tb>
inline void init_conv_qdq_model_data(std::string const md_path,
                                    std::string node_name,
                                    ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2> wgt,
                                    ConvWgtTensor_qdq_RT_Params qdq_params,
                                    int Co_no_pad, int Ci_no_pad,
                                    int const vec_coeff,
                                    int const debug_mode)
{
    replace_symbols(node_name);
    //
    // Read wgt binary file (B.bin)
    //
    std::filesystem::path wgt_file = {md_path + "/" + node_name + "/" + "B.bin"};

    uint32_t num_wgt_elements = wgt.Kx * wgt.Ky * Co_no_pad * Ci_no_pad;
    size_t wgt_buf_size = sizeof(Tw) * num_wgt_elements;
    char* wgt_buf = (char*)allocate(wgt_buf_size);
    if (wgt_buf == NULL) {
        std::cout << "Error: Unable to allocate memory for reading wgt data" << std::endl;
        return;
    }

    size_t wgt_bytes_read = read_bin_file(wgt_file, wgt_buf, wgt_buf_size);

    //
    // create vector from raw bin data
    //
    std::vector<Tw> wgt_vec(num_wgt_elements);
    std::memcpy(wgt_vec.data(), wgt_buf, wgt_buf_size);

    //
    // Transpose wgt buffer from Kx, ky, Ci, Co to Co, Ci, Ky, Kx
    //
    std::vector<Tw> wgt_transposed(num_wgt_elements);
	transpose_conv_wgt<Tw>(reinterpret_cast<Tw*>(wgt_buf), wgt_transposed, wgt.Kx, wgt.Ky, Ci_no_pad, Co_no_pad);

    //
    // get scale, zp
    //
    std::vector<json> scale_zp = get_scale_zp_vector(md_path, node_name);
	
    //
    // Copy transposed wgt data to wgt buffer
    //
	Tw w_zp;
    if (scale_zp[3].is_array())
    { 
		w_zp = (Tw)(scale_zp[3][0]);
	} else 
	{
		w_zp = (Tw)(scale_zp[3]);
	}
    for (int o = 0; o < wgt.Co; ++o)
    {
		for (int i = 0; i < wgt.Ci; ++i)
        {
			for (int y = 0; y < wgt.Ky; ++y)
            {
				for (int x = 0; x < wgt.Kx; ++x)
                {
                    //
                    // Check if index is the no pad region
                    //
                    if ((o < Co_no_pad) &&
                        (i < Ci_no_pad) &&
                        (x < wgt.Kx) &&
                        (y < wgt.Ky))
                    {
					    wgt.wgt_at(o, i, y, x) = wgt_transposed[o * Ci_no_pad * wgt.Ky * wgt.Kx + i * wgt.Ky * wgt.Kx + y * wgt.Kx + x];
                    }
                    else
                    {
                        wgt.wgt_at(o, i, y, x) = w_zp;
                    }
				}
			}
		}
    }

    //
    // Create a two dimensional matrix from wgt buf vector. Shape = Ky * Kx * Ci, Co
    //
    std::vector<int64_t> w_shape = {wgt.Kx * wgt.Ky * Ci_no_pad, Co_no_pad};
    std::vector<std::vector<Tw>> wgt_2D = fold2D<Tw>(wgt_vec, w_shape);

    //
    // Check if bias is preset
    //
	int scale_zp_size = 6;
    bool is_bias = true;
    if (scale_zp.size() == scale_zp_size + 2)
    {
        is_bias = true;
    }
    else if (scale_zp.size() == scale_zp_size)
    {
        is_bias = false;
    }
    else
    {
        std::cout << "scale_zp vector has wrong size for node "
            << node_name << std::endl;
    }

    //
    // Calculate qdq values based on presence of bias
    //
    std::tuple<std::vector<float>, std::vector<float>, std::vector<float>> qdq_values;
    if (is_bias) 
    {
        //
        // If bias is present read bias binary file
        //
	    std::filesystem::path bias_file {md_path + "/" + node_name + "/" + "Bias.bin"};
        int raw_bias_size = Co_no_pad * sizeof(Tb);
        void* raw_bias_data = allocate(raw_bias_size);
        if (raw_bias_data == NULL) {
            std::cout << "Error: unable to allocate memeory for reading bias data" << std::endl;
            deallocate(wgt_buf);
            return;
        }
        read_bin_file(bias_file, reinterpret_cast<char*>(raw_bias_data), raw_bias_size);

        //
        // Create a vector from raw bias buffer
        //
	    std::vector<Tb> bias(Co_no_pad);
	    std::memcpy(bias.data(), raw_bias_data, raw_bias_size);
	    
        //
        // record various scale zp values
        //
        float in_s = scale_zp[0];
        uint16_t in_zp = scale_zp[1];
        float o_s = scale_zp[6];
        uint16_t o_zp = scale_zp[7];
        
        //
        // If vec_coeff is > 1 perform channel wise qdq calculation
        //
	    if (vec_coeff == 0)
        {
            float w_s = scale_zp[2];
            uint8_t w_zp = scale_zp[3];
            float b_s = scale_zp[4];
            int32_t b_zp = scale_zp[5];
            //
            // calculate c0, c1, c2
	        // Note: qdq calculation for 8A8W_bias, 16A16W_bias and 16A8W_bias conv_bias should be the same
            //
            qdq_values = dq_uint16A_uint8W_bias_conv_q_param_gen<Tw, Tb>(
                         in_s, in_zp, wgt_2D, w_s, w_zp, bias, b_s, b_zp, o_s, o_zp);
        }
        else if (vec_coeff == 1)
        {
            std::cout << "ERROR: Unsupported channelwise formatting for vec_coeff = 1" << std::endl;
        }
        else if (vec_coeff > 1)
        {
            std::vector<float> w_s = scale_zp[2];
            std::vector<uint16_t> w_zp = scale_zp[3];
            std::vector<float> b_s = scale_zp[4];
            std::vector<uint16_t> b_zp = scale_zp[5];
            qdq_values = dq_uint16A_int8W_bias_conv_q_param_gen_chwise<Tw, Tb>(
                          in_s, in_zp, wgt_2D, w_s, w_zp, bias, b_s, b_zp, o_s, o_zp);
	    }
        deallocate(raw_bias_data);
    }
    else
    {
	    //
        // record various scale zp values
        //
    	float in_s = scale_zp[0];
    	uint16_t in_zp = scale_zp[1];
    	float o_s = scale_zp[4];
    	uint16_t o_zp = scale_zp[5];
	    
        //
        // If vec_coeff is greater than 1 perform channel wise qdq calculation
        //
        if (vec_coeff == 0)
        {
    	    float w_s = scale_zp[2];
    	    uint8_t w_zp = scale_zp[3];
            // calculate c0, c1, c2
	        // Note: qdq calculation for 8A8W, 16A16W and 16A8W matmul_nobias should be the same
            qdq_values = calculate_conv_qdq_params_no_bias<Tw, Tb>(
	    		            wgt_2D, in_s, in_zp, w_s, w_zp, o_s, o_zp);
	    }
        else if (vec_coeff == 1)
        {
            std::cout << "ERROR: Unsupported channelwise formatting for vec_coeff = 1" << std::endl;
	    }
        else if (vec_coeff > 1)
        {
            std::vector<float> w_s = scale_zp[2];
            std::vector<uint16_t> w_zp = scale_zp[3];
	        // create empty bias info
	        std::vector<Tb> bias;
            std::vector<float> b_s;
            std::vector<uint16_t> b_zp;
            qdq_values = dq_uint16A_int8W_bias_conv_q_param_gen_chwise<Tw, Tb>(
	        		        in_s, in_zp, wgt_2D, w_s, w_zp, bias, b_s, b_zp, o_s, o_zp);
	    }	
    }

    // Add the vectors to make it compiled.
    std::vector<float> C0_vec = std::get<0>(qdq_values);
    std::vector<float> C1_vec = std::get<1>(qdq_values);
    std::vector<float> C2_vec = std::get<2>(qdq_values);

    if (debug_mode)
    {
        uint32_t N = C0_vec.size();
        printf("C0 values\n");
        for (int i = 0; i < N; ++i)
        {
            printf("%f\n", C0_vec[i]);
        }
        printf("C1 values\n");
        for (int i = 0; i < N; ++i)
        {
            printf("%f\n", C1_vec[i]);
        }
        printf("C2 values\n");
        for (int i = 0; i < N; ++i)
        {
            printf("%f\n", C2_vec[i]);
        }
    }

    for (int n = 0; n < wgt.Co; ++n)
    {
        bool index_valid = (n < Co_no_pad) ? true : false;
        Tc0 tc0_value = (index_valid) ? C0_vec[n] : static_cast<Tc0>(0);
        wgt.set_c0_at(n, tc0_value);

        Tc1 tc1_value;
        if (vec_coeff > 0)
        {
            tc1_value = (index_valid) ? C1_vec[n] : static_cast<Tc1>(0);
        }
        else
        {
            tc1_value = (index_valid) ? C1_vec[0] : static_cast<Tc1>(0);
        }
        wgt.set_c1_at(n, tc1_value);

        Tc2 tc2_value;
        if (vec_coeff > 1)
        {
            tc2_value = (index_valid) ? C2_vec[n] : static_cast<Tc2>(0);
        }
        else
        {
            tc2_value = (index_valid) ? C2_vec[0] : static_cast<Tc2>(0);
        }
        wgt.set_c2_at(n, tc2_value);
    }
    wgt.set_qdq_params(qdq_params);

    deallocate(wgt_buf);
}

template<typename Ta, typename Tw, typename Tc0, typename Tc1, typename Tc2>
inline void init_random_conv_qdq_a16w8(
    ActTensor<Ta> ifm,
    ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2> wgt,
    ConvWgtTensor_qdq_RT_Params qdq_params,
    int Co_no_pad, int Ci_no_pad)
{
    int64_t ifm_min = (qdq_params.ifm_sign == 1) ? -16 : 0;
    int64_t ifm_max = (qdq_params.ifm_sign == 1) ? +16 : 64;
    int64_t wgt_min = (qdq_params.wgt_sign == 1) ? -16 : 0;
    int64_t wgt_max = (qdq_params.wgt_sign == 1) ? +16 : 15;

    float c0_min = (qdq_params.ofm_sign == 1) ? -16.828F : 0.0F;
    float c0_max = +16.828F;                            
    float c1_min = (qdq_params.ofm_sign == 1) ? -2.828F : 0.0F;
    float c1_max = +2.828F;                             
    float c2_min = (qdq_params.ofm_sign == 1) ? -1.828F : 0.0F;
    float c2_max = +1.828F;

    // Use constexpr to ensure compile-time conversion
    constexpr auto RAND_MAX_F = static_cast<float>(RAND_MAX);

    for (int y = 0; y < ifm.Y; ++y) {
        for (int x = 0; x < ifm.X; ++x) {
            for (int c = 0; c < ifm.C; ++c) {
                ifm.at(c, y, x) = Ta((rand() % (ifm_max - ifm_min)) + ifm_min); 
            }
        }
    }
    for (int ci = 0; ci < wgt.Ci; ++ci) {
        for (int ky = 0; ky < wgt.Ky; ++ky) {
            for (int kx = 0; kx < wgt.Kx_padded; ++kx) {
                for (int co = 0; co < wgt.Co; ++co) {
                    bool is_pixel = (
                        (co < Co_no_pad) &&
                        (ci < Ci_no_pad) &&
                        (kx < wgt.Kx) &&
                        (ky < wgt.Ky)
                    );
                    wgt.wgt_at(co, ci, ky, kx) = Tw((is_pixel) ? (rand() % (wgt_max - wgt_min)) + wgt_min : 0);
                }
            }
        }
    }

    for (int co = 0; co < wgt.Co; ++co) {
        bool valid_idx = (co < Co_no_pad);
        auto value = (valid_idx) ? c0_min + static_cast<float>(rand()) / RAND_MAX_F * (c0_max - c0_min) : static_cast<float>(0);
        wgt.set_c0_at(co, value);
    }

    if (wgt.vect_coeff > 0) {
        // NOTE: This indicates per channel QDQ on C1
        for(int co = 0; co < wgt.Co; ++co){
            bool valid_idx = (co < Co_no_pad);
            auto value = (valid_idx) ? c1_min + static_cast<float>(rand()) / RAND_MAX_F * (c1_max - c1_min) : static_cast<float>(0);
            wgt.set_c1_at(co, value);
        }
    } else {
        // If the vector coefficient is 0, we set c1 to a constant value
        float wgt_c1_scalar = c1_min + static_cast<float>(rand()) / RAND_MAX_F * (c1_max - c1_min);
        for(int co = 0; co < wgt.Co; ++co){
            bool valid_idx = (co < Co_no_pad);
            auto value = (valid_idx) ? wgt_c1_scalar : static_cast<float>(0);
            wgt.set_c1_at(co, value);
        }
    }

    if (wgt.vect_coeff > 1) {
        // NOTE: This indicates per channel QDQ on C2
        for(int co = 0; co < wgt.Co; ++co){
            bool valid_idx = (co < Co_no_pad);
            auto value = (valid_idx) ? c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min) : static_cast<float>(0);
            wgt.set_c2_at(co, value);
        }
    } else {
        // If the vector coefficient is 0, we set c2 to a constant value
        float wgt_c2_scalar = c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min);
        for(int co = 0; co < wgt.Co; ++co){
            bool valid_idx = (co < Co_no_pad);
            auto value = (valid_idx) ? wgt_c2_scalar : static_cast<float>(0);
            wgt.set_c2_at(co, value);
        }
    }

    wgt.set_qdq_params(qdq_params);
}


template<typename Tw, typename Tb>
inline void log_tensor(ConvWgtTensor_noqdq<Tw, Tb> tensor, std::string const& name)
{
    std::cout << name << ": \n";
    for (int bi = 0; bi < tensor.Ci; bi += tensor.Cis) { // Loop over Bi (blocks of Cis)
        for (int bc = 0; bc < tensor.Co; bc += tensor.Cos) { // Loop over Bc (blocks of Cos)
            std::cout << "Subvol Block [Ci: " << bi << "-" << std::min(bi + tensor.Cis, tensor.Ci) - 1
                      << ", Co: " << bc << "-" << std::min(bc + tensor.Cos, tensor.Co) - 1 << "]:\n";
            for (int cis = bi; cis < std::min(bi + tensor.Cis, tensor.Ci); cis += 64) { // Loop over Cis
                for (int ky = 0; ky < tensor.Ky; ++ky) { // Loop over Ky
                    for (int kx = 0; kx < tensor.Kx_padded; ++kx) { // Loop over Kx
                        for (int cos = bc; cos < std::min(bc + tensor.Cos, tensor.Co); ++cos) { // Loop over Cos
                            std::cout << "Weights [Ci: " << cis << "-" << std::min(cis + 64, tensor.Ci) - 1
                                      << ", Co: " << cos << ", Ky: " << ky << ", Kx: " << kx << "]: ";
                            for (int ci_gran = cis; ci_gran < std::min(cis + 64, tensor.Ci); ++ci_gran) { // Loop over Ci_gran
                                std::cout << (int)tensor.wgt_at(cos, ci_gran, ky, kx) << " ";
                            }
                            std::cout << "\n";
                        }
                    }
                }
            }
            std::cout << "\n";
        }
    }
    std::cout << "BIAS: \n";
    for (int co = 0; co < tensor.Co; ++co) {
        std::cout << tensor.bias_at(co) << " ";
    }
    std::cout << "\n";
    std::cout << "QDQPARAMS: \n";
    auto qdq_params = tensor.get_qdq_params();
    std::cout << "OutShift: " << int(qdq_params.shift_out) << "\n";
    std::cout << "BiasShift: " << int(qdq_params.shift_bias) << "\n";
    std::cout << "IfmSign: " << int(qdq_params.ctrl.sign_A) << "\n";
    std::cout << "WgtSign: " << int(qdq_params.ctrl.sign_W) << "\n";
    std::cout << "OfmSign: " << int(qdq_params.ctrl.sign_O) << "\n";
    std::cout << "LReLUAlpha: " << qdq_params.lrelu_alpha << "\n";
    std::cout << "MaxValue: " << int(qdq_params.max_value) << "\n";
    std::cout << "Ctrl: " << int(qdq_params.ctrl.sign_A) << int(qdq_params.ctrl.sign_W) << int(qdq_params.ctrl.sign_O) << "\n";
}

template<typename Tw, typename Tc0, typename Tc1, typename Tc2>
inline void log_tensor(ConvWgtTensor_qdq<Tw, Tc0, Tc1, Tc2> tensor, std::string const& name)
{
    std::cout << name << ": \n";
    for (int bi = 0; bi < tensor.Ci; bi += tensor.Cis) { // Loop over Bi (blocks of Cis)
        for (int bc = 0; bc < tensor.Co; bc += tensor.Cos) { // Loop over Bc (blocks of Cos)
            std::cout << "Subvol Block [Ci: " << bi << "-" << std::min(bi + tensor.Cis, tensor.Ci) - 1
                      << ", Co: " << bc << "-" << std::min(bc + tensor.Cos, tensor.Co) - 1 << "]:\n";
            for (int cis = bi; cis < std::min(bi + tensor.Cis, tensor.Ci); cis += 64) { // Loop over Cis
                for (int ky = 0; ky < tensor.Ky; ++ky) { // Loop over Ky
                    for (int kx = 0; kx < tensor.Kx_padded; ++kx) { // Loop over Kx
                        for (int cos = bc; cos < std::min(bc + tensor.Cos, tensor.Co); ++cos) { // Loop over Cos
                            std::cout << "Weights [Ci: " << cis << "-" << std::min(cis + 64, tensor.Ci) - 1
                                      << ", Co: " << cos << ", Ky: " << ky << ", Kx: " << kx << "]: ";
                            for (int ci_gran = cis; ci_gran < std::min(cis + 64, tensor.Ci); ++ci_gran) { // Loop over Ci_gran
                                std::cout << (int)tensor.wgt_at(cos, ci_gran, ky, kx) << " ";
                            }
                            std::cout << "\n";
                        }
                    }
                }
            }
            std::cout << "\n";
        }
    }
    std::cout << "Co_coeff: \n";
    for (int co = 0; co < tensor.Co; ++co) {
        std::cout << tensor.c0_at(co) << " ";
    }
    std::cout << "\n C1_coeff: \n";
    for (int co = 0; co < tensor.Co; ++co) {
        std::cout << tensor.c1_at(co) << " ";
    }
    std::cout << "\n C2_coeff: \n";
    for (int co = 0; co < tensor.Co; ++co) {
        std::cout << tensor.c2_at(co) << " ";
    }
    std::cout << "\n QDQ Params: \n";
    auto qdq_params = tensor.get_qdq_params();
    std::cout << "Shift_out: " << qdq_params.shift_out << "\n";
    std::cout << "ifm_sign: " << int(qdq_params.ifm_sign) << "\n";
    std::cout << "wgt_sign: " << int(qdq_params.wgt_sign) << "\n";
    std::cout << "ofm_sign: " << int(qdq_params.ofm_sign) << "\n";
}

#endif // CONV_HPP
