#ifndef RESIZE_NNI_HPP
#define RESIZE_NNI_HPP

#include "common.hpp"

/*
    This enum defines the coordinate transformation modes for bilinear interpolation.
    - HalfPixel: Uses half-pixel coordinates for interpolation.
    - AlignCorners: Aligns corners of the input and output tensors.
    - Asymmetric: Uses asymmetric scaling for interpolation.
*/
enum class CoordinateTransformationMode {
    HalfPixel,
    AlignCorners,
    Asymmetric
};

template <typename T>
struct ResizeNNIWGT
{
    CoordinateTransformationMode const mode;
    int const Yo;
    int const Xo;
    int const Yi;
    int const Xi;
    int const Yos;
    int const Xos;
    int const Yis;
    int const Xis;
    int const Yis_step;
    int const Xis_step;
    int const Yis_offset;
    int const Xis_offset;
    int const split_Y;
    int const split_X;
    bool const use_padding;
    T* data;
    float* raw_indices_Y;
    float* raw_indices_X;
    ResizeNNIWGT(
        CoordinateTransformationMode mode,
        int Yo, int Xo,
        int Yi, int Xi,
        int Yos, int Xos,
        int Yis, int Xis,
        int Yis_step, int Xis_step,
        int Yis_offset, int Xis_offset,
        bool use_padding,
        void* data
    )
        : mode(mode)
        , Yo(Yo)
        , Xo(Xo)
        , Yi(Yi)
        , Xi(Xi)
        , Yos(Yos)
        , Xos(Xos)
        , Yis(Yis)
        , Xis(Xis)
        , Yis_step(Yis_step)
        , Xis_step(Xis_step)
        , Yis_offset(Yis_offset)
        , Xis_offset(Xis_offset)
        , use_padding(use_padding)
        , data(static_cast<T*>(data))
        , split_Y(static_cast<int>(std::ceil(static_cast<float>(Yo) / Yos)))
        , split_X(static_cast<int>(std::ceil(static_cast<float>(Xo) / Xos)))
        , raw_indices_Y(static_cast<float*>(malloc(Yo *sizeof(float))))
        , raw_indices_X(static_cast<float*>(malloc(Xo *sizeof(float))))
    {}

    void set_raw_indices(bool debug_print=false)
    {
        auto iy = raw_indices_Y;
        auto ix = raw_indices_X;
        float grid_offset = 0.0f;
        float scale_adjust = 0.0f;
        if (mode == CoordinateTransformationMode::HalfPixel) {
            grid_offset = 0.5f;
            scale_adjust = 0.0f;
        } else if (mode == CoordinateTransformationMode::AlignCorners) {
            grid_offset = 0.0f;
            scale_adjust = 1.0f;
        } else {
            grid_offset = 0.0f;
            scale_adjust = 0.0f;
        }
        float scale_Y = (static_cast<float>(Yo) - scale_adjust) / (static_cast<float>(Yi) - scale_adjust);
        float scale_X = (static_cast<float>(Xo) - scale_adjust) / (static_cast<float>(Xi) - scale_adjust);
        if(debug_print)
        {
            printf("split_Y: %d, split_X: %d\n", split_Y, split_X);
            printf("scale_Y: %f, scale_X: %f\n", scale_Y, scale_X);
            printf("grid_offset: %f, scale_adjust: %f\n", grid_offset, scale_adjust);
        }
        for (int y = 0; y < Yo; ++y) {
            float float_yidx = static_cast<float>(y) + grid_offset;
            float scaled_yidx = std::max(0.0f, float_yidx / scale_Y - grid_offset);
            scaled_yidx = std::min(scaled_yidx, static_cast<float>(Yi - 1));
            if (debug_print) {
                printf("y = %d, float_yidx = %f, scaled_yidx = %f\n", y, float_yidx, scaled_yidx);
            }
            iy[y] = scaled_yidx;
        }
        for (int x = 0; x < Xo; ++x) {
            float float_xidx = static_cast<float>(x) + grid_offset;
            float scaled_xidx = std::max(0.0f, float_xidx / scale_X - grid_offset);
            scaled_xidx = std::min(scaled_xidx, static_cast<float>(Xi - 1));
            if (debug_print) {
                printf("x = %d, float_xidx = %f, scaled_xidx = %f\n", x, float_xidx, scaled_xidx);
            }
            ix[x] = scaled_xidx;
        }
    }

    void print_raw_indices()
    {
        printf("Raw indices Y:\n");
        for (int i = 0; i < Yo; ++i) {
            printf("%d: %f\n", i, raw_indices_Y[i]);
        }
        printf("Raw indices X:\n");
        for (int i = 0; i < Xo; ++i) {
            printf("%d: %f\n", i, raw_indices_X[i]);
        }
    }

    int32_t srs(int64_t value, int fractional_bits, int output_bits, bool sgn = true) {
        // Scale down by fractional_bits and round
        int64_t scaled = (value + (1LL << (fractional_bits - 1))) >> fractional_bits;
        
        // Clamp to output_bits range
        if (sgn) {
            int64_t max_val = (1LL << (output_bits - 1)) - 1;
            int64_t min_val = -(1LL << (output_bits - 1));
            scaled = std::max(min_val, std::min(max_val, scaled));
        } else {
            int64_t max_val = (1LL << output_bits) - 1;
            scaled = std::max(static_cast<int64_t>(0), std::min(max_val, scaled));
        }
        
        return static_cast<int32_t>(scaled);
    }

