// Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
////////////////////////////////////////////////////////////////////////

#pragma once

#include "conv_lcp.h"
#include "adf_utils_header.hpp"
#include <algorithm>
#include <numeric>
#include <cmath>
#include <cassert>
#include <array>
#include <optional>

#include <stdint.h>

#ifndef AIE2P_USE_SHIFTX
#define AIE2P_USE_SHIFTX 0
#endif
#ifndef __TXNRT__
#if __AIE_ARCH__ == 20
static constexpr uint32_t width_pad = 2;
static constexpr uint32_t extra_pad = 2;
static constexpr uint32_t align_ldst_bytes = 32;
#elif __AIE_ARCH__ == 21
static constexpr uint32_t width_pad = 6;
static constexpr uint32_t extra_pad = 2;
static constexpr uint32_t align_ldst_bytes = 64;
#endif
#else
static constexpr uint32_t width_pad = 6;
static constexpr uint32_t extra_pad = 2;
static constexpr uint32_t align_ldst_bytes = 64;

#endif

static uint32_t calculate_ofmlen(
    const conv_lcp_t &params_mllib,
    const adf_utils::tensor_HCWc& ofmsv_dim,
    const adf_utils::tensor_2d &kstride,
    const std::optional<maxpool_params_t> maxpool_params = std::nullopt

){
    uint32_t ofm_len = ofmsv_dim.size();
    if (maxpool_params.has_value()) {
        // ofm_len with the default values is the ofm_len for maxpool_params, not conv_parmas.
        // What follows is a reverse engineering of the ofm_len pre-pool, post-conv
        // Note: the + is a best guess at a formula for deriving the padding from the conv.
        uint32_t conv_x = (ofmsv_dim.W * maxpool_params.value().stride) + 4;
        uint32_t conv_y = (ofmsv_dim.H * maxpool_params.value().stride) + 1;

        ofm_len = ofmsv_dim.channels() * conv_x * conv_y;
    }
    if (params_mllib.op_mode == layer_mode::FC) {
        // Don't know the logic here.
        ofm_len /= 2;
    }

    if (params_mllib.op_mode == layer_mode::CONVDWC) {
        ofm_len = ofmsv_dim.size();
    }
    return ofm_len;
}

// Handle the special cases
//
// FC: xsize is always 1
// 1. If the number of channels is less than 8, then the xsize is reduced by a factor of 8/c_inner
// 2. Additional padding with reshape
// 3. Additional padding with no reshape
//
// TODO: handle more special for other layers
static uint32_t calculate_ox(
    layer_mode op_mode,
    const adf_utils::tensor_HCWc& ifmsv_dim,
    const adf_utils::tensor_HCWc& ofmsv_dim,
    const adf_utils::tensor_2d &ksize,
    const adf_utils::tensor_2d &kstride,
    bool has_fused_maxpool,
    const fused_dw_params_t &fused_dw_params = {0}
)
{
    if (op_mode == layer_mode::FC) {
        // TODO: add comment about why it is always 1
        return 1;
    }
    else if (op_mode == layer_mode::CONVDWC) {
        // TODO: verify handling for fused DWC
        int ox_fused = adf_utils::round_up(ofmsv_dim.W* kstride.x + fused_dw_params.ksize.x - kstride.x, 8);
        return ox_fused;
    }
    else if (!has_fused_maxpool) {
        // TODO: add handling for fused DWC
        // TODO: add handling for standalone DWC
        return ofmsv_dim.W;
    }
    else {
        uint32_t ifm_xsize = ifmsv_dim.W;

        if (ifmsv_dim.c_inner < 8) {
            unsigned ratio = 8 / ifmsv_dim.c_inner;

            // The first layer has a reshaped optimization
            ifm_xsize = ((ifm_xsize - 8) / ratio) - 1;
        }
        else if (ksize.x == 3) {
            // Remove width pad
            ifm_xsize = ifm_xsize - width_pad;
        }

        return (ifm_xsize - ksize.x + 1) / kstride.x;
    }
}



