#ifndef DWC_HPP
#define DWC_HPP

#include <random>
#include "common.hpp"

struct DwcWgtTensor_qdq_RT_Params
{
    int32_t shift_out = 0;
    int32_t zp_w = 0;
    int32_t reserved[30] = {0};    // Padding to make the struct size is 128 bytes
};

template<typename Tw, typename Tc0, typename Tc2>
struct DwcWgtTensor_qdq
{
    static const int subv_align_bytes = 128;
    int const Co;
    int const Ky;
    int const Kx;
    int const Kx_padded;
    int const Ky_padded;
    int const Cos;
    int const vect_coeff;
    int const subv_wgt_size;
    int const subv_c0_size;
    int const subv_c2_size;
    int const subv_qdq_size;
    int const subv_size;
    int const data_size;
    Tw* const data;

    DwcWgtTensor_qdq(int Co, int Ky, int Kx, int Cos, int vect_coeff, void* data)
        : Co(Co)
        , Ky(Ky)
        , Kx(Kx)
        , Cos(Cos)
        , vect_coeff(vect_coeff)
        , data(static_cast<Tw*>(data))
        , Kx_padded((Kx < 4) ? 4 : Kx) // Pad Kx to the next multiple of 4
        , Ky_padded((Ky < 3) ? 3 : Ky) // Pad Kx to the next multiple of 4
        , subv_wgt_size(compute_subv_wgt_size(Ky, Kx_padded, Cos))
        , subv_c0_size(compute_subv_c0_size(Cos))
        , subv_c2_size(compute_subv_c2_size(Cos))
        , subv_qdq_size(compute_subv_qdq_size())
        , subv_size(compute_subv_size(Ky_padded, Kx_padded, Cos))
        , data_size(size(Co, Ky_padded, Kx_padded, Cos))
    {
        assert((Cos % 64 == 0) && "Co must be multiple of 64");
        assert((Co % Cos == 0) && "Co must be multiple of Cos");
        assert((Ky > 0) && "Ky must be positive");
        assert((Kx > 0) && "Kx must be positive");
    }

    char* subv_ptr(int co)
    {
        assert((0 <= co) && (co < Co));
        int offset = subv_size * (co / Cos);
        return (char *)data + offset; // NOLINT
    }

    Tw& wgt_at(int co, int ky, int kx)
    {
        assert((0 <= co) && (co < Co));
        assert((0 <= ky) && (ky < Ky_padded));
        assert((0 <= kx) && (kx < Kx_padded));
        int subv_idx = 0;
        subv_idx =
            (((co % Cos) / 32) * (Ky_padded * Kx_padded * 32)) +
            (ky * (Kx_padded * 32)) +
            (kx * 32) +
            (co % 32);
        auto ptr = reinterpret_cast<Tw*>(subv_ptr(co)); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

     Tc0& c0_at(int co)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        auto ptr = reinterpret_cast<Tc0*>(subv_ptr(co) + subv_wgt_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    void set_c0_at(int co, Tc0 value)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        auto ptr = reinterpret_cast<Tc0*>(subv_ptr(co) + subv_wgt_size); // NOLINT
        ptr[subv_idx] = value; // NOLINT
    }


    Tc2& c2_at(int co)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        auto ptr = reinterpret_cast<Tc2*>(subv_ptr(co) + subv_wgt_size + subv_c0_size); // NOLINT
        return ptr[subv_idx]; // NOLINT
    }

    void set_c2_at(int co, Tc2 value)
    {
        assert((0 <= co) && (co < Co));
        int subv_idx = co % Cos;
        auto ptr = reinterpret_cast<Tc2*>(subv_ptr(co) + subv_wgt_size + subv_c0_size); // NOLINT
        ptr[subv_idx] = value; // NOLINT
    }

