// Copyright 2022-2024 Advanced Micro Devices, Inc. All Rights Reserved.
////////////////////////////////////////////////////////////////////////

#pragma once
#include <array>
#include <cmath>
#include "conv_xint8/kernel_setup/tensor.hpp"
#include "conv_xint8/kernel_setup/wrapper_mode.h"

/*
############# Overlay Choices #############
These structures are used to configure the overlay
used in each layer.
*/

enum class subarray_split {
    // Currently only splitting subarrays over H is supported.
    H,
    OC
};

struct overlay_choices
{
    // Number of subarrays to use for this layer (must be <= max subarrays).
    // TODO: Allow defaulting to 0 - automatically choose the maximum number of subarrays available.
    //       This requires moving iteration calculation inside the Layer object.
    uint32_t num_subarrays = 1;

    // Which axis to split the subarrays over.
    subarray_split axis    = subarray_split::H;

    // Overcompute / Over-Transfer for each subarray to account for Halo region.
    adf_utils::tensor_HCWc halo       = {0, 0, 0, 0};
};

/*
############# Layer Params #############
These structures are used to pass parameters to the tiling functions.
They are usually reduced to a smaller subset of fields passed to the
kernels called LCPs (Layer Configuration Parameters) or
RTPs (Run Time Parameters).
*/

struct gap_params_t
{
    unsigned int ifm_width;
    unsigned int ifm_height;
    unsigned int ifm_sv_width_actual;
    unsigned int ifm_sv_width;
    unsigned int ifm_sv_height;
    unsigned int div_shift;
};

// TODO JL Copy gap params for now. Need avgpool2d params though
struct avgpool2d_params_t
{
    unsigned int ifm_width;
    unsigned int ifm_height;
    unsigned int ifm_sv_width_actual;
    unsigned int ifm_sv_width;
    unsigned int ifm_sv_height;
    unsigned int div_shift;
};

struct maxpool_params_t
{
    struct padding_t
    {
        // TODO: Use adf_utils::image_padding instead.
        unsigned int left;
        unsigned int top;
        unsigned int right;
        unsigned int bottom;

        //! @note Returns true if padding is non-zero on any edge
        bool is_zero() const
        {
            return (left != 0 || top != 0 || right != 0 || bottom != 0);
        }
    };

    unsigned int ksize;
    unsigned int stride;

    unsigned int in_width;

    unsigned int ofm_pad; // should go to conv2d?
    padding_t padding;

    unsigned int manual_ba_pad_val = 0;
};

struct resize_params_t
{
    resize_mode coordinate_transformation_mode;
};

struct bilinear_params_t
{
    adf_utils::image_padding padding;
    uint8_t radix_h;
    uint8_t radix_w;
    resize_mode coordinate_transformation_mode;
};

struct fused_dw_params_t
{
    adf_utils::tensor_2d ksize = {.x = 0, .y = 0};
    adf_utils::tensor_2d stride = {.x = 0, .y = 0};
    adf_utils::image_padding padding;
};

struct leakyrelu_params_t
{
    int16_t alpha;
    int8_t shift_alpha;
    int8_t shift_out;
};

struct gapdwc_params_t
{
    uint8_t gap_x;
    uint8_t gap_y;
};

namespace adf_utils {
static constexpr uint32_t round_up(uint32_t numToRound, uint32_t multiple)
{
    return ((numToRound + multiple - 1) / multiple) * multiple;
}

static constexpr uint32_t ceil_div(uint32_t num, uint32_t den)
{
    return round_up(num, den) / den;
}

// "Dynamic Shape" is a description of how work is split across the AIE Matrix of tiles
// If a Platform with a different number of tiles is to be supported, this should be extended
enum class dynamic_shape_t {
    H4OC4,  // Split Height by 4, Output Channels by 4
    H8OC2,  // Split Height by 8, Output Channels by 2
    H2OC8,  // Split Height by 2, Output Channels by 8
    H1OC8x2  // Split Height by 1, Output Channels by 16 in 8 groups of 2
};

static std::string_view to_string(dynamic_shape_t ds)
{
    if      (ds == dynamic_shape_t::H4OC4) return "H4OC4";
    else if (ds == dynamic_shape_t::H8OC2) return "H8OC2";
    else if (ds == dynamic_shape_t::H2OC8) return "H2OC8";
    else if (ds == dynamic_shape_t::H1OC8x2) return "H1OC8x2";

    return "";
}

static std::ostream &operator<<(std::ostream &os, const dynamic_shape_t &ds)
{
    os << to_string(ds);
    return os;
}

static constexpr uint32_t get_rows(dynamic_shape_t ds)
{
    // Get the number of rows that the enum corresponds to
    if (ds == dynamic_shape_t::H4OC4) return 4;
    if (ds == dynamic_shape_t::H2OC8) return 8;
    if (ds == dynamic_shape_t::H8OC2) return 2;
    if (ds == dynamic_shape_t::H1OC8x2) return 16;
    return -1;
}

static constexpr uint32_t get_cols(dynamic_shape_t ds)
{
    // Get the number of columns that the enum corresponds to
    if (ds == dynamic_shape_t::H4OC4) return 4;
    if (ds == dynamic_shape_t::H2OC8) return 2;
    if (ds == dynamic_shape_t::H8OC2) return 8;
    if (ds == dynamic_shape_t::H1OC8x2) return 1;
    return -1;
}

static constexpr uint32_t get_h_split(dynamic_shape_t ds)
{
    // Get the height split (currently cols) that the enum corresponds to
    return get_cols(ds);
}

static constexpr uint32_t get_oc_split(dynamic_shape_t ds)
{
    // Get the OC split (currently rows) that the enum corresponds to
    return get_rows(ds);
}


struct tiling_iterations {
    // Iteration over Height
    uint32_t H = 1;