//Calculating the oy for conv depending on mode
//for example for convdwcfusion do a separate calcualtion
static uint32_t calculate_oy(
    layer_mode op_mode,
    const adf_utils::tensor_HCWc& ifmsv_dim,
    const adf_utils::tensor_HCWc& ofmsv_dim,
    const adf_utils::tensor_2d &ksize,
    const adf_utils::tensor_2d &kstride,
    bool has_fused_maxpool,
    const fused_dw_params_t &fused_dw_params = {0}
)
{
    uint32_t oy =  ofmsv_dim.H;
    if (has_fused_maxpool) {
       oy = (ifmsv_dim.H - ksize.y + kstride.y) / kstride.y;
    }
    else if (op_mode == layer_mode::CONVDWC) {
        // Oy == Iy for CONVDWC because the Conv is always 1x1
        return ofmsv_dim.H * kstride.x + fused_dw_params.ksize.x - kstride.x;
    }

    return oy;
}

//Modify/Pad ky if IC < 64 for convdwcfusion
static uint32_t calculate_fused_ksize(
    const adf_utils::tensor_HCWc& ifm_dim
)
{
    //uint32_t ky = ifmsv_dim.C >= 8 ? 1 : ifmsv_dim.C >= 4 ? 2 : 4;
    uint32_t ky = 1;

    // TODO can we get this parameter from the tiling_model?
    //if (ifm_dim.channels() < LayerConv::min_loop_iterations) {
    //    if (ifm_dim.c_inner != 8) {
    //        throw std::runtime_error("c_inner value assumed to be 8 for conv+dwc ksize derivation");
    //    }
    //    ky = adf_utils::ceil_div(8, ifm_dim.C);
    //}
    return ky;
}

//Setting up padding parameters when convdwcfusion is performed
static void fused_padding_params(
    arch_params_lcp_t &params,
    const conv_lcp_t &params_mllib,
    const adf_utils::tiling_iterations &iters,
    const adf_utils::tensor_HCWc &ifm_tensor_dim,
    const adf_utils::tensor_HCWc &ifmsv_dim,
    const adf_utils::tensor_HCWc &ofmsv_dim,
    const adf_utils::tensor_2d &ksize,
    const adf_utils::tensor_2d &kstride,
    const adf_utils::dynamic_shape_t ds)
{

    //TODO : Is this true, are these just the input sizes ?
    params.sv_x = ifmsv_dim.W;
    params.sv_y = ifmsv_dim.H;
    params.nifms = ofmsv_dim.channels();
    params.spatial_split_w_ld = 0;

    //TODO Validate correct split selection
    params.spatial_split_h_ld = adf_utils::get_h_split(ds);
    params.dwc_stride_w = ofmsv_dim.W * kstride.x;
    params.dwc_stride_h = ofmsv_dim.H * kstride.y;
    params.pixel_w = ifm_tensor_dim.W;
    params.pixel_h = ifm_tensor_dim.H;

    params.pad_left = params_mllib.ifm_padding.left;
    //TODO validate pad right
    // params.pad_right = fused_dw_params.ifm_padding.right + ifmsv_dim.W - LayerConv::unaligned_input_dim_calculation(ofmsv_dim.W, fused_dw_params.ksize.x, kstride.x);
    params.pad_right = params_mllib.ifm_padding.right + ifmsv_dim.W - (ofmsv_dim.W * kstride.x + ksize.x - kstride.x);
    params.pad_top = params_mllib.ifm_padding.top;
    params.pad_bottom = params_mllib.ifm_padding.bottom;

    params.ddr2mt_c = iters.OC;
    params.mt2aie_h = iters.H;
    params.mt2aie_w = iters.W;
}