    void set_qdq_params(DwcWgtTensor_qdq_RT_Params qdq_params)
    {
        // NOTE: The qdq_params are loaded for every Co shard in the conv kernel along with bias.
        for(int co=0; co<Co; co++){
            auto qdq_params_ptr = reinterpret_cast<DwcWgtTensor_qdq_RT_Params*>(subv_ptr(co) + subv_wgt_size + subv_c0_size + subv_c2_size); // NOLINT
            *qdq_params_ptr = qdq_params;
        }
    }

    DwcWgtTensor_qdq_RT_Params get_qdq_params(){
        auto qdq_params_ptr = reinterpret_cast<DwcWgtTensor_qdq_RT_Params*>(subv_ptr(0) + subv_wgt_size + subv_c0_size + subv_c2_size); // NOLINT
        return *qdq_params_ptr;
    }

    static int compute_subv_wgt_size(int Ky, int Kx, int Cos)
    {
        return iceil(Cos * Ky * Kx * int(sizeof(Tw)), subv_align_bytes);
    }

    static int compute_subv_c0_size(int Cos)
    {
        return iceil(Cos * int(sizeof(Tc0)), subv_align_bytes);
    }


    static int compute_subv_c2_size(int Cos)
    {
        return iceil(Cos * int(sizeof(Tc2)), subv_align_bytes);
    }

    static int compute_subv_qdq_size()
    {
        return sizeof(DwcWgtTensor_qdq_RT_Params);
    }

    static int compute_subv_size(int Ky, int Kx, int Cos)
    {
        int subv_wgt_size = compute_subv_wgt_size(Ky, Kx, Cos);
        int subv_c0_size = compute_subv_c0_size(Cos);
        int subv_c2_size = compute_subv_c2_size(Cos);
        int subv_qdq_size = compute_subv_qdq_size();
        int subv_size = subv_wgt_size + 
                        subv_c0_size +
                        subv_c2_size +
                        subv_qdq_size; 
        return subv_size;
    }

    static int size(int Co, int Ky, int Kx, int Cos)
    {
        int num_subv = (Co / Cos);
        int subv_size = compute_subv_size(Ky, Kx, Cos);
        int total_size = num_subv * subv_size;
        return total_size;
    }

};


template<typename Ta, typename Tw, typename Tc0, typename Tc2>
inline void init_random_dwc_qdq_a16w8(
    ActTensor<Ta>& ifm,
    DwcWgtTensor_qdq<Tw, Tc0, Tc2>& wgt,
    DwcWgtTensor_qdq_RT_Params& qdq_params,
    int const Co_no_pad,
    int const Ky_no_pad,
    int const Kx_no_pad
)
{
    int64_t ifm_min = 0;
    int64_t ifm_max = 16;
    int64_t wgt_min = 0;
    int64_t wgt_max = 16;

    float c0_min = 0.0F;
    float c0_max = +16.828F;                            
    float c2_min = 0.0F;
    float c2_max = +1.828F;
    int32_t zp_min = 1;
    int32_t zp_max = 4;

    // Use constexpr to ensure compile-time conversion
    constexpr auto RAND_MAX_F = static_cast<float>(RAND_MAX);
    for (int y = 0; y < ifm.Y; ++y) {
        for (int x = 0; x < ifm.X; ++x) {
            for (int c = 0; c < ifm.C; ++c) {
                ifm.at(c, y, x) = Ta((rand() % (ifm_max - ifm_min)) + ifm_min); 
            }
        }
    }

    for (int ky = 0; ky < wgt.Ky_padded; ++ky) {
        for (int kx = 0; kx < wgt.Kx_padded; ++kx) {
            for (int co = 0; co < wgt.Co; ++co) {
                bool is_pixel = (
                    (co < Co_no_pad) &&
                    (kx < Kx_no_pad) &&
                    (ky < Ky_no_pad)
                );
                wgt.wgt_at(co, ky, kx) = Tw((is_pixel) ? (rand() % (wgt_max - wgt_min)) + wgt_min : 0);
            }
        }
    }

    for (int co = 0; co < wgt.Co; ++co) {
        bool valid_idx = (co < Co_no_pad);
        auto value = 1.0f; // (valid_idx) ? c0_min + static_cast<float>(rand()) / RAND_MAX_F * (c0_max - c0_min) : static_cast<float>(0);
        wgt.set_c0_at(co, value);
    }

    if (wgt.vect_coeff > 0) {
        // NOTE: This indicates per channel QDQ on C2
        for(int co = 0; co < wgt.Co; ++co){
            bool valid_idx = (co < Co_no_pad);
            auto value = (valid_idx) ? c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min) : static_cast<float>(0);
            wgt.set_c2_at(co, value);
        }
    } else {
        // If the vector coefficient is 0, we set c2 to a constant value
        float wgt_c2_scalar = c2_min + static_cast<float>(rand()) / RAND_MAX_F * (c2_max - c2_min);
        for(int co = 0; co < wgt.Co; ++co){
            bool valid_idx = (co < Co_no_pad);
            auto value = 1.0f;  //(valid_idx) ? wgt_c2_scalar : static_cast<float>(0);
            wgt.set_c2_at(co, value);
        }
    }
    int32_t zp_w = int32_t((rand() % (zp_max - zp_min)) + zp_min);
    qdq_params.zp_w = zp_w;
    wgt.set_qdq_params(qdq_params);
}


