#pragma once

#include <string>
#include <ostream>

#include "conv/conv_xint8/kernel_setup/mllib_config.h"

static std::string layer_mode_string(layer_mode mode)
{
    switch (mode) {
        case layer_mode::CONV2D:                    return "CONV2D";
        case layer_mode::CONV2D_ADD2D:              return "CONV2D_ADD2D";
        case layer_mode::CONV2D_ADD2D_GAP:          return "CONV2D_ADD2D_GAP";
        case layer_mode::FC:                        return "FC";
        case layer_mode::CONV2D_PAD2D_MAXPOOL:      return "CONV2D_PAD2D_MAXPOOL";
        case layer_mode::AVGPOOL2D_PREFUSED_CONV2D: return "AVGPOOL2D_PREFUSED_CONV2D";
        case layer_mode::AVGPOOL2D_CONV2D_GAP:      return "AVGPOOL2D_CONV2D_GAP";
        case layer_mode::MAXPOOL2D:                 return "MAXPOOL2D";
        case layer_mode::CONV2D_GAP:                return "CONV2D_GAP";
        case layer_mode::DWC:                       return "DWC";
        case layer_mode::CONVDWC:                   return "CONVDWC";
        case layer_mode::RESIZE_NEAREST:            return "RESIZE_NEAREST";
        case layer_mode::RESIZE_BILINEAR:           return "RESIZE_BILINEAR";
        case layer_mode::ELTWISE_MUL:               return "ELTWISE_MUL";
        case layer_mode::MUL_ADD2D:                 return "MUL_ADD2D";
        case layer_mode::GAP2D:                     return "GAP2D";
        case layer_mode::AVGPOOL2D:                 return "AVGPOOL2D";
        case layer_mode::SOFTMAX:                   return "SOFTMAX";
        case layer_mode::CONV2D_LEAKYRELU:          return "CONV2D_LEAKYRELU";
        case layer_mode::CONV2D_ADD2D_LEAKYRELU:    return "CONV2D_ADD2D_LEAKYRELU";
        case layer_mode::GAPDWC:                    return "GAPDWC";

        default:
            throw std::invalid_argument("Invalid value for layer_mode: " + std::to_string((int)mode));
    }
}

static constexpr bool has_fused_elementwise(layer_mode mode)
{
    switch (mode) {
        case layer_mode::CONV2D_ADD2D:
        case layer_mode::CONV2D_ADD2D_GAP:
        case layer_mode::CONV2D_ADD2D_LEAKYRELU:
            return true;

        default:
            return false;
    }
}

static constexpr bool has_convolution(layer_mode mode)
{
    switch (mode) {
        case layer_mode::CONV2D:
        case layer_mode::CONV2D_ADD2D:
        case layer_mode::CONV2D_ADD2D_GAP:
        case layer_mode::CONV2D_PAD2D_MAXPOOL:
        case layer_mode::CONV2D_GAP:
        case layer_mode::CONVDWC:
        case layer_mode::CONV2D_LEAKYRELU:
        case layer_mode::CONV2D_ADD2D_LEAKYRELU:
        case layer_mode::FC:
            return true;

        default:
            return false;
    }
}

static constexpr bool is_gap_dwc(layer_mode mode)
{
    switch (mode) {
        case layer_mode::GAPDWC:
            return true;

        default:
            return false;
    }
}


static std::ostream &operator<<(std::ostream &os, const layer_mode &mode)
{
    os << layer_mode_string(mode);
    return os;
}
