#ifndef MAXPOOL_HPP
#define MAXPOOL_HPP
#include "common.hpp"
#include <limits>

struct MaxPool_NoQdqParams_A8 {
    uint8_t sign;
    // NOTE: Used to align the struct size to 128 bytes.
    uint8_t reserved[127];
};

inline void init_random_maxpool_noqdq_a8(
    ActTensor<int8_t> ifm,
    MaxPool_NoQdqParams_A8* qdq_params,
    int sign
){
    printf("Inside init random maxpool noqdq a8\n");
    int ifm_min = (sign == 0) ? 0 : -16;
    int ifm_max = (sign == 0) ? 64 : 16;
    for (int c = 0; c < ifm.C; ++c) {
        for (int y = 0; y < ifm.Y; ++y) {
            for (int x = 0; x < ifm.X; ++x) {
                ifm.at(c, y, x) = (int8_t)((rand() % (ifm_max - ifm_min)) + ifm_min);
            }
        }
    }
    qdq_params->sign = sign;
}

inline void cpu_maxpool_noqdq_a8(
    ActTensor<int8_t> ifm,
    ActTensor<int8_t> ofm,
    int Sy, int Sx,
    int Py, int Px,
    int Ky, int Kx,
    int sign
){
    printf("Inside cpu maxpool noqdq a8\n");
    assert(ifm.C == ofm.C);
    for (int c = 0; c < ifm.C; ++c) {
        for (int yo = 0; yo < ofm.Y; ++yo) {
            for (int xo = 0; xo < ofm.X; ++xo) {
                int yi_start = yo * Sy - Py;
                int xi_start = xo * Sx - Px;
                int64_t max_val = (sign == 0) ? int64_t(std::numeric_limits<uint8_t>::lowest()) 
                                                : int64_t(std::numeric_limits<int8_t>::lowest());
                for (int ky = 0; ky < Ky; ++ky) {
                    for (int kx = 0; kx < Kx; ++kx) {
                        int yi = yi_start + ky;
                        int xi = xi_start + kx;
                        int64_t val = 0;
                        if (yi < 0 || yi >= ifm.Y || xi < 0 || xi >= ifm.X) {
                            val = (sign == 0) ? int64_t(std::numeric_limits<uint8_t>::lowest()) 
                                                  : int64_t(std::numeric_limits<int8_t>::lowest());
                        } else {
                            val = (sign == 0) ? int64_t(uint8_t(ifm.at(c, yi, xi))) 
                                                  : int64_t(int8_t(ifm.at(c, yi, xi)));

                        }
                        max_val = (val > max_val) ? val : max_val;
                    }
                }
                ofm.at(c, yo, xo) = (sign == 0) ? uint8_t(max_val) : int8_t(max_val);
            }
        }
    }
}

#endif // MAXPOOL_HPP