#ifndef BILINEAR_PIXEL_RESIZE_BF16_RUNTIME_HPP 
#define BILINEAR_PIXEL_RESIZE_BF16_RUNTIME_HPP

#include <vector>

/*
    This enum defines the coordinate transformation modes for bilinear interpolation.
    - HalfPixel: Uses half-pixel coordinates for interpolation.
    - AlignCorners: Aligns corners of the input and output tensors.
    - Asymmetric: Uses asymmetric scaling for interpolation.
*/
enum class CoordinateTransformationMode {
    HalfPixel,
    AlignCorners,
    Asymmetric
};

/*
    This structure defines the parameters for the BilinearQDQ operation.
    It includes fields for dequantization and quantization parameters.
    These are runtime parameters that are packed as part of the constrants buffer.
*/
struct BilinearQDQParams {
    uint16_t dq_enable;
    uint16_t dq_sc;
    uint16_t dq_zp;
    uint16_t q_enable;
    uint16_t q_sc;
    uint16_t q_zp;
    uint16_t reserved[26]; // Reserved fields to align to 64 bytes
};

template <typename T>
struct bilinearWGT
{
    CoordinateTransformationMode const mode;
    int const Yo;
    int const Xo;
    int const Yos;
    int const Xos;
    int const Yi;
    int const Xi;
    int const Yis;
    int const Xis;
    int const Xis_step;
    int const Yis_step;
    int const Xis_offset;
    int const Yis_offset;
    int const split_Y;
    int const split_X;
    T* data;
    float* raw_indices_Y;
    float* raw_indices_X;
    bilinearWGT(
        CoordinateTransformationMode mode,
        int Yo,
        int Xo,
        int Yos,
        int Xos,
        int Yi,
        int Xi,
        int Yis,
        int Xis,
        int Yis_step,
        int Xis_step,
        int Yis_offset,
        int Xis_offset,
        T* data
    )
        : mode(mode)
        , Yo(Yo)
        , Xo(Xo)
        , Yos(Yos)
        , Xos(Xos)
        , Yi(Yi)
        , Xi(Xi)
        , Yis(Yis)
        , Xis(Xis)
        , Yis_step(Yis_step)
        , Xis_step(Xis_step)
        , Yis_offset(Yis_offset)
        , Xis_offset(Xis_offset)
        , data(static_cast<T*>(data))
        , split_Y(static_cast<int>(std::ceil(static_cast<float>(Yo) / Yos)))
        , split_X(static_cast<int>(std::ceil(static_cast<float>(Xo) / Xos)))
        , raw_indices_Y(static_cast<float*>(malloc(Yo * sizeof(float))))
        , raw_indices_X(static_cast<float*>(malloc(Xo * sizeof(float))))
    {

    }

    static int subvol_size(int Yos, int Xos)
    {
        return (Yos + Xos) * 4 * sizeof(T);
    }

    void set_raw_indices(bool debug_print=false)
    {
        float* iy = raw_indices_Y;
        float* ix = raw_indices_X;
        float grid_offset = 0.0f;
        float scale_adjust = 0.0f;
        if (mode == CoordinateTransformationMode::HalfPixel) {
            grid_offset = 0.5f;
            scale_adjust = 0.0f;
        } else if (mode == CoordinateTransformationMode::AlignCorners) {
            grid_offset = 0.0f;
            scale_adjust = 1.0f;
        } else { // Asymmetric
            grid_offset = 0.0f;
            scale_adjust = 0.0f;
        }
        float scale_Y = (static_cast<float>(Yo) - scale_adjust) / (static_cast<float>(Yi) - scale_adjust);
        float scale_X = (static_cast<float>(Xo) - scale_adjust) / (static_cast<float>(Xi) - scale_adjust);
        if (debug_print) {
            printf("scale_Y = %f, scale_X = %f, grid_offset = %f, scale_adjust = %f\n", scale_Y, scale_X, grid_offset, scale_adjust);
        }
        for (int y = 0; y < Yo; ++y) {
            float float_yidx = static_cast<float>(y) + grid_offset;
            float scaled_yidx = std::max(0.0f, float_yidx / scale_Y - grid_offset);
            scaled_yidx = std::min(scaled_yidx, static_cast<float>(Yi - 1));
            if (debug_print) {
                printf("y = %d, float_yidx = %f, scaled_yidx = %f\n", y, float_yidx, scaled_yidx);
            }
            iy[y] = scaled_yidx;
        }
        for (int x = 0; x < Xo; ++x) {
            float float_xidx = static_cast<float>(x) + grid_offset;
            float scaled_xidx = std::max(0.0f, float_xidx / scale_X - grid_offset);
            scaled_xidx = std::min(scaled_xidx, static_cast<float>(Xi - 1));
            if (debug_print) {
                printf("x = %d, float_xidx = %f, scaled_xidx = %f\n", x, float_xidx, scaled_xidx);
            }
            ix[x] = scaled_xidx;
        }
    }

