#ifndef TENSOR_HPP
#define TENSOR_HPP

#include <stdint.h>
#include <assert.h>

#include <iostream>
#include <type_traits>
#include <cmath>
#include <cstring>
#include <vector>

int round_up_to_multiple(int x, int m)
{
    return ((x + m - 1) / m) * m;
}

struct TestConfig {
    int Ci, Yi, Xi;
    int Yis, Xis, Yos, Xos;
    int Co, Yo, Xo;
    int Ky, Kx;
    int Sy, Sx;
    int Py_b, Px_b;
    int Py_a, Px_a;
    int Cis, Cos;
    int Co_split;
    int Cop, Cip;
    int8_t shift_res;
    int8_t shift_tdm;
    int16_t lrelu_alpha;
    int16_t lrelu_shift;
    bool enable_matAdd;
    int elw_shift_ifm1, elw_shift_ifm2, elw_shift_ofm;
    int epsilon;
    int ParamSize;
    std::string fused_op;
};

template<typename T>
struct ActTensor
{
    int const C;
    int const Y;
    int const X;
    T* const data;

    ActTensor(int C, int Y, int X, void* data)
        : C(C)
        , Y(Y)
        , X(X)
        , data(static_cast<T*>(data))
    {}

    T& at(int c, int y, int x)
    {
        assert(c < C);
        assert(y < Y);
        assert(x < X);
        int idx = (y * X * C) + (x * C) + c;
        assert(idx < C * Y * X);
        return data[idx];
    }

    void print(char const* msg = nullptr)
    {
        if (msg != nullptr) {
            std::cout << msg;
        }
        for (int c = 0; c < C; ++c) {
            for (int y = 0; y < Y; ++y) {
                for (int x = 0; x < X; ++x) {
                    if (std::is_integral<T>::value) {
                        std::cout << static_cast<int64_t>(at(c, y, x)) << " ";
                    } else {
                        std::cout << at(c, y, x) << " ";
                    }
                }
                std::cout << "\n";
            }
            std::cout << "\n";
        }
    }

    void init_random(int64_t min = 0, int64_t max = 2)
    {
        for (int c = 0; c < C; ++c) {
            for (int y = 0; y < Y; ++y) {
                for (int x = 0; x < X; ++x) {
                    if (std::is_integral<T>::value) {
                        at(c, y, x) = (rand() % (max - min)) + min;
                    } else {
                        at(c, y, x) = ((max - min) * (rand() / float(RAND_MAX))) + min;
                    }
                }
            }
        }
    }

    // void init_random(int64_t min = 0, int64_t max = 2)
    // {
    //     int count = 0;
    //     for (int y = 0; y < Y; ++y) {
    //         // count = 0;
    //         for (int x = 0; x < X; ++x) {
    //             // if (count == 7*C) count = 0;
    //             for (int c = 0; c < C; ++c) {
    //                 if (std::is_integral<T>::value) {
    //                     // at(c, y, x) = (rand() % (max - min)) + min;
    //                     at(c, y, x) = x;
    //                 } else {
    //                     at(c, y, x) = ((max - min) * (rand() / float(RAND_MAX))) + min;
    //                 }
    //             }
    //         }
    //     }
    // }

    static int size(int C, int Y, int X)
    {
        return C * Y * X * sizeof(T);
    }
};

template<typename T,typename Tb = uint8_t, int is_xint8 = 0, int is_a8w8 = 0>
struct ConvWgtTensor;


template<typename Tw, typename Tb, int is_xint8, int is_a8w8>
struct ConvWgtTensor
{
    static int constexpr subv_align_bytes = 64;
    static int constexpr bias_subv_gran = 8;

    static int constexpr num_bias_subv = 4;
    static int constexpr subv_qdq_c1_size = sizeof(int32_t);
    static int constexpr subv_qdq_c2_size = sizeof(int32_t);
    static int constexpr subv_shift_tdm_size = sizeof(int32_t);
    static int constexpr subv_shift_res_size = sizeof(int32_t);
    static int constexpr subv_zp_wgt_size = sizeof(int32_t);

    int const Co;
    int  Ci;
    int  Ky;
    int  Kx;
    int const Cos;
    int const Cis;
    int const subv_qdq_c0_size;
    int const subv_qdq_size;
    int subv_wgt_size;
    int subv_size;
    int const subv_bias_size;
    char* const data;
    int const data_size;

    ConvWgtTensor(int Co, int Ci, int Ky, int Kx, int Cos, int Cis, void* data)
        : Co(Co)
        , Ci(Ci)
        , Ky(Ky)
        , Kx(Kx)
        , Cos(Cos)
        , Cis(Cis)
        , subv_qdq_c0_size(is_xint8 ? 0 : Cos * sizeof(int64_t))
        , subv_qdq_size(is_xint8 ? 0 : (subv_qdq_c0_size + subv_qdq_c1_size +
                                         subv_qdq_c2_size +
                                         subv_shift_tdm_size +
                                         subv_shift_res_size +
                                         subv_zp_wgt_size))
        , subv_bias_size(round_up_to_multiple(Cos * sizeof(Tb) * num_bias_subv, subv_align_bytes))
        , subv_wgt_size(round_up_to_multiple(Cos * Cis * Ky * Kx * sizeof(Tw), subv_align_bytes))
        , data_size(size(Co, Ci, Ky, Kx, Cos, Cis))
        , data(static_cast<char*>(data))
    {
        assert(Co % 8 == 0);
        assert(Ci % 8 == 0);
        if(is_xint8) {
            assert(Co % Cos == 0);
            assert(Ci % Cis == 0);
            subv_size = (subv_wgt_size + subv_bias_size);
        } else {
            assert(Co >= Cos);
            assert(Ci >= Cis);
            subv_size = (round_up_to_multiple(subv_wgt_size + subv_qdq_size, subv_align_bytes));
        }
    }