static std::tuple<ReluType, ReluType> get_conv_and_fused_activation(const conv_lcp_t &params_mllib)
{
    // Check activation type support
    ReluType conv_activation = params_mllib.run_time_act;
    ReluType fused_activation = params_mllib.run_time_act_fused;
    {
        // Activation performed by convolution and fused operator
        // NoRelu: there is a Relu operator but it is not used.
        // None: there is not even a Relu operator available in the code
        const std::array valid{ReluType::NoRelu, ReluType::Relu, ReluType::Relu6, ReluType::Leaky_Prelu};
        for (ReluType r : {conv_activation, fused_activation}) {
            const bool unsupported = std::find(valid.begin(), valid.end(), r) == valid.end();
            if (unsupported) {
                throw std::runtime_error("Activation type in run_time_act(_fused) is not supported");
            }
        }
    }
    return {conv_activation, fused_activation};
}
void print_kernel_params(arch_params_lcp_t &param)
{
    printf("kernel_params = { \n");
    printf("    .act_type_1            = %d, \n", param.act_type_1          );
    printf("    .tile_ocg              = %d, \n", param.tile_ocg            );
    printf("    .str_w                 = %d, \n", param.str_w               );
    printf("    .shift_bias_1          = %d, \n", param.shift_bias_1        );
    printf("    .shift_out_1           = %d, \n", param.shift_out_1         );
    printf("    .ifm_sign              = %d, \n", param.ifm_sign            );
    printf("    .shift_psum_in         = %d, \n", param.shift_psum_in       );
    printf("    .shift_psum_out        = %d, \n", param.shift_psum_out      );
    printf("    .shift_out16           = %d, \n", param.shift_out16         );
    printf("    .shift_leaky           = %d, \n", param.shift_leaky         );
    printf("    .leaky_alpha           = %d, \n", param.leaky_alpha         );
    printf("    .step_align            = %d, \n", param.step_align          );
    printf("    .shfl                  = %d, \n", param.shfl                );
    printf("    .shft                  = %d, \n", param.shft                );
    printf("    .incAI1                = %d, \n", param.incAI1              );
    printf("    .numAL1                = %d, \n", param.numAL1              );
    printf("    .incAL1                = %d, \n", param.incAL1              );
    printf("    .numAL2                = %d, \n", param.numAL2              );
    printf("    .incAL2                = %d, \n", param.incAL2              );
    printf("    .incAL3                = %d, \n", param.incAL3              );
    printf("    .numAO1                = %d, \n", param.numAO1              );
    printf("    .incAO1                = %d, \n", param.incAO1              );
    printf("    .numAO2                = %d, \n", param.numAO2              );
    printf("    .incAO2                = %d, \n", param.incAO2              );
    printf("    .incAO3                = %d, \n", param.incAO3              );
    printf("    .numB                  = %d, \n", param.numB                );
    printf("    .incB1                 = %d, \n", param.incB1               );
    printf("    .incB2                 = %d, \n", param.incB2               );
    printf("    .incS0                 = %d, \n", param.incS0               );
    printf("    .numCS1                = %d, \n", param.numCS1              );
    printf("    .incCS1                = %d, \n", param.incCS1              );
    printf("    .numCS2                = %d, \n", param.numCS2              );
    printf("    .incCS2                = %d, \n", param.incCS2              );
    printf("    .incCS3                = %d, \n", param.incCS3              );
    printf("    .inner_loop            = %d, \n", param.inner_loop          );
    printf("    .outer_loop            = %d, \n", param.outer_loop          );
    printf("    .psum0                 = %d, \n", param.psum0               );
    printf("    .psum1                 = %d, \n", param.psum1               );
    printf("    .conv_out              = %d, \n", param.conv_out            );
    printf("    .num_ifm_depth_iter    = %d, \n", param.num_ifm_depth_iter  );
    printf("    .conv_type             = %d, \n", param.conv_type           );
    printf("    .stride_bits           = %d, \n", param.stride_bits         );
    printf("    .hdr_len               = %d, \n", param.hdr_len             );
    printf("    .wts_offset            = %d, \n", param.wts_offset          );
    printf("    .wts_sv_len            = %d, \n", param.wts_sv_len          );
    printf("    .ofm_len               = %d, \n", param.ofm_len             );
    printf("    .num_iter              = %d, \n", param.num_iter            );
    printf("    .wrapper_iter          = %d, \n", param.wrapper_iter        );
    printf("    .out_mode              = %d, \n", param.out_mode            );
    printf("    .op_mode               = %d, \n", param.op_mode             );
    printf("    .stride2_exec_type     = %d, \n", param.stride2_exec_type   );
    printf("    .global_num_cols       = %d, \n", param.global_num_cols     );
    printf("    .ifm_len               = %d, \n", param.ifm_len             );
    printf("    .spatial_split_w_ld    = %d, \n", param.spatial_split_w_ld  );
    printf("    .spatial_split_h_ld    = %d, \n", param.spatial_split_h_ld  );
    printf("    .dwc_stride_w          = %d, \n", param.dwc_stride_w        );
    printf("    .dwc_stride_h          = %d, \n", param.dwc_stride_h        );
    printf("    .spacial_step_h        = %d, \n", param.spacial_step_h      );
    printf("    .pixel_h               = %d, \n", param.pixel_h             );
    printf("    .pixel_w               = %d, \n", param.pixel_w             );
    printf("    .pad_val               = %d, \n", param.pad_val             );
    printf("    .pad_top               = %d, \n", param.pad_top             );
    printf("    .pad_bottom            = %d, \n", param.pad_bottom          );
    printf("    .pad_left              = %d, \n", param.pad_left            );
    printf("    .pad_right             = %d, \n", param.pad_right           );
    printf("    .sv_x                  = %d, \n", param.sv_x                );
    printf("    .sv_y                  = %d, \n", param.sv_y                );
    printf("    .nifms                 = %d, \n", param.nifms               );
    printf("    .ddr2mt_c              = %d, \n", param.ddr2mt_c            );
    printf("    .mt2aie_h              = %d, \n", param.mt2aie_h            );
    printf("    .mt2aie_w              = %d, \n", param.mt2aie_w            );         
    printf("} \n");
}