    float& get_raw_indices_Y(int y)
    {
        assert(y >= 0 && y < Yo);
        return raw_indices_Y[y];
    }
    float& get_raw_indices_X(int x)
    {
        assert(x >= 0 && x < Xo);
        return raw_indices_X[x];
    }

    void print_raw_indices()
    {
        printf("Fractional Indices (iy): ");
        for (int y = 0; y < Yo; ++y) {
            printf("%f ", raw_indices_Y[y]);
        }
        printf("\n");
        printf("Fractional Indices (ix): ");
        for (int x = 0; x < Xo; ++x) {
            printf("%f ", raw_indices_X[x]);
        }
        printf("\n");
    }

    void set_wgt(bool debug_print=false)
    {
        set_raw_indices(debug_print);
        int coeff_len = Yos + Xos;
        std::vector<std::vector<std::vector<float>>> coeff(split_Y, std::vector<std::vector<float>>(split_X, std::vector<float>(coeff_len)));
        std::vector<std::vector<std::vector<int32_t>>> index(split_Y, std::vector<std::vector<int32_t>>(split_X, std::vector<int32_t>(coeff_len)));
        if (debug_print) {
            printf("coeff_len = %d, split_Y = %d, split_X = %d, Yos = %d, Xos = %d\n", coeff_len, split_Y, split_X, Yos, Xos);
        }

        for (int y = 0; y < split_Y; ++y) {
            for (int x = 0; x < split_X; ++x) {
                // Y indices
                for (int i = 0; i < Yos; ++i) {
                    int arr_idx = y * Yos + i;
                    float idx = (arr_idx < Yo) ? raw_indices_Y[y * Yos + i] : 0.0f; // Handle out of bounds
                    coeff[y][x][i] = std::fmod(idx, 1.0f);
                    index[y][x][i] = static_cast<int32_t>(std::floor(idx)) - (Yis_step * y + Yis_offset);
                }
                // X indices
                for (int i = 0; i < Xos; ++i) {
                    int arr_idx = x * Xos + i;
                    float idx = (arr_idx < Xo) ? raw_indices_X[x * Xos + i] : 0.0f; // Handle out of bounds
                    coeff[y][x][i+Yos] = std::fmod(idx, 1.0f);
                    index[y][x][i+Yos] = 16 * (static_cast<int32_t>(std::floor(idx)) - (Xis_step * x + Xis_offset));
                }
            }
        }
        if (debug_print) {
            // Print coefficients and indices for debugging
            for (int y = 0; y < split_Y; ++y) {
                for (int x = 0; x < split_X; ++x) {
                    for (int i = 0; i < Yos; ++i) {
                            int arr_idx = y * Yos + i;
                            float idx = (arr_idx < Yo) ? raw_indices_Y[y * Yos + i] : 0.0f; // Handle out of bounds
                            printf("raw_indices_Y[%d] = %f\n", y * Yos + i, idx);
                            printf("coeff[%d][%d][%d] = %f Yos part\n", y, x, i, coeff[y][x][i]);
                    }
                    for (int i = 0; i < Xos; ++i) {
                            int arr_idx = x * Xos + i;
                            float idx = (arr_idx < Xo) ? raw_indices_X[x * Xos + i] : 0.0f; // Handle out of bounds
                            printf("raw_indices_X[%d] = %f\n", x * Xos + i, idx);
                            printf("coeff[%d][%d][%d] = %f Xos part\n", y, x, i+Yos, coeff[y][x][i+Yos]);
                    }
                }
            }
            for(int y = 0; y < split_Y; ++y) {
                for(int x = 0; x < split_X; ++x) {
                    for(int i = 0; i < coeff_len; ++i) {
                        if (i < Yos) {
                            printf("index[%d][%d][%d] = %d Yos part\n", y, x, i, index[y][x][i]);
                        } else {
                            printf("index[%d][%d][%d] = %d Xos part\n", y, x, i, index[y][x][i]);
                        }
                    }
                }
            }
        }

        // Fill weights buffer: [cf0, cf1, id0, id1] along last axis
        int idx_flat = 0;
        for (int y = 0; y < split_Y; ++y) {
            for (int x = 0; x < split_X; ++x) {
                for (int i = 0; i < coeff_len; ++i) {
                    uint16_t cf0 = float_to_bfloat16((1.0f - coeff[y][x][i])).value;
                    uint16_t cf1 = float_to_bfloat16((coeff[y][x][i])).value;
                    uint16_t id0 = static_cast<uint16_t>(index[y][x][i] & 0xFFFF);
                    uint16_t id1 = static_cast<uint16_t>(index[y][x][i] >> 16);
                    data[idx_flat++] = cf0;
                    data[idx_flat++] = cf1;
                    data[idx_flat++] = id0;
                    data[idx_flat++] = id1;
                }
            }
        }
    }