    char* subv_ptr(int co, int ci)
    {
        //
        // Indexing equation determined by tiling, with the following order.
        // Read this list from right-to-left to determine inner-to-outermost
        // traversal order.
        //
        // Co Ci
        //
        int offset = subv_size * (
            ((co / Cos) * (Ci / Cis)) +
            (ci / Cis)
        );
        return data + offset;
    }

    Tw& at(int co, int ci, int ky, int kx)
    {
        assert(co < Co);
        assert(ci < Ci);
        assert(ky < Ky);
        assert(kx < Kx);
        //
        // Indexing equation detemined by kernel, with the following order.
        // Read this list from right-to-left to determine inner-to-outermost
        // traversal order.
        //
        // Co:Cos Ci:Cis Ky Kx Ci:8 Co:8
        //
        int Cgran = (is_xint8 || is_a8w8) ? 16 : 8;
        int subv_idx =
            (((co % Cos) / Cgran) * Cis * Ky * Kx * Cgran) +
            (((ci % Cis) / 8) * Ky * Kx * Cgran * 8) +
            (ky * Kx * Cgran * 8) +
            (kx * Cgran * 8) +
            ((ci % 8) * 8) +
            (co % 8);
        if(is_xint8 || is_a8w8) subv_idx += (((co % 16) / 8) * 8 * 8);
        Tw* ptr = reinterpret_cast<Tw*>(subv_ptr(co, ci));
        return ptr[subv_idx];
    }

    Tb bias_at(int co)
    {
        assert(co < Co);
        int i = co % Cos;
        int subv_idx =
            ((i / bias_subv_gran) * num_bias_subv * bias_subv_gran) +
            (i % bias_subv_gran);
        Tb* ptr = reinterpret_cast<Tb*>(subv_ptr(co, Ci - 1) + subv_wgt_size);
        assert(uint64_t(&ptr[subv_idx]) - uint64_t(data) + sizeof(*ptr) <= data_size);
        return ptr[subv_idx];
    }
    void set_bias(std::vector<Tb> const& bias_vec)
    {
        assert(bias_vec.size() == Co);
        for (int co = 0; co < Co; co += Cos) {
            for (int ci = 0; ci < Ci; ci += Cis) {
                for (int n = 0; n < num_bias_subv; ++n) {
                    for (int i = 0; i < Cos; ++i) {
                        int data_idx = co + i;
                        int subv_idx =
                            ((i / bias_subv_gran) * num_bias_subv * bias_subv_gran) +
                            (n * bias_subv_gran) +
                            (i % bias_subv_gran);
                        Tb* ptr = reinterpret_cast<Tb*>(subv_ptr(co, ci) + subv_wgt_size);
                        assert(data_idx < Co);
                        assert(uint64_t(&ptr[subv_idx]) - uint64_t(data) + sizeof(*ptr) <= data_size);
                        bool is_last_subv = (ci + Cis >= Cis);
                        ptr[subv_idx] = (is_last_subv) ? bias_vec[data_idx] : 0;
                    }
                }
            }
        }
    }

    void set_qdq_c0(int co, int64_t coeff0)
    {
        for (int ci = 0; ci < Ci; ci += Cis) {
            int64_t* qdq_c0 = reinterpret_cast<int64_t*>(
                subv_ptr(co, ci)
                + subv_wgt_size
            );
            qdq_c0[co % Cos] = coeff0;
        }
    }

    void set_qdq_c1(int32_t coeff1)
    {
        for (int co = 0; co < Co; co += Cos) {
            for (int ci = 0; ci < Ci; ci += Cis) {
                int32_t* qdq_c1 = reinterpret_cast<int32_t*>(
                    subv_ptr(co, ci)
                    + subv_wgt_size
                    + subv_qdq_c0_size
                );
                *qdq_c1 = coeff1;
            }
        }
    }

    void set_qdq_c2(int32_t coeff2)
    {
        for (int co = 0; co < Co; co += Cos) {
            for (int ci = 0; ci < Ci; ci += Cis) {
                int32_t* qdq_c2 = reinterpret_cast<int32_t*>(
                    subv_ptr(co, ci)
                    + subv_wgt_size
                    + subv_qdq_c0_size
                    + subv_qdq_c1_size
                );
                *qdq_c2 = coeff2;
            }
        }
    }

    void set_shift_tdm(int32_t shift)
    {
        for (int co = 0; co < Co; co += Cos) {
            for (int ci = 0; ci < Ci; ci += Cis) {
                int32_t* shift_tdm = reinterpret_cast<int32_t*>(
                    subv_ptr(co, ci)
                    + subv_wgt_size
                    + subv_qdq_c0_size
                    + subv_qdq_c1_size
                    + subv_qdq_c2_size
                );
                *shift_tdm = shift;
            }
        }
    }

    void set_shift_res(int32_t shift)
    {
        for (int co = 0; co < Co; co += Cos) {
            for (int ci = 0; ci < Ci; ci += Cis) {
                int32_t* shift_res = reinterpret_cast<int32_t*>(
                    subv_ptr(co, ci)
                    + subv_wgt_size
                    + subv_qdq_c0_size
                    + subv_qdq_c1_size
                    + subv_qdq_c2_size
                    + subv_shift_tdm_size
                );
                *shift_res = shift;
            }
        }
    }