static arch_params_lcp_t convert_to_arch_params(
    unsigned layer_idx,
    const conv_lcp_t &params_mllib,
    const adf_utils::tiling_iterations &iters,
    const adf_utils::tensor_HCWc &ifmsv_dim,
    const adf_utils::tensor_HCWc &ofmsv_dim,
    const adf_utils::tensor_2d &ksize,
    const adf_utils::tensor_2d &kstride,
    //const adf_utils::dynamic_shape_t &ds,
    std::optional<maxpool_params_t> maxpool_params = std::nullopt,
    std::optional<fused_dw_params_t> fused_dw_params = std::nullopt,
    std::optional<const adf_utils::tensor_HCWc> ifm_tensor_dim = std::nullopt)
{
    const bool has_fused_maxpool = maxpool_params.has_value();
    fused_dw_params_t dw_params = {0};
    dw_params = fused_dw_params.has_value() ? fused_dw_params.value() : dw_params;

    //calculating ofm_sizes through helpers here, needed for fused cases
    uint32_t ofm_xsize = calculate_ox(params_mllib.op_mode, ifmsv_dim, ofmsv_dim, ksize, kstride, has_fused_maxpool, dw_params);
    uint32_t ofm_ysize = calculate_oy(params_mllib.op_mode, ifmsv_dim, ofmsv_dim, ksize, kstride, has_fused_maxpool, dw_params);

    adf_utils::tensor_2d ksize_used = ksize;
    if ( params_mllib.op_mode == layer_mode::CONVDWC ) {
        adf_utils::tensor_2d ksize_fused = { .x = 1, .y = calculate_fused_ksize(ifmsv_dim)};
        ksize_used = ksize_fused;
    }

    adf_utils::tensor_2d kstride_fused = { .x = 1, .y = 1};
    adf_utils::tensor_2d kstride_used = params_mllib.op_mode == layer_mode::CONVDWC ? kstride_fused  : kstride;

    constexpr unsigned int BIAS_MULTIPLE = 2;

    uint32_t height_iter = iters.H;
    uint32_t width_iter = iters.W;
    uint32_t depth_iter;

    // TODO: generalize this check for all non-convolution layers
    if (is_gap_dwc(params_mllib.op_mode))
        depth_iter = iters.H * iters.W;
    else if (has_convolution(params_mllib.op_mode))
        depth_iter = iters.IC;
    else
        depth_iter = 1;

    uint32_t channel_iter = iters.OC;
    uint32_t ofm_len = calculate_ofmlen(params_mllib, ofmsv_dim, kstride, maxpool_params);

    uint32_t ifm = ifmsv_dim.channels();
    ifm = ifm <= 8? 8 : ifm;    // min() This may be an incorrect assumption, but it's my best guess for why layer0 is not 4.

    unsigned num_ofm_ch_scaled = ofmsv_dim.C;
    unsigned ifm_scaled = ifmsv_dim.C;
    if(params_mllib.op_mode == layer_mode::FC) {
        // I don't know the logic behind FC having different values :(
        num_ofm_ch_scaled /= 2;
        ifm_scaled *= 16;
        // what is this and does this need to mult by 8 now ?
        //ifm *= 8;
        ifm *= 4;

        // If you're hitting these errors - You probably are trying to run a different network
        // Remove them ONLY after reviewing the FC workaround above.
        unsigned check_ofm_fc = 2;
        unsigned check_ifm_fc = 16;
        if (num_ofm_ch_scaled != check_ofm_fc) throw std::runtime_error("NotWellImplemented: Calculation for FC OFM Channels.");
        if (ifm_scaled != check_ifm_fc)        throw std::runtime_error("NotWellImplemented: Calculation for FC IFM Channels.");
    }

    if (ofm_xsize == 0) throw std::runtime_error("ofm_xsize must be greater than zero");
    if (ofm_ysize == 0) throw std::runtime_error("ofm_ysize must be greater than zero");
    if (ksize_used.x < 1) throw std::runtime_error("filter x size must be greater than zero");
    if (ksize_used.y < 1) throw std::runtime_error("filter y size must be greater than zero");
    if (kstride_used.x < 1) throw std::runtime_error("filter x stride must be greater than zero");
    if (kstride_used.y < 1) throw std::runtime_error("filter y stride must be greater than zero");
    if (height_iter <= 0) throw std::runtime_error("height_iter must be greater than zero");
    if (width_iter <= 0) throw std::runtime_error("width_iter must be greater than zero");
    if (depth_iter <= 0) throw std::runtime_error("depth_iter must be greater than zero");
    if (channel_iter <= 0) throw std::runtime_error("channel_iter must be greater than zero");
    if (ofm_len <= 0) throw std::runtime_error("ofm_len must be greater than zero");
    if (ifm_scaled <= 0) throw std::runtime_error("ifm_scaled must be greater than zero");
    if (num_ofm_ch_scaled <= 0) throw std::runtime_error("num_ofm_ch_scaled must be greater than zero");

    arch_params_lcp_t param = {};


    //Keep underneath type of kernel params here
    uint16_t oyp         = ofm_ysize;
    uint32_t oxp         = ofm_xsize;
    uint16_t num_ofm_ch  = (num_ofm_ch_scaled) * 8;

    uint32_t fc_mode = (params_mllib.conv_type == conv_type::CONV2D_FC1) ||
                       (params_mllib.conv_type == conv_type::CONV2D_FC2);

    //kernel granuliarities are ic=8, oc=16
    uint16_t tile_icg = ifm_scaled;
    param.tile_ocg = num_ofm_ch_scaled / 2;

    //setting up helper parameters for creating kernel params
    const int batch = 1;
    int align_kx = (batch >= 4 ? 1 : 8 / batch);
    int oxg = fc_mode ? 2 : (oxp * kstride_used.x / 8);
    int oyg = oyp;
    int kx = ksize_used.x;
    int ky = ksize_used.y;
    int sx = kstride_used.x;
    int sy = kstride_used.y;

    //int conv_ifm_w = oxg * 8 + ((kx - 1 + align_kx - 1) / align_kx) * align_kx;
    
    // else case is added to enable dimesions with IFM W non-multiple of 8
    int conv_ifm_w;
    if(ifmsv_dim.W % 8 == 0)
        conv_ifm_w = oxg * 8 + ((kx - 1 + align_kx - 1) / align_kx) * align_kx;
    else
        conv_ifm_w = oxg * 8;
    
    int step_xn = 32;
    int step_so = (sx == 2 && batch < 4 ? 1 : 2) * step_xn;
    int step_si = (sx == 2 && batch >= 4 ? 2 : 1) * step_xn;
    int step_co = (oxg * 64 * batch) >> (sx == 2);

    int step_ci = fc_mode ? 32 : conv_ifm_w * 8 * batch;

    int step_yo = step_co * param.tile_ocg * 2;
    int step_yi = step_ci * tile_icg;

    int incAI_next, incAI_reset;
    //end of helper setup
    //setting up kernel internal parameters
    param.str_w = kstride_used.y;
    param.incAI1 = step_si;
    param.shfl = (sx == 2 && batch < 4 ? (batch == 1 ? 8 : 16) : 64);
    param.shft = (sx == 2 && batch < 4 ? param.shfl : 11);
    incAI_next = step_si;
    incAI_reset = -step_si;

    param.numAL1 = kx - 1;
    param.incAL1 = incAI_reset + 8 * batch;
    param.numAL2 = ky - 1;
    param.incAL2 = incAI_reset - 8 * batch * param.numAL1 + step_yi;
    param.incAL3 = incAI_reset - 8 * batch * param.numAL1 - step_yi * param.numAL2 + step_ci;

    //step_align is calculated differently when using fifo in conv
    param.step_align = fc_mode ? 2 : (param.incAL1 & 15 ? 0 : (param.incAL1 & 31 ? 1 : 2));

    if (fc_mode)
        param.incAL3 = 0;

    param.numAO1 = param.tile_ocg - 1;
    param.incAO1 = -step_yi;
    param.numAO2 = oxg * batch - 1;
    param.incAO2 = param.incAO1 + 2 * step_si;
    param.incAO3 = step_yi * (sy - 1) - 2 * step_si * param.numAO2;

    param.incS0 = (sx == 2 && batch < 4 ? 0 : step_xn);

    // need to change something here ?
    // TODO: generalize this computation to take into account fused maxpool
    param.numCS1 = param.numAO1;
    param.incCS1 = (step_co - step_so) * 1 + step_xn + 32 * (params_mllib.conv_type == conv_type::CONV2D_7x7S2_LYR1);
    param.numCS2 = param.numAO2;
    param.incCS2 = fc_mode ? -32 * tile_icg : (step_co - step_yo - 32 * (params_mllib.conv_type == conv_type::CONV2D_7x7S2_LYR1)) * 1 + step_xn;
    param.incCS3 = step_xn + 32 * (params_mllib.conv_type == conv_type::CONV2D_7x7S2_LYR1);

    param.numB = param.tile_ocg - 1;
    param.incB1 = 0;
    param.incB2 = -kx * ky * tile_icg * param.tile_ocg * 128;

    param.inner_loop = ky * kx * tile_icg;
    param.outer_loop = fc_mode ? 4*2 : oyg * oxg * param.tile_ocg;
    //end setting up kernel internal parameters

    //layer common parameters
    param.shift_bias_1   = params_mllib.shift_bias_init;
    param.shift_out_1    = params_mllib.shift_out;
    param.ifm_sign       = params_mllib.ifm_sign;
    param.shift_psum_in  = params_mllib.shift_psum_in;
    param.shift_psum_out = params_mllib.shift_psum_out;

    std::tie(param.act_type_1, std::ignore) = get_conv_and_fused_activation(params_mllib);

    // Copy over some things required by conv2d for now - this will help us to completely get rid of conv2d_params_t later
    param.num_ifm_depth_iter = depth_iter;
    param.conv_type          = params_mllib.conv_type;
    param.out_mode           = params_mllib.out_mode;

    param.stride_bits        = std::log2(kstride_used.x);
    param.hdr_len            = num_ofm_ch * BIAS_MULTIPLE;
    param.wts_offset         = 0;
    param.ofm_len            = ofm_len;
    param.wts_sv_len         = (ifm * num_ofm_ch) + (2 * 4 * num_ofm_ch);
    param.op_mode            = params_mllib.op_mode;
    param.conv_out           = 0;

    param.num_iter           = uint16_t(width_iter) *
                               uint16_t(height_iter) *
                               uint16_t(channel_iter) *
                               uint16_t(depth_iter);
    bool fused_elew = params_mllib.op_mode == layer_mode::CONV2D_ADD2D || params_mllib.op_mode == layer_mode::CONV2D_ADD2D_GAP;
    param.wrapper_iter = (fused_elew && depth_iter > 13) ? 2 * depth_iter : param.num_iter;

    const bool is_first_layer_conv = (param.conv_type == conv_type::CONV2D_7x7S2_LYR1);

#ifdef STRIDE2_OPT
    assert(batch == 1);
    const unsigned stride2_kernel_select = 4;
    param.stride2_exec_type = (param.stride_bits && !is_first_layer_conv) ? stride2_kernel_select : 0;
#endif
    // populating the DWC + Conv fused parameters here
    //if (params_mllib.op_mode == layer_mode::CONVDWC ) {
    //    if ( !fused_dw_params.has_value( ) || !ifm_tensor_dim.has_value( ))
    //        throw std::runtime_error( "fused_dw_params and ifm_tensor_dim needs to be passed for fused CONV+DWC" );
    //if ( ifm_tensor_dim.has_value( ))
    //    fused_padding_params(param, params_mllib, iters, ifm_tensor_dim.value( ), ifmsv_dim, ofmsv_dim, ksize_used, kstride, ds);
    //}

    param.act_type_1 = params_mllib.run_time_act_fused;
    param.shift_out16 = params_mllib.shift_out16;
    param.shift_leaky = params_mllib.shift_leaky + params_mllib.shift_out;
    param.leaky_alpha = params_mllib.leaky_alpha;
    param.upshift_elw_ifm1 = params_mllib.upshift_fused;
    param.upshift_elw_ifm2 = params_mllib.upshift_nonfused;
    param.downshift_eltw_res = params_mllib.downshift_eltw_res;
    print_kernel_params(param);

    return param;
}