    static int size(int Y, int X, int Yos, int Xos)
    {
        int persubvol = subvol_size(Yos, Xos);
        int y_tiles = static_cast<int>(std::ceil(static_cast<float>(Y) / Yos));
        int x_tiles = static_cast<int>(std::ceil(static_cast<float>(X) / Xos));
        int no_of_subvols = y_tiles * x_tiles;
        return no_of_subvols * persubvol;
    }
};

/*
    This function converts a bfloat16 value to a float value.
    and print the tensor values in float format.
*/
inline void print_bfloat16_tensor(
    ActTensor<uint16_t> &tensor,
    const char *name
) {
    printf("%s:\n", name);
    for (int c = 0; c < tensor.C; ++c) {
        for (int y = 0; y < tensor.Y; ++y) {
            for (int x = 0; x < tensor.X; ++x) {
                std::cout << bfloat16_to_float(bfloat16_t{tensor.at(c, y, x)}) << " ";
            }
            std::cout << std::endl;
        }
        std::cout << std::endl;
    }
}


/*
    This function initializes the input tensor with random int16 values.
    The input tensor is expected to be in either int16 or bfloat16 format
    depending on the quantization parameters.
*/
inline void init_bilinear_input_tensor(
    ActTensor<uint16_t> &input,
    int max = 16,
    int min = 0
) {
    // Initialize input tensor with valid uint16_t values in [min, max]
    // Clamp max and min to uint16_t range
    int range = max - min + 1;
    for (int c = 0; c < input.C; ++c) {
        for (int y = 0; y < input.Y; ++y) {
            for (int x = 0; x < input.X; ++x) {
                input.at(c, y, x) = static_cast<uint16_t>(min + (rand() % range));
            }
        }
    }
}


/*
    This function initializes the BilinearQDQParams structure with the provided parameters.
*/
inline void init_bilinear_qdq_params(
    BilinearQDQParams &qdq_params,
    uint16_t dq_enable,
    float dq_sc,
    uint16_t dq_zp,
    uint16_t q_enable,
    float q_sc,
    uint16_t q_zp
) {
    qdq_params.dq_enable = dq_enable;
    qdq_params.dq_sc = float_to_bfloat16(dq_sc).value;
    qdq_params.dq_zp = dq_zp;
    qdq_params.q_enable = q_enable;
    qdq_params.q_sc = float_to_bfloat16(q_sc).value;
    qdq_params.q_zp = q_zp;
}