    void set_zp_wgt(int32_t zp)
    {
        for (int co = 0; co < Co; co += Cos) {
            for (int ci = 0; ci < Ci; ci += Cis) {
                int32_t* zp_wgt = reinterpret_cast<int32_t*>(
                    subv_ptr(co, ci)
                    + subv_wgt_size
                    + subv_qdq_c0_size
                    + subv_qdq_c1_size
                    + subv_qdq_c2_size
                    + subv_shift_tdm_size
                    + subv_shift_res_size
                );
                *zp_wgt = zp;
            }
        }
    }

    void print(char const* msg = nullptr)
    {
        if (msg != nullptr) {
            std::cout << msg;
        }
        for (int co = 0; co < Co; ++co) {
            for (int ci = 0; ci < Ci; ++ci) {
                std::cout<< "cout: " << co  << ", cin: " << ci << std::endl;
                for (int ky = 0; ky < Ky; ++ky) {
                    for (int kx = 0; kx < Kx; ++kx) {
                        std::cout << static_cast<int64_t>(at(co, ci, ky, kx)) << " ";
                    }
                    std::cout << "\n";
                }
                std::cout << "\n";
            }
            std::cout << "\n";
        }
        std::cout << "\n";
        if(is_xint8) {
            for (int co = 0; co < Co; ++co) {
                std::cout << int64_t(bias_at(co)) << " ";
            }
            std::cout << "\n";
        }
    }

    void init_random(int64_t min = 0, int64_t max = 2)
    {
        if(is_xint8){
                min = -4;
                max = 4;
        }
        for (int co = 0; co < Co; ++co) {
            for (int ci = 0; ci < Ci; ++ci) {
                for (int ky = 0; ky < Ky; ++ky) {
                    for (int kx = 0; kx < Kx; ++kx) {
                        at(co, ci, ky, kx) = (rand() % (max - min)) + min;
                    }
                }
            }
        }
        if(is_xint8) {
            std::vector<Tb> bias_vec(Co);
            for (int co = 0; co < Co; ++co) {
                //bias_vec[co] = 4;
                bias_vec[co] = (rand() % (max - min)) + min;
            }
            set_bias(bias_vec);
        }
    }

    void copy_data(Tw* orig_wgt)
    {
        for (int co = 0; co < Co; ++co) {
            for (int ci = 0; ci < Ci; ++ci) {
                for (int ky = 0; ky < Ky; ++ky) {
                    for (int kx = 0; kx < Kx; ++kx) {
                        int orig_idx = co * Ci * Ky * Kx + ci * Ky * Kx + ky * Kx + kx;
                        at(co, ci, ky, kx) = orig_wgt[orig_idx];
                    }
                }
            }
        }
    }

    static int size(int Co, int Ci, int Ky, int Kx, int Cos, int Cis)
    {
        int num_subv = (Co / Cos) * (Ci / Cis);
        int subv_wgt_size = round_up_to_multiple(Cos * Cis * Ky * Kx * sizeof(Tw), subv_align_bytes);
        if(is_xint8){
                int subv_bias_size = round_up_to_multiple(Cos * sizeof(Tb) * num_bias_subv, subv_align_bytes);
                int subv_size = subv_wgt_size + subv_bias_size;
                return num_subv * subv_size;
        }
	int subv_qdq_c0_size = Cos * sizeof(int64_t);
        int subv_qdq_size = subv_qdq_c0_size + subv_qdq_c1_size +
                                         subv_qdq_c2_size +
                                         subv_shift_tdm_size +
                                         subv_shift_res_size +
                                         subv_zp_wgt_size;
        int subv_size = round_up_to_multiple(subv_wgt_size + subv_qdq_size, subv_align_bytes);
        return num_subv * subv_size;
    }

};

template<typename T, int Cs>
struct DwcWgtTensor;

template<int Cs>
struct DwcWgtTensor<uint8_t, Cs>
{
    using Tw = uint8_t;
    static int constexpr subv_align_bytes = 64;
    static int constexpr subv_qdq_c0_size = Cs * sizeof(int64_t);
    static int constexpr subv_qdq_c1_size = sizeof(int32_t);
    static int constexpr subv_qdq_c2_size = sizeof(int32_t);
    static int constexpr subv_shift_tdm_size = sizeof(int32_t);
    static int constexpr subv_shift_res_size = sizeof(int32_t);
    static int constexpr subv_zp_wgt_size = sizeof(int32_t);
    static int constexpr subv_qdq_size = subv_qdq_c0_size +
                                         subv_qdq_c1_size +
                                         subv_qdq_c2_size +
                                         subv_shift_tdm_size +
                                         subv_shift_res_size +
                                         subv_zp_wgt_size;

    int const C;
    int const Ky;
    int const Kx;
    int const subv_wgt_size;
    int const subv_size;
    char* const data;

    DwcWgtTensor(int C, int Ky, int Kx, void* data)
        : C(C)
        , Ky(Ky)
        , Kx(Kx)
        , subv_wgt_size(round_up_to_multiple(Cs * Ky * Kx * sizeof(Tw), subv_align_bytes))
        , subv_size(round_up_to_multiple(subv_wgt_size + subv_qdq_size, subv_align_bytes))
        , data(static_cast<char*>(data))
    {
        assert(C >= Cs);
    }