    // Iteration over Width
    uint32_t W = 1;

    // Iteration over Input Channels
    uint32_t IC = 1;

    // Iteration over Output Channels
    uint32_t OC = 1;

    // Re-loading of Weights, using a sub-iteration of Input Channels
    uint32_t super_ifm = 1;

    // Re-loading of Weights, using a sub-iteration of Output Channels
    uint32_t super_ofm = 1;

    constexpr uint32_t super_total() const
    {
        // Total number of times weights need to be reloaded from L3
        return super_ifm * super_ofm;
    }

    constexpr uint32_t total() const
    {
        // Product of all non-super TilingIterations (kernel_iter). For convolutions, IC and OC are multiplicative.
        // In the rest of layers IC and OC should match and they do not multiply.

        if (has_convolution(layer_mode)) {
            return H * W * IC * OC;
        }
        else {
            if (IC != OC)
                throw std::runtime_error("IC and OC should match");

            return H * W * IC;
        }
    }

    constexpr uint32_t sub_channel_iter() const
    {
        // Sub-Iteration of Output Channels in Width
        return ceil_div(OC, super_ofm);
    }

    // Overloaded == for ease of comparisons (e.g. testing)
    constexpr bool operator==(const tiling_iterations& other) const
    {
        return (
            H == other.H &&
            W == other.W &&
            IC == other.IC &&
            OC == other.OC &&
            super_ifm == other.super_ifm &&
            super_ofm == other.super_ofm
        );
    }

    enum layer_mode layer_mode;
};

static constexpr tensor_HCWc get_tensor_iterator(
    const tensor_HCWc& tensor,
    const tensor_HCWc& subvolume,
    uint32_t height_split,
    uint32_t channel_split
)
{
    // The logic of mapping dim and sv to iterations.
    // Splits need to be accounted for too.

    uint32_t inner_c_split = tensor.c_inner / 8;    // This will mostly be 1, but some inner loops might have 2
    inner_c_split = inner_c_split ? inner_c_split : 1;

    //  ceil(H / split) / subvolume_sz;
    // TODO: when this is not constexpr, add checks that inner_c_split needs to be a perfect divisor
    uint32_t iter_h = ceil_div(ceil_div(tensor.H, height_split), subvolume.H);
    uint32_t iter_w = ceil_div(tensor.W, subvolume.W);
    uint32_t iter_c = ceil_div(ceil_div(ceil_div(tensor.C, channel_split), subvolume.C), inner_c_split);

    tensor_HCWc result = {
        .c_inner = 0,  // N/A. Unable to iterate over the inner kernel loop.
        .W = iter_w,
        .C = iter_c,
        .H = iter_h
    };

    return result;
}

static constexpr tiling_iterations compute_ifm_ofm_iterations(
    const dynamic_shape_t ds,
    const tensor_HCWc& ifm_dim,
    const tensor_HCWc& ifmsv_dim,
    const tensor_HCWc& ofm_dim,
    const tensor_HCWc& ofmsv_dim,
    uint32_t super_ifm = 1,
    uint32_t super_ofm = 1,
    layer_mode mode = layer_mode::CONV2D,
    overlay_choices overlay_cfg = {}
){

    // Calculate the iterations from the dims & subvolumes.
    // Currently Super Iters are specified manually.

    uint32_t h_split_for_iters = get_h_split(ds);
    uint32_t oc_split_for_iters = get_oc_split(ds);

    if (mode == layer_mode::GAPDWC) {
        // Note: GAPDWC layer does not support H split in overlay arrays.
        h_split_for_iters = 1;
        oc_split_for_iters = get_oc_split(ds)*get_h_split(ds);
    }
    h_split_for_iters   *= overlay_cfg.axis == subarray_split::H ? overlay_cfg.num_subarrays : 1;
    oc_split_for_iters  *= overlay_cfg.axis == subarray_split::OC ? overlay_cfg.num_subarrays : 1;

    const tensor_HCWc iters_calced_for_in = get_tensor_iterator(
        ifm_dim,
        ifmsv_dim,
        h_split_for_iters,
        1
    );
    const tensor_HCWc iters_calced_for_out = get_tensor_iterator(
        ofm_dim,
        ofmsv_dim,
        h_split_for_iters,
        oc_split_for_iters
    );

    auto input_C_iters  = iters_calced_for_in.C;
    auto output_C_iters = iters_calced_for_out.C;
    auto output_H_iters = iters_calced_for_out.H;
    auto output_W_iters = iters_calced_for_out.W;

    if (mode == layer_mode::DWC ||
        mode == layer_mode::GAPDWC ||
        mode == layer_mode::MAXPOOL2D ||
        mode == layer_mode::AVGPOOL2D ||
        mode == layer_mode::RESIZE_NEAREST ||
        mode == layer_mode::RESIZE_BILINEAR) {

        // For certain layer types, there is only one "channels"
        // So the in/out channel iterations must be the same.
        input_C_iters = output_C_iters;
    }
    if (mode == layer_mode::GAPDWC){
        output_H_iters = iters_calced_for_in.H;
    }

    return tiling_iterations{
        .H         = output_H_iters,
        .W         = output_W_iters,
        .IC        = input_C_iters,
        .OC        = output_C_iters,
        .super_ifm = super_ifm,
        .super_ofm = super_ofm,
        .layer_mode = mode
    };
}
} // namespace adf_utils