inline void bilinear_resize_bf16(
    ActTensor<uint16_t> input,
    ActTensor<uint16_t> output,
    float* iy,
    float* ix
) {
    // Perform bilinear interpolation to match the Python model function
    for (int h = 0; h < output.Y; ++h) {
        for (int w = 0; w < output.X; ++w) {
            float ih = iy[h];
            float iw = ix[w];
            int ih0 = static_cast<int>(std::floor(ih));
            int ih1 = static_cast<int>(std::ceil(ih));
            int iw0 = static_cast<int>(std::floor(iw));
            int iw1 = static_cast<int>(std::ceil(iw));
            float dh0 = bfloat16_to_float( float_to_bfloat16( 1.0f - ( ih - float( ih0 ))));
            float dh1 = bfloat16_to_float( float_to_bfloat16( ih - float( ih0 )));
            float dw0 = bfloat16_to_float( float_to_bfloat16( 1.0f - ( iw - float( iw0 ))));
            float dw1 = bfloat16_to_float( float_to_bfloat16( iw - float( iw0 )));
            // Clamp indices to valid range
            ih0 = std::max(0, std::min(ih0, input.Y - 1));
            ih1 = std::max(0, std::min(ih1, input.Y - 1));
            iw0 = std::max(0, std::min(iw0, input.X - 1));
            iw1 = std::max(0, std::min(iw1, input.X - 1));
            for (int c = 0; c < output.C; ++c) {
                output.at(c, h, w) = float_to_bfloat16(
                    bfloat16_to_float(bfloat16_t{input.at(c, ih0, iw0)}) * bfloat16_to_float( float_to_bfloat16( dh0 * dw0 )) +
                    bfloat16_to_float(bfloat16_t{input.at(c, ih0, iw1)}) * bfloat16_to_float( float_to_bfloat16( dh0 * dw1 )) +
                    bfloat16_to_float(bfloat16_t{input.at(c, ih1, iw0)}) * bfloat16_to_float( float_to_bfloat16( dh1 * dw0 )) +
                    bfloat16_to_float(bfloat16_t{input.at(c, ih1, iw1)}) * bfloat16_to_float( float_to_bfloat16( dh1 * dw1 ))).value;
            }
        }
    }
}


/*
    This function performs bilinear pixel resizing on a bfloat16 input tensor.
    If dq_enable is set, it dequantizes the input tensor using the qdq_params.
    If q_enable is set, it quantizes the output tensor using the qdq_params.
    The function uses bilinear interpolation to resize the input tensor to the output tensor size.
*/
inline void cpu_model(
    CoordinateTransformationMode mode,
    ActTensor<uint16_t> input,
    ActTensor<uint16_t> output,
    BilinearQDQParams qdq_params,
    bilinearWGT<uint16_t> wgt
) {
    int ifm_size = ActTensor<uint16_t>::size(input.C, input.Y, input.X);
    int ofm_size = ActTensor<uint16_t>::size(output.C, output.Y, output.X);

    ActTensor<uint16_t> ifm(input.C, input.Y, input.X, malloc(ifm_size));
    ActTensor<uint16_t> ofm(output.C, output.Y, output.X, malloc(ofm_size));

    // Initialize output tensor to zero
    for (int c = 0; c < output.C; ++c) {
        for (int y = 0; y < output.Y; ++y) {
            for (int x = 0; x < output.X; ++x) {
                ofm.at(c, y, x) = float_to_bfloat16(0.0f).value;
            }
        }
    }

    // Dequantize input or copy to float tensor
    for (int y = 0; y < input.Y; ++y) {
        for (int x = 0; x < input.X; ++x) {
            for (int c = 0; c < input.C; ++c) {
                if (qdq_params.dq_enable) {
                    ifm.at(c, y, x) =  float_to_bfloat16( bfloat16_to_float( float_to_bfloat16( input.at(c, y, x) - qdq_params.dq_zp )) * bfloat16_to_float(bfloat16_t{qdq_params.dq_sc})).value;
                } else {
                    ifm.at(c, y, x) = input.at(c, y, x);
                }
            }
        }
    }

    // Perform bilinear interpolation
    bilinear_resize_bf16(ifm, ofm, wgt.raw_indices_Y, wgt.raw_indices_X);

    // Quantize output or copy to bfloat16 tensor
    for (int y = 0; y < output.Y; ++y) {
        for (int x = 0; x < output.X; ++x) {
            for (int c = 0; c < output.C; ++c) {
                if (qdq_params.q_enable) {
                    output.at(c, y, x) = static_cast<uint16_t>(std::round(bfloat16_to_float(bfloat16_t{ofm.at(c, y, x)}) * bfloat16_to_float(bfloat16_t{qdq_params.q_sc}))) + qdq_params.q_zp; 
                } else {
                    output.at(c, y, x) = ofm.at(c, y, x);
                }
            }
        }
    }
}