    char* subv_ptr(int c)
    {
        int offset = subv_size * (c / Cs);
        return data + offset;
    }

    Tw& at(int c, int ky, int kx)
    {
        assert(c < C);
        assert(ky < Ky);
        assert(kx < Kx);
        //
        // Indexing equation determined by the kernel, with the following order.
        // Read this list from right-to-left to determine inner-to-outermost
        // traversal order.
        //
        // C:Cs Ky Kx C:8
        //
        int subv_idx =
            (((c % Cs) / 8) * Ky * Kx * 8) +
            (ky * Kx * 8) +
            (kx * 8) +
            (c % 8);
        Tw* ptr = reinterpret_cast<Tw*>(subv_ptr(c));
        return ptr[subv_idx];
    }

    void set_qdq_c0(int c, int64_t coeff0)
    {
        int64_t* qdq_c0 = reinterpret_cast<int64_t*>(
            subv_ptr(c)
            + subv_wgt_size
        );
        qdq_c0[c % Cs] = coeff0;
    }

    void set_qdq_c1(int32_t coeff1)
    {
        for (int c = 0; c < C; c += Cs) {
            int32_t* qdq_c1 = reinterpret_cast<int32_t*>(
                subv_ptr(c)
                + subv_wgt_size
                + subv_qdq_c0_size
            );
            *qdq_c1 = coeff1;
        }
    }

    void set_qdq_c2(int32_t coeff2)
    {
        for (int c = 0; c < C; c += Cs) {
            int32_t* qdq_c2 = reinterpret_cast<int32_t*>(
                subv_ptr(c)
                + subv_wgt_size
                + subv_qdq_c0_size
                + subv_qdq_c1_size
            );
            *qdq_c2 = coeff2;
        }
    }

    void set_shift_tdm(int32_t shift)
    {
        for (int c = 0; c < C; c += Cs) {
            int32_t* shift_res = reinterpret_cast<int32_t*>(
                subv_ptr(c)
                + subv_wgt_size
                + subv_qdq_c0_size
                + subv_qdq_c1_size
                + subv_qdq_c2_size
            );
            *shift_res = shift;
        }
    }

    void set_shift_res(int32_t shift)
    {
        for (int c = 0; c < C; c += Cs) {
            int32_t* shift_tdm = reinterpret_cast<int32_t*>(
                subv_ptr(c)
                + subv_wgt_size
                + subv_qdq_c0_size
                + subv_qdq_c1_size
                + subv_qdq_c2_size
                + subv_shift_tdm_size
            );
            *shift_tdm = shift;
        }
    }

    void set_zp_wgt(int32_t zp)
    {
        for (int c = 0; c < C; c += Cs) {
            int32_t* zp_wgt = reinterpret_cast<int32_t*>(
                subv_ptr(c)
                + subv_wgt_size
                + subv_qdq_c0_size
                + subv_qdq_c1_size
                + subv_qdq_c2_size
                + subv_shift_tdm_size
                + subv_shift_res_size
            );
            *zp_wgt = zp;
        }
    }

    void print(char const* msg = nullptr)
    {
        if (msg != nullptr) {
            std::cout << msg;
        }
        for (int c = 0; c < C; ++c) {
            for (int ky = 0; ky < Ky; ++ky) {
                for (int kx = 0; kx < Kx; ++kx) {
                    std::cout << static_cast<int64_t>(at(c, ky, kx)) << " ";
                }
                std::cout << "\n";
            }
            std::cout << "\n";
        }
    }

    void init_random(int64_t min = 0, int64_t max = 8)
    {
        for (int c = 0; c < C; ++c) {
            for (int ky = 0; ky < Ky; ++ky) {
                for (int kx = 0; kx < Kx; ++kx) {
                    at(c, ky, kx) = (rand() % (max - min)) + min;
                }
            }
        }
    }

    static int size(int C, int Ky, int Kx)
    {
        int num_subv = C / Cs;
        int subv_wgt_size = round_up_to_multiple(Cs * Ky * Kx * sizeof(Tw), subv_align_bytes);
        int subv_size = round_up_to_multiple(subv_wgt_size + subv_qdq_size, subv_align_bytes);
        return num_subv * subv_size;
    }
};

// Computes the Ci dimension of a folded weight matrix
int fold_channel_in_dim(int Ci, int fold_factor, int Ci_gran)
{
    int Ci_p = ((Ci + Ci_gran - 1) / Ci_gran) * Ci_gran;
    int Ci_f = Ci_p * fold_factor;
    return Ci_f;
}

// Computes the Kx dimension of a folded weight matrix
int fold_kernel_x_dim(int Kx, int fold_factor)
{
    int Kx_f = (Kx + fold_factor - 1) / fold_factor;
    return Kx_f;
}

int fold_spatial_x_dim(int Xo, int Kx, int Sx, int fold_factor, int Xi_gran)
{
    int Kx_f = fold_kernel_x_dim(Kx, fold_factor);
    int Sx_f = Sx / fold_factor;
    int x = ((Xo - 1) * Sx_f) + Kx_f;
    return round_up_to_multiple(x, Xi_gran);
}

int fold_spatial_y_dim(int Yi, int pad)
{
    int Yi_f = Yi + 2 * pad;
    return Yi_f;
}

int fold_stride_x_dim(int Sx, int fold_factor)
{
    assert(Sx % fold_factor == 0);
    int Sx_f = Sx / fold_factor;
    return Sx_f;
}