    void set_wgt(bool debug_print=false)
    {
        if (debug_print) {
            printf("Setting weights for ResizeNNIWGT with mode: %d\n", static_cast<int>(mode));
        }
        std::vector<std::vector<std::vector<int32_t>>> index(split_Y, std::vector<std::vector<int32_t>>(split_X, std::vector<int32_t>(Yos + Xos, 0)));
        set_raw_indices(debug_print);
        for (int y = 0; y < split_Y; ++y) {
            for (int x = 0; x < split_X; ++x) {
                for (int i = 0; i < Yos; ++i) {
                    int arr_idx = y * Yos + i;
                    if (arr_idx < Yo) {
                        int64_t scaled_idx = static_cast<int64_t>(raw_indices_Y[arr_idx] * (1LL << 24));
                        index[y][x][i] = ResizeNNIWGT::srs(scaled_idx, 24, 32, true);
                    }
                }
                for (int i = 0; i < Xos; ++i) {
                    int arr_idx = x * Xos + i;
                    if (arr_idx < Xo) {
                        int64_t scaled_idx = static_cast<int64_t>(raw_indices_X[arr_idx] * (1LL << 24));
                        index[y][x][i + Yos] = ResizeNNIWGT::srs(scaled_idx, 24, 32, true);
                    }
                }
            }
        }
        // Split indices into id0 (lower 16 bits) and id1 (upper 16 bits)
        std::vector<std::vector<std::vector<int32_t>>> id0(
            split_Y, std::vector<std::vector<int32_t>>(
                split_X, std::vector<int32_t>(Yos + Xos, 0)
            )
        );
        std::vector<std::vector<std::vector<int32_t>>> id1(
            split_Y, std::vector<std::vector<int32_t>>(
                split_X, std::vector<int32_t>(Yos + Xos, 0)
            )
        );
    
        for (int y = 0; y < split_Y; y++) {
            for (int x = 0; x < split_X; x++) {
                for (int i = 0; i < Yos + Xos; i++) {
                    id0[y][x][i] = index[y][x][i] & 0xFFFF;
                    id1[y][x][i] = index[y][x][i] >> 16;
                }
            }
        }
    
        // Create id0_rel (relative indices)
        std::vector<std::vector<std::vector<int32_t>>> id0_rel = id0; // Copy id0

        // Adjust relative indices for height splits
        for (int y = 0; y < split_Y - 1; y++) {
            for (int x = 0; x < split_X; x++) {
                for (int i = 0; i < Yos; i++) {
                    id0_rel[y + 1][x][i] = id0[y + 1][x][i] - id0[y + 1][x][0] -
                                             (use_padding ? Yis_offset : 0);
                }
            }
        }
    
        // Adjust first height split
        for (int x = 0; x < split_X; x++) {
            for (int i = 0; i < Yos; i++) {
                id0_rel[0][x][i] = id0_rel[0][x][i] - (use_padding ? Yis_offset : 0);
            }
        }

        int idx_flat = 0;
        for (int y = 0; y < split_Y; y++) {
            for (int x = 0; x < split_X; x++) {
                for (int i = 0; i < Yos + Xos; i++) {
                    int32_t id0_val = id0_rel[y][x][i];
                    int32_t id1_val = id1[y][x][i];
                    if (debug_print) {
                        printf("%d, %d\n", id0_val, id1_val);
                    }
                    data[idx_flat++] = static_cast<T>(id0_val);
                    data[idx_flat++] = static_cast<T>(id1_val);
                }
            }
        }
    }

    static int subvol_size(int Yos, int Xos){
        return (Yos+Xos) * 2 * sizeof(T);
    }

    static int size(int Yo, int Xo, int Yos, int Xos){
        int persubvol = subvol_size(Yos, Xos);
        int Y_tiles = static_cast<int>(std::ceil(static_cast<float>(Yo) / Yos));
        int X_tiles = static_cast<int>(std::ceil(static_cast<float>(Xo) / Xos));
        return Y_tiles * X_tiles * persubvol;
    }
};


template <typename Ta, typename Tw, typename To>
inline void cpu_resize_nni(
    ActTensor<Ta> ifm,
    ResizeNNIWGT<Tw> wgt,
    ActTensor<To> ofm
)
{
    int64_t ifm_Y = ifm.Y;
    int64_t ifm_X = ifm.X;
    int64_t ofm_Y = ofm.Y;
    int64_t ofm_X = ofm.X;

    for (int y = 0; y < ofm_Y; ++y) {
        for (int x = 0; x < ofm_X; ++x) {
            for (int c = 0; c < ifm.C; ++c) {
                auto raw_y = wgt.raw_indices_Y[y];
                auto raw_x = wgt.raw_indices_X[x];
                int64_t ifm_y = wgt.srs(static_cast<int64_t>(raw_y * (1 << 24)), 24, 32, true);
                int64_t ifm_x = wgt.srs(static_cast<int64_t>(raw_x * (1 << 24)), 24, 32, true);
                if (ifm_y >= 0 && ifm_y < ifm_Y && ifm_x >= 0 && ifm_x < ifm_X) {
                    ofm.at(c, y, x) = ifm.at(c, ifm_y, ifm_x);
                } else {
                    ofm.at(c, y, x) = 0; // Handle out-of-bounds
                }
            }
        }
    }
}

template <typename Ta>
inline void init_randon_resize_nni(
    ActTensor<Ta> ifm
)
{
    int64_t ifm_min = 0;
    int64_t ifm_max = 16;

    for (int y = 0; y < ifm.Y; ++y) {
        for (int x = 0; x < ifm.X; ++x) {
            for (int c = 0; c < ifm.C; ++c) {
                ifm.at(c, y, x) = int8_t((rand() % (ifm_max - ifm_min)) + ifm_min); 
            }
        }
    }

}

#endif // RESIZE_NNI_HPP