inline int check_bilinear_output(
    ActTensor<uint16_t> gold_model_output,
    ActTensor<uint16_t> output,
    int q_enable,
    float epsilon = 0.0f
) {
    // Check if the output tensor matches the expected output tensor
    assert(output.C == gold_model_output.C && output.Y == gold_model_output.Y && output.X == gold_model_output.X);
    int error_count = 0;
    float max_gold_diff = 0.0f;
    for (int y = 0; y < output.Y; ++y) {
        for (int x = 0; x < output.X; ++x) {
            for (int c = 0; c < output.C; ++c) {
                float gold_diff = 0.0f;
                if (q_enable) {
                    // compare int16 values
                    // if the diff is greater than epsilon, increment the error count
                    gold_diff = std::abs(float(output.at(c, y, x)) - float(gold_model_output.at(c, y, x)));
                    max_gold_diff = std::max(max_gold_diff, gold_diff);
                    if (gold_diff > epsilon) {
                        error_count++;
                        printf(
                            "ERROR at Y[%d][%d][%d]: \tgold=%d, \taie=%d, \tgold_diff=%f\n", y, x, c,
                            gold_model_output.at(c, y, x),
                            output.at(c, y, x), gold_diff
                        );
                    } else {
                        printf(
                            "PASS at Y[%d][%d][%d]: \tgold=%d, \taie=%d\n", y, x, c,
                            gold_model_output.at(c, y, x),
                           output.at(c, y, x)
                        );
                    }
                } else {
                    // compare bfloat16 values
                    gold_diff = std::abs(bfloat16_to_float(bfloat16_t{output.at(c, y, x)}) - bfloat16_to_float(bfloat16_t{gold_model_output.at(c, y, x)}));
                    max_gold_diff = std::max(max_gold_diff, gold_diff);
                    if (gold_diff > epsilon) {
                        error_count++;
                        printf(
                            "ERROR at Y[%d][%d][%d]: \tgold=%f, \taie=%f, \tgold_diff = %f\n", y, x, c,
                               bfloat16_to_float(bfloat16_t{gold_model_output.at(c, y, x)}),
                               bfloat16_to_float(bfloat16_t{output.at(c, y, x)}), gold_diff
                        );
                    } else {
                        printf(
                            "PASS at [%d][%d][%d]: \tgold=%f, \taie=%f\n", y, x, c,
                               bfloat16_to_float(bfloat16_t{gold_model_output.at(c, y, x)}),
                               bfloat16_to_float(bfloat16_t{output.at(c, y, x)})
                        );
                    }
                }
            }
        }
    }
    if (error_count == 0) {
        printf("Output matches expected data.\n");
        printf("Max gold diff: %f\n", max_gold_diff);
    } else {
        printf("Output does not match expected data. Total errors: %d\n", error_count);
        printf("Max gold diff: %f\n", max_gold_diff);
    }
    return error_count;
}

#endif // BILINEAR_PIXEL_RESIZE_BF16_RUNTIME_HPP