// This function will fold pixels from Kx into the Cin dimension.
// The final filter width will be ceil(Kx / fold_factor, 1) and the
// final filter depth will be ceil(Ci, Ci_gran) * fold_factor. Each fold
// will be appended to the end of the input channel dimension of the filter.
//
// The input data is assumed to be formatted as Co Ci Ky Kx where
// the innermost traversal is read from right-to-left.
//
// Proceeding down the input channel dimension, we will first traverse
// all pixels from the first fold, then the second fold, ..., etc.
// Any extra trailing pixels are padded with wgt_zp. Padding will
// be inserted at the end of each fold to round each fold up to a
// multiple of Ci_gran.
//
// The filter Kx columns for each fold are interleaved so that
// the first column belongs to the first fold, second column
// belongs to the second fold, etc. then this wraps on the
// number of folds.
template<typename T, typename Tb, int is_xint8, int is_a8w8>
void fold_conv_wgt(
    T const* wgt_data,
    T wgt_zp,
    int Co,
    int Ci,
    int Ky,
    int Kx,
    int fold_factor,
    int Ci_gran,
    ConvWgtTensor<T, Tb, is_xint8, is_a8w8> wgt_fold)
{
    int Ci_p = ((Ci + Ci_gran - 1) / Ci_gran) * Ci_gran;
    int Ci_f = fold_channel_in_dim(Ci, fold_factor, Ci_gran);
    int Kx_f = fold_kernel_x_dim(Kx, fold_factor);

    assert(wgt_fold.Co == Co);
    assert(wgt_fold.Ci == Ci_f);
    assert(wgt_fold.Ky == Ky);
    assert(wgt_fold.Kx == Kx_f);

    for (int o = 0; o < Co; ++o) {
        for (int f = 0; f < fold_factor; ++f) {
            for (int i = 0; i < Ci_p; ++i) {
                for (int y = 0; y < Ky; ++y) {
                    for (int x = 0; x < Kx_f; ++x) {
                        // NOTE: Here src_i or src_x may be out of bounds if we attempt
                        // to fold more pixels than exist in the weight matrix. This is
                        // when the zero-point padding will occur.
                        int dst_i = i + (f * Ci_p);
                        int src_i = i;
                        int dst_x = x;
                        int src_x = (x * fold_factor) + f;
                        int src_idx =
                            (o * Ci * Ky * Kx) +
                            (src_i * Ky * Kx) +
                            (y * Kx) +
                            (src_x);
                        T val = ((src_i < Ci) && (src_x < Kx)) ? wgt_data[src_idx] : wgt_zp;
                        wgt_fold.at(o, dst_i, y, dst_x) = val;
                    }
                }
            }
        }
    }
}

template<typename T>
void fold_conv_wgt_7x7(
    const T* wgt_data,    // [Co][Ci][Ky][Kx]
    T wgt_zp,
    TestConfig orig,
    TestConfig folded,
    T* wgt_fold            // [Co][Ci*fold_factor][Ky][Kx_new]
)
{
    // int Ci_fold = folded.Ci;  // 24

    // zero initialize folded buffer
    int total_size = folded.Co * folded.Ci * folded.Ky * folded.Kx;
    for (int i = 0; i < total_size; ++i)
        wgt_fold[i] = wgt_zp;

    for (int o = 0; o < orig.Co; ++o) {
        for (int y = 0; y < orig.Ky; ++y) {
            for (int x_new = 0; x_new < folded.Kx; ++x_new) {
                for (int c_fold = 0; c_fold < folded.Ci; ++c_fold) {

                    int overlay = c_fold / orig.Ci;   // 0,1,2
                    int c_old   = c_fold % orig.Ci;
                    int x_old = -1;  // default pad

                    // Map according to your guide
                    if (x_new == 0) {
                        if (overlay < 3) x_old = overlay; // 0,1,2
                    } else if (x_new == 1) {
                        if (overlay == 2) x_old = 3;      // only last overlay gets old Kx=3
                    } else if (x_new == 2) {
                        if (overlay < 3) x_old = overlay + 4; // 4,5,6
                    }
                    int dst_idx = o*folded.Ci*folded.Ky*folded.Kx + c_fold*folded.Ky*folded.Kx + y*folded.Kx + x_new;
                    if (x_old >= 0 && x_old < orig.Kx) {
                        int src_idx = o*orig.Ci*orig.Ky*orig.Kx + c_old*orig.Ky*orig.Kx + y*orig.Kx + x_old;
                        wgt_fold[dst_idx] = wgt_data[src_idx];
                    } else {
                        wgt_fold[dst_idx] = wgt_zp;
                    }
                }
            }
        }
    }
}

