// Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
////////////////////////////////////////////////////////////////////////

#ifndef _MLLIB_CONFIG_H_
#define _MLLIB_CONFIG_H_

#include "conv/conv_xint8/mllib_const.h"

//Fast Kernels plus a LR=8 Kernel after reaching layer 5
//#define LAYER1_5
//Is using only LR=8 kernels
#define ARCH_LR8
#define CONV_LR8_KERNEL    1
#define ACT_TYPE           LINEAR

#define DTYPE_IN           uint8        // IFM, act: ReLU
#define DTYPE_IN2          int8         // IFM for add2d, from conv2d
#define DTYPE_OUT          uint8        // OFM. act: ReLU
#define DTYPE_WTS          int8         // WTS
#define FILTER_K           3            // WTS: 3x3x16x32

#define ASYNC_ALL          1
#if ASYNC_ALL
#define IFM_BPC_TYPE       adf::bpc_async_1d
#define WTS_BPC_TYPE       adf::bpc_async_1d
#else
#define IFM_BPC_TYPE       adf::bpc_sync_1d
#define WTS_BPC_TYPE       adf::bpc_sync_1d
#endif
#define OFM_BPC_TYPE       adf::bpc_async_1d
#define IFM2_BPC_TYPE      adf::bpc_async_1d

// These should really be set dependent on the network, not in a common area
constexpr unsigned int BATCH_SIZE = 1;
constexpr unsigned int SM_BATCH_SIZE = BATCH_SIZE;
constexpr unsigned int NUM_SM_CLASSES = 1001;

// This option enables IFM reuse across multiple OC iterations (when IFM fits entirely in L1). This is handled by the
// wrapper and the ifm_reuse_iter field in super_kernel_params_t. When disabled, IFM is retransmitted by the DMA on
// every OC iteration.
#define IFM_REUSE          1

// This option enables WTS reuse across multiple iterations (when WTS fits entirely in L1). This is handled by the
// wrapper and the wts_reuse_iter field in super_kernel_params_t. When disabled, WTS is retransmitted by the DMA on
// every iteration even if it fits in L1.
#define WTS_REUSE          0

// Moved from Makefile

#define CONV2D_OPT1
#define STRIDE2_OPT
#define AIE2P_USE_SHIFTX   0
#define USE_OPMODE_1       0
#define SHIFT_IFM1         8            // Shift for IFM for add2d, from conv2d
#define ADD2D_SS_SIZE 2

#define IN_CASC input_cascade
#define OUT_CASC output_cascade

enum class layer_mode : uint8_t
{
    CONV2D                      = 0,    // Standalone Conv
    CONV2D_ADD2D                = 2,    // Conv + Elementwise Add Fused
    CONV2D_ADD2D_GAP            = 3,    // Conv + Elementwise Add + GlobalAvgPool Fused
    FC                          = 4,    // Conv Fully connected mode, can this be merged to conv_type ?
    CONV2D_PAD2D_MAXPOOL        = 5,    // Conv + Pad2D + Maxpool Fused
    AVGPOOL2D_PREFUSED_CONV2D   = 6,    // Avgpool + Conv2d Fused, do we really have this ?
    AVGPOOL2D_CONV2D_GAP        = 7,    // Avgpool + Conv2d + GlobalAvgPool Fused, do we really have this ?
    MAXPOOL2D                   = 8,    // Standalone Maxpool
    CONV2D_GAP                  = 9,    // Conv + GlobalAvgPool Fused
    DWC                         = 10,   // Standalone Depthwise needs
    CONVDWC                     = 11,   // CONV + DWC Fused needs three inputs
    RESIZE_NEAREST              = 12,   // resize operation nearest
    RESIZE_BILINEAR             = 13,   // resize operation bilinear
    ELTWISE_MUL                 = 14,   // element wise mul (vector x scalar)
    MUL_ADD2D                   = 15,   // element wise mul (vector x scalar) + add2d
    GAP2D                       = 16,   // Standalone GlobalAvgPool (used in early work of inceptionv4)
    AVGPOOL2D                   = 17,   // Standalone AvgPool (used in early work of inceptionv4)
    SOFTMAX                     = 18,   // Standalone softmax
    CONV2D_LEAKYRELU            = 19,   // Conv + Elementwise Add + LeakyRelu
    CONV2D_ADD2D_LEAKYRELU      = 20,   // Conv + Elementwise Add + LeakyRelu
    GAPDWC                      = 21,   // GAP implemented as a DWC
};

enum class conv_type : uint8_t
{
    CONV2D_REGULAR              = 0,            // Conv
    CONV2D_LAYER1               = 1,            // Conv
    CONV2D_FC1                  = 2,            // Manipulates Param for FC Mode
    CONV2D_7x7S2_LYR1           = 3,            // Manipulates Param for Layer1 Mode
    CONV2D_7x7S2_LYR1_EDGE      = 4,            // Manipulates Param for Layer1 Mode EDGE
    CONV2D_1x1Fuse              = 5,            // Custom conv_type to describe Conv 1x1 + DWC
    CONV2D_FC2                  = 8,            // Manipulates Param for FC Mode
};

enum class resize_mode : uint8_t
{
    half_pixel                = 0,
    asymmetric                = 1,
    symmetric                 = 2,
};

// This needs to be at the bottom of the file, as LCP types use macros defined above
//#include "conv_lcp.h"

#endif //_MLLIB_CONFIG_H_