template<typename Tw, typename Tc0, typename Tc2>
inline void log_tensor(DwcWgtTensor_qdq<Tw, Tc0, Tc2> tensor, std::string const& name)
{
    std::cout << name << ": \n";
    
    // Loop over Co blocks (64 channels per subvolume)
    for (int co_block = 0; co_block < tensor.Co; co_block += tensor.Cos) {
        std::cout << "Subvol Block [Co: " << co_block << "-" 
                  << std::min(co_block + tensor.Cos, tensor.Co) - 1 << "]:\n";
        
        // Loop following the pixel order: Co_gran(32), Ky, Kx, remaining Co(64)
        for (int co_gran = co_block; co_gran < std::min(co_block + tensor.Cos, tensor.Co); co_gran += 32) {
            std::cout << "Co_gran block [" << co_gran << "-" 
                      << std::min(co_gran + 32, tensor.Co) - 1 << "]:\n";
            
            for (int ky = 0; ky < tensor.Ky_padded; ++ky) {
                for (int kx = 0; kx < tensor.Kx_padded; ++kx) {
                    std::cout << "  Weights [Ky: " << ky << ", Kx: " << kx << "]: ";
                    
                    // Inner loop over 32 channels within Co_gran
                    for (int co_offset = 0; co_offset < 32 && (co_gran + co_offset) < tensor.Co; ++co_offset) {
                        int co = co_gran + co_offset;
                        std::cout << (int)tensor.wgt_at(co, ky, kx) << " ";
                    }
                    std::cout << "\n";
                }
            }
        }
        std::cout << "\n";
    }
    
    // Print coefficients
    std::cout << "C0_coeff: \n";
    for (int co = 0; co < tensor.Co; ++co) {
        std::cout << tensor.c0_at(co) << " ";
    }
    std::cout << "\n\nC2_coeff: \n";
    for (int co = 0; co < tensor.Co; ++co) {
        std::cout << tensor.c2_at(co) << " ";
    }
    std::cout << "\n";

    std::cout << "QDQ Params: \n";
    DwcWgtTensor_qdq_RT_Params qdq_params = tensor.get_qdq_params();
    std::cout << "  shift_out: " << qdq_params.shift_out << "\n";
    std::cout << "  zp_w: " << qdq_params.zp_w << "\n";
}