template<typename T, typename Tfold>
void fold_conv_wgt_7x7_pad(
    const T* wgt_data,    // [Co][Ci][Ky][Kx]
    T wgt_zp,
    int CO_nopad,
    int CI_nopad, 
    TestConfig orig,   // pad info
    TestConfig folded,
    Tfold wgt_fold            // [Co][Ci*fold_factor][Ky][Kx_new]
)
{
    // zero initialize folded buffer (already done outside)


    for (int o = 0; o < CO_nopad; ++o) {
        for (int y = 0; y < orig.Ky; ++y) {
            for (int x_new = 0; x_new < folded.Kx; ++x_new) {
                for (int c_fold = 0; c_fold < folded.Ci; ++c_fold) {

                    int overlay = c_fold / orig.Ci;   // 0,1,2
                    int c_old = c_fold % orig.Ci;
                    int x_old = -1;  // default pad

                    // Map according to your guide
                    if (x_new == 0) {
                        if (overlay < 3) x_old = overlay; // 0,1,2
                    }
                    else if (x_new == 1) {
                        if (overlay == 2) x_old = 3;      // only last overlay gets old Kx=3
                    }
                    else if (x_new == 2) {
                        if (overlay < 3) x_old = overlay + 4; // 4,5,6
                    }
                    //int dst_idx = o * folded.Ci * folded.Ky * folded.Kx + c_fold * folded.Ky * folded.Kx + y * folded.Kx + x_new;
                    if (x_old >= 0 && x_old < orig.Kx && c_old < CI_nopad) {
                        int src_idx = o * CI_nopad * orig.Ky * orig.Kx + c_old * orig.Ky * orig.Kx + y * orig.Kx + x_old;
                        wgt_fold.at(o, c_fold, y, x_new) = wgt_data[src_idx];
                    }
                }
            }
        }
    }
}


// This function will fold pixels from X dimension into
// the C dimension, and round up the X/C dimensions according
// to the provided granularities. The same padding is applied
// in all dimensions, with additional trailing padding in the X and C
// dimensions to fill gaps where dimensions are rounded up to
// meet the granularity requirement.
//
// The input data is assumed to be formatted as Ci Yi Xi where
// the innermost traversal is read from right-to-left.
template<typename T>
void fold_conv_ifm(
    T const* ifm_data,
    T ifm_zp,
    int Ci,
    int Yi,
    int Xi,
    int Xo,
    int Kx,
    int Sx,
    int pad,
    int fold_factor,
    int Ci_gran,
    int Xi_gran,
    ActTensor<T> ifm_fold)
{
    int Ci_fold = fold_channel_in_dim(Ci, fold_factor, Ci_gran);
    int Yi_fold = fold_spatial_y_dim(Yi, pad);
    int Xi_fold = fold_spatial_x_dim(Xo, Kx, Sx, fold_factor, Xi_gran);

    assert(ifm_fold.C == Ci_fold);
    assert(ifm_fold.Y == Yi_fold);
    assert(ifm_fold.X == Xi_fold);

    // NOTE: Since ActTensor is YXC ordered, folding pixels from the
    // X dimension into the C dimension requires no additional formatting.
    // We just need to insert the padding values in the YXC dimensions.

    int Ci_pad = Ci_fold / fold_factor;
    int Yi_pad = Yi_fold;
    int Xi_pad = Xi_fold * fold_factor;
    ActTensor<T> ifm_reshape(Ci_pad, Yi_pad, Xi_pad, ifm_fold.data);
    for (int c = 0; c < Ci_pad; ++c) {
        for (int y = 0; y < Yi_pad; ++y) {
            for (int x = 0; x < Xi_pad; ++x) {
                if (c < Ci &&
                    pad <= y && y < Yi + pad &&
                    pad <= x && x < Xi + pad)
                {
                    int idx =
                        (c * Yi * Xi) +
                        ((y - pad) * Xi) +
                        ((x - pad));
                    assert(idx < Ci * Yi * Xi);
                    ifm_reshape.at(c, y, x) = ifm_data[idx];
                } else {
                    ifm_reshape.at(c, y, x) = ifm_zp;
                }
            }
        }
    }
}

template<typename T>
void fold_conv_ifm_7x7(
    T const* ifm_data,   // [Yi][Xi][Ci] (HWC)
    T ifm_zp,
    TestConfig orig,
    TestConfig folded,
    ActTensor<T> ifm_fold // [Ci*3][Yi+2*pad][Xi_fold]
)
{
    const int Ci_fold = folded.Ci;
    const int Yi_fold = folded.Yi;
    const int Xi_fold = folded.Xi;

    assert(ifm_fold.C == Ci_fold);
    assert(ifm_fold.Y == Yi_fold);
    assert(ifm_fold.X == Xi_fold);

    // ---- initialize with zero point ----
    for (int c = 0; c < Ci_fold; ++c)
        for (int y = 0; y < Yi_fold; ++y)
            for (int x = 0; x < Xi_fold; ++x)
                ifm_fold.at(c, y, x) = ifm_zp;
    // ---- fold IFM ----
    for (int y = 0; y < Yi_fold; ++y) {
        int y_old = y; // here we don't pad the zero of left and right which 7x7 will do runtime.
        bool y_valid = (0 <= y_old && y_old < orig.Yi);
        for (int x_new = 0; x_new < Xi_fold; ++x_new) {
            int src_x[3];
            if ((x_new & 1) == 0) {
                // even
                src_x[0] = -1;
                src_x[1] = -1;
                src_x[2] = 2 * x_new;
            } else {
                // odd
                src_x[0] = 2 * x_new - 1;
                src_x[1] = 2 * x_new;
                src_x[2] = 2 * x_new + 1;
            }
            for (int b = 0; b < 3; ++b) {
                int x_old = src_x[b];
                bool x_valid = (0 <= x_old && x_old < orig.Xi);
                for (int c = 0; c < orig.Ci; ++c) {
                    int c_new = b * orig.Ci + c;
                    if (y_valid && x_valid) {
                        int idx =
                            (y_old * orig.Xi * orig.Ci) +
                            (x_old * orig.Ci) +
                            c;
                        ifm_fold.at(c_new, y, x_new) =
                            ifm_data[idx];
                    }
                }
            }
        }
    }
}

template<typename T>
void fold_conv_ifm_7x7_pad(
    const T* ifm_data,   // [Yi][Xi][Ci] (HWC)
    T ifm_zp,
    int CI_nopad,
    const TestConfig& orig,
    const TestConfig& folded,
    T* ifm_fold
)
{
    const int Ci_fold = folded.Ci;
    const int Yi_fold = folded.Yi;
    const int Xi_fold = folded.Xi;

    // ---- initialize with zero point ----
    for (int i = 0; i < Ci_fold * Yi_fold * Xi_fold; ++i) {
        ifm_fold[i] = ifm_zp;
    }

    // ---- fold IFM ----
    for (int y = 0; y < Yi_fold; ++y) {
        int y_old = y;
        bool y_valid = (0 <= y_old && y_old < orig.Yi);

        for (int x_new = 0; x_new < Xi_fold; ++x_new) {

            int src_x[3];
            if ((x_new & 1) == 0) {
                src_x[0] = -1;
                src_x[1] = -1;
                src_x[2] = 2 * x_new;
            } else {
                src_x[0] = 2 * x_new - 1;
                src_x[1] = 2 * x_new;
                src_x[2] = 2 * x_new + 1;
            }

            for (int b = 0; b < 3; ++b) {
                int x_old = src_x[b];
                bool x_valid = (0 <= x_old && x_old < orig.Xi);

                if (!y_valid || !x_valid)
                    continue;

                int idx =
                    (y_old * orig.Xi * CI_nopad) +
                    (x_old * CI_nopad);

                int c_new_base = b * orig.Ci;
                int fold_base =
                    (y * Xi_fold * Ci_fold) +
                    (x_new * Ci_fold) +
                    c_new_base;

                const T* src_ptr = ifm_data + idx;
                T* dst_ptr = ifm_fold + fold_base;

                // ---- portable scalar copy (unrolled) ----
                int c = 0;

                // Copy 8 elements at a time
                for (; c + 8 <= CI_nopad; c += 8) {
                    dst_ptr[c + 0] = src_ptr[c + 0];
                    dst_ptr[c + 1] = src_ptr[c + 1];
                    dst_ptr[c + 2] = src_ptr[c + 2];
                    dst_ptr[c + 3] = src_ptr[c + 3];
                    dst_ptr[c + 4] = src_ptr[c + 4];
                    dst_ptr[c + 5] = src_ptr[c + 5];
                    dst_ptr[c + 6] = src_ptr[c + 6];
                    dst_ptr[c + 7] = src_ptr[c + 7];
                }

                // Tail
                for (; c < CI_nopad; ++c) {
                    dst_ptr[c] = src_ptr[c];
                }
            }
        }
    }
}

template<typename T>
T srs(int64_t acc, int shift, bool is_xint8 = false)
{
    if(is_xint8) {
        if (shift > 0){
            acc = (acc + (1 << (shift-1)))  >> shift;
        }
    } else {
        acc = acc >> shift;
    }
    T val = 0;
    if (std::is_same<T, int8_t>::value) {
        val = (acc > INT8_MAX) ? INT8_MAX :
              (acc < INT8_MIN) ? INT8_MIN :
                                 acc;
    } else if (std::is_same<T, uint8_t>::value) {
        val = (acc > UINT8_MAX) ? UINT8_MAX :
              (acc < 0)         ? 0 :
                                  acc;
    } else if (std::is_same<T, int16_t>::value) {
        val = (acc > INT16_MAX) ? INT16_MAX :
              (acc < INT16_MIN) ? INT16_MIN :
                                  acc;
    } else if (std::is_same<T, uint16_t>::value) {
        val = (acc > UINT16_MAX) ? UINT16_MAX :
              (acc < 0)          ? 0 :
                                   acc;
    } else if (std::is_same<T, int32_t>::value) {
        val = (acc > INT32_MAX) ? INT32_MAX :
              (acc < INT32_MIN) ? INT32_MIN :
                                  acc;
    } else if (std::is_same<T, uint32_t>::value) {
        val = (acc > UINT32_MAX) ? UINT32_MAX :
              (acc < 0)          ? 0 :
                                   acc;
    } else {
        val = acc;
    }
    return val;
}

template<typename Ta>
void cpu_add_2d(
    ActTensor<Ta> ifm1,
    ActTensor<Ta> ifm2,
    ActTensor<Ta> ofm,
    int shiftin0, int shiftin1,int shiftout)
{
    for (int co = 0; co < ofm.C; ++co) {
        for (int yo = 0; yo < ofm.Y;  ++yo) {
            for (int xo = 0; xo < ofm.X; ++xo) {
                int64_t acc1 = ifm1.at(co, yo, xo) << shiftin0 ;
                int64_t acc2 = ifm2.at(co, yo, xo) << shiftin1 ;
                int64_t acc3 = 0;
                acc3 = acc1 +  acc2;
                acc3 = (acc3 > INT8_MAX) ? INT8_MAX :(acc3 < INT8_MIN) ? INT8_MIN :acc3;
                ofm.at(co, yo, xo) = (int8_t)(acc3 >> shiftout);

            }
        }
    }
}