arch_params_lcp_t compute_conv_kernel_params(int Xis, int Cis, int Yis, int Ky, int Kx, 
                int Xos, int Cos, int Yos, int Sx, int Sy, int8_t shift, int elw_shift_ifm1,
                int elw_shift_ifm2, int elw_shift_ofm, int16_t lrelu_alpha, int16_t lrelu_shift) {
    adf_utils::tensor_HCWc ifmsv_dim;
    adf_utils::tensor_HCWc ofmsv_dim;
    adf_utils::tensor_2d filter;
    adf_utils::tensor_2d stride;
    adf_utils::tiling_iterations iters;
    int const c_gran = 8;
    int const min_inner_loop_range = 8;
    /* NOTE: minimum ci_gran for kernel is 8
     * The minimum inner loop range is 8 
     * Kx * ky *  (Cis / ci_gran) >= 8
     */
    ifmsv_dim.c_inner = c_gran;
    ifmsv_dim.W = Xis;
    ifmsv_dim.C = Cis / c_gran;
    ifmsv_dim.H = Yis;
    assert((Ky*Kx*Cis) >= min_inner_loop_range);
    /*
     * The outer loop iterates in the order of ox, oy and oc
     */
    ofmsv_dim.c_inner = c_gran;
    ofmsv_dim.W = Xos;
    ofmsv_dim.C = Cos / c_gran;
    ofmsv_dim.H = Yos;
    // Set filter
    filter.x = Kx;
    filter.y = Ky;
    // Set stride
    stride.x = Sx;
    stride.y = Sy;          
    conv_lcp_t conv1x1_lcp_params = {
        .shift_out = shift,
        .out_mode = 0,
        .upshift_fused = static_cast <uint8_t>( 8 - elw_shift_ifm1),
        .upshift_nonfused =static_cast <uint8_t>( 8 - elw_shift_ifm2),
        .downshift_eltw_res =static_cast <uint8_t>( 8 - elw_shift_ofm),
        .op_mode = layer_mode::CONV2D_LEAKYRELU,
        .run_time_act = ReluType::Leaky_Prelu,
        .run_time_act_fused = ReluType::Leaky_Prelu,
        .conv_type = conv_type::CONV2D_REGULAR,
        .in0_sign = 0,
        .ifm_sign = 1,
        .ifm_padding = {
            .top = 0,
            .left = 0,
            .bottom = 0,
            .right = 0
        },
        .ofm_shift_biased = 0,
        .shift_alpha_lrelu = 0,
        .dwc_shift_out = 0,
        .shift_out16 = 0,
        .shift_leaky = lrelu_shift,
        .leaky_alpha = lrelu_alpha,

        .conv_after_gap = 0
    };
    arch_params_lcp_t conv1x1_kernel_params = convert_to_arch_params(0, conv1x1_lcp_params, iters, ifmsv_dim, ofmsv_dim, filter, stride);
    return conv1x1_kernel_params;
}