template<typename Ta, typename Tw, typename Tc0, typename Tc2, typename Tacc>
inline void cpu_dwc(
    ActTensor<Ta>& ifm,
    DwcWgtTensor_qdq<Tw, Tc0, Tc2>& wgt,
    ActTensor<Tacc>& out,
    int const Sy, int const Sx,
    int const Py, int const Px
) {
    // Depthwise Convolution with QDQ
    for (int co = 0; co < out.C; ++co) {
        for (int yi = -Py, yo = 0; yo < out.Y; yi += Sy, ++yo) {
            for (int xi = -Px, xo = 0; xo < out.X; xi += Sx, ++xo) {
                Tacc acc = 0;
                for (int ky = 0; ky < wgt.Ky; ++ky) {
                    for (int kx = 0; kx < wgt.Kx; ++kx) {
                        int y = yi + ky;
                        int x = xi + kx;
                        Ta a = (0 <= y && y < ifm.Y &&
                                0 <= x && x < ifm.X) ? ifm.at(co, y, x)
                                                     : 0;
                        Tw w = wgt.wgt_at(co, ky, kx);
                        acc += a * w;
                    }
                }
                out.at(co, yo, xo) = acc;
            }
        }
    }

}


template<typename Ta, typename Tw, typename Tc0, typename Tc2, typename To, typename Tacc>
inline void cpu_3term_qdq(
    ActTensor<Ta> act,
    ActTensor<Tacc> conv_out,
    DwcWgtTensor_qdq<Tw, Tc0, Tc2> wgt,
    ActTensor<To> out,
    int Ky, int Kx,
    int Sy, int Sx,
    int Py, int Px,
    int debug_mode,
    int ofm_sign
)
{
    /*
        Here the QDQ equatiopn changes 
        Because the ifm_sum is not accumulating on the channel dimension
        THe zp_weights is essentially the C1
        C1 = - (scale * zp_w), hence the -ve sign causes the equation to change to subtraction
        QDQ formula:
        Y = (conv_out - (ifm_sum * qdq_params.zp_w) ) * qdq_params.C2  + C0;
    */
    ActTensor<Tacc> ifm_sum(out.C, out.Y, out.X, malloc(sizeof(Tacc) * out.C * out.Y * out.X));
    for(int co = 0; co < out.C; ++co){
        for (int yi = -Py, yo = 0; yo < out.Y; yi += Sy, ++yo) {
            for (int xi = -Px, xo = 0; xo < out.X; xi += Sx, ++xo) {
                Tacc sum = 0;
                for (int ky = 0; ky < Ky; ++ky) {
                    for (int kx = 0; kx < Kx; ++kx) {
                        int y = yi + ky;
                        int x = xi + kx;
                        sum += (0 <= y && y < act.Y &&
                            0 <= x && x < act.X) ? act.at(co, y, x) : 0;
                    }
                }
                ifm_sum.at(co, yo, xo) = sum;
            }
        }
    }
    DwcWgtTensor_qdq_RT_Params qdq_params = wgt.get_qdq_params();
    for (int y = 0; y < out.Y; ++y){
        for (int x = 0; x < out.X; ++x) {
            for (int c = 0; c < out.C; ++c) {
                float conv_val = static_cast<float>(conv_out.at(c, y, x));
                if(debug_mode)
                    std::cout << "conv_val = " << conv_val << std::endl;
                Tacc ifm_sum_val = ifm_sum.at(c, y, x);
                std::cout << "ifm_sum_val = " << ifm_sum_val << std::endl;
                float res = ((conv_val - (ifm_sum_val * qdq_params.zp_w)) * wgt.c2_at(c)) + wgt.c0_at(c);
                if (debug_mode)
                    printf("res = %f, C2 = %f, zp_weights = %d, c0 = %f\n", res, wgt.c2_at(c), qdq_params.zp_w, wgt.c0_at(c));
                // Quantize the result to the output type
                out.at(c, y, x) = quantize_float_to_int16<To>(res, qdq_params.shift_out, ofm_sign);
                if (debug_mode)
                    printf("out.at(%d, %d, %d) = %d\n", c, y, x, out.at(c, y, x));
            }
        }
    }
}

#endif // DWC_HPP