template<typename Ta, typename Tw, int Cs>
void cpu_conv_dw(
    ActTensor<Ta> ifm,
    DwcWgtTensor<Tw, Cs> wgt,
    ActTensor<Ta> ofm,
    int Sy, int Sx,
    int Py_b, int Px_b, int Py_a, int Px_a,
    int shift)
{
    assert(ifm.C == ofm.C);
    for (int c = 0; c < ifm.C; ++c) {
        for (int yi = -Py_b, yo = 0; yo < ofm.Y; yi += Sy, ++yo) {
            for (int xi = -Px_b, xo = 0; xo < ofm.X; xi += Sx, ++xo) {
                int64_t acc = 0;
                for (int ky = 0; ky < wgt.Ky; ++ky) {
                    for (int kx = 0; kx < wgt.Kx; ++kx) {
                        int y = yi + ky;
                        int x = xi + kx;
                        Ta a = (0 <= y && y < ifm.Y &&
                                0 <= x && x < ifm.X) ? ifm.at(c, y, x)
                                                     : 0;
                        Tw w = wgt.at(c, ky, kx);
                        acc += a * w;
                    }
                }
                ofm.at(c, yo, xo) = srs<Ta>(acc, shift);
            }
        }
    }
}

template<typename T>
int check_result(
    ActTensor<T> expected,
    ActTensor<T> received,
    int epsilon = 0)
{
    assert(expected.C == received.C);
    assert(expected.Y == received.Y);
    assert(expected.X == received.X);

    int err_count = 0;
    for (int c = 0; c < expected.C; ++c) {
        for (int y = 0; y < expected.Y; ++y) {
            for (int x = 0; x < expected.X; ++x) {
                int diff = expected.at(c, y, x) - received.at(c, y, x);
                diff = (diff < 0) ? -diff : diff;
                bool fail = (diff > epsilon);
                bool warn = (diff > 0);
                if (fail) {
                    err_count += 1;
                    std::cout << "ERROR: [" << c << ", " << x << ", " << y << "]: "
                              << "Expected: " << (int)expected.at(c, y, x) << ", "
                              << "Received: " << (int)received.at(c, y, x) << "\n";
                } else if (warn) {
                    std::cout << "WARNING: [" << c << ", " << x << ", " << y << "]: "
                              << "Expected: " << (int)expected.at(c, y, x) << ", "
                              << "Received: " << (int)received.at(c, y, x) << "\n";
                }
                // else {
                //     std::cout << "PASS: [" << c << ", " << x << ", " << y << "]: "
                //               << "Expected: " << (int)expected.at(c, y, x) << ", "
                //               << "Received: " << (int)received.at(c, y, x) << "\n";
                // }
            }
        }
    }

    std::cout << "Error Count = " << err_count << "\n";
    if (err_count > 0) {
        std::cout << "DI: FAIL "
                  << expected.C << "x" << expected.Y << "x" << expected.X << "\n";
    } else {
        std::cout << "DI: PASS "
                  << expected.C << "x" << expected.Y << "x" << expected.X << "\n";
    }

    return err_count;
}

#define Q(x) #x
#define QUOTE(x) Q(x)


template<typename Ta, typename Tw, typename Tb, int is_xint8, int is_a8w8>
void cpu_conv_2d(
    ActTensor<Ta> ifm,
    ConvWgtTensor<Tw, Tb, is_xint8, is_a8w8> wgt,
    ActTensor<Ta> ofm,
    TestConfig cfg)
{
    for (int co = 0; co < ofm.C; ++co) {
        for (int yi = -cfg.Py_b, yo = 0; yo < ofm.Y; yi += cfg.Sy, ++yo) {
            for (int xi = -cfg.Px_b, xo = 0; xo < ofm.X; xi += cfg.Sx, ++xo) {
                int64_t acc = 0;
                for (int ci = 0; ci < ifm.C; ++ci) {
                    for (int ky = 0; ky < wgt.Ky; ++ky) {
                        for (int kx = 0; kx < wgt.Kx; ++kx) {
                            int y = yi + ky;
                            int x = xi + kx;
                            Ta a = (0 <= y && y < ifm.Y &&
                                    0 <= x && x < ifm.X) ? ifm.at(ci, y, x)
                                                         : 0;
                            Tw w = wgt.at(co, ci, ky, kx);
                            acc += a * w;
                        }
                    }
                }
                if(is_xint8){
                        acc += wgt.bias_at(co);
                        ofm.at(co, yo, xo) = acc > 0 ? srs<Ta>(acc, cfg.shift_res,true) : srs<Ta>(acc * cfg.lrelu_alpha, (cfg.shift_res + cfg.lrelu_shift),true);
                }
                else{
                        ofm.at(co, yo, xo) = srs<Ta>(acc, cfg.shift_res);
                }
            }
        }
    }
}

TestConfig kernel_fold_mapping(int fold_case_id, const TestConfig& orig)
{
    /*this is NOT a generic mapping, for every specific case only*/
    // 1.  case 0:  from [Ky, Kx] = [7, 7] -> [7, 3]
    TestConfig cfg = orig;
    constexpr int Ci_gran = 8;
    constexpr int Xi_gran = 8;

    switch (fold_case_id) {

        case 0: {
            int fold_factor_Ci = 3;
            int fold_factor_Xi = 2;
            cfg.Ci = orig.Ci * fold_factor_Ci;
            cfg.Xi = orig.Xi / fold_factor_Xi;
            assert (cfg.Ci % Ci_gran == 0);
            assert (cfg.Xi % Xi_gran == 0);
            cfg.Kx = 3; // Kx remain same
            cfg.Sx = 2; // Sx remain same
            cfg.Px_b = 1; // Py_b remain same;
            cfg.Px_a = 0;

            break;
        }
        default: {
            // fallback
            break;
        }
    }
    return cfg;
}


#endif // TENSOR_HPP
