// Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
////////////////////////////////////////////////////////////////////////

#pragma once

#include "conv_xint8/super_kernel_types.h"
#include "adf_utils_header.hpp"

#define DECL_ITR3D(iterator_name)\
      int32_t iterator_name##_incr0 ;  \
      int32_t iterator_name##_incr1 ;  \
      uint32_t iterator_name##_wrap0 ;  \
      uint32_t iterator_name##_wrap1 ;  \
      int32_t iterator_name##_incr2 ;

#define DECL_ITR2D(iterator_name)\
      int32_t iterator_name##_incr0 ;  \
      uint32_t iterator_name##_wrap0 ;  \
      int32_t iterator_name##_incr1 ;

struct maxpool2d_lcp_t{

    uint16_t outer_loop_count;

    DECL_ITR3D(iterator)

    // Changing below to int16
    int16_t inc_Ky;
    int16_t inc_A_0;
    // changing below to int8
    int8_t shft_0;
    int8_t shft_1;
    int8_t shft_2;
    int8_t shft_3;
    int8_t shft_4;
    int8_t shft_5;
    int8_t shft_6;
    int8_t shft_7;
    int8_t shfl_0; //Check
    int8_t shfl_1; //Check
    uint16_t offset;
    uint8_t downshift_result;
};

struct pad2d_lcp_t{
    uint8_t sv_x;
    uint8_t sv_y;
    uint8_t nifms;
    uint8_t col_pad_mask;
    int8_t pad_val;
    uint8_t pad_top;
    uint8_t pad_bot;
    uint8_t pad_left;
    uint8_t pad_right;
};

struct mul2d_lcp_t {
    // No idea what needs to go here...
    uint32_t multiplication_factor;
};

struct add2d_lcp_t {
    uint8_t zero_ifm2;               // Force IFM2 to zero for cascade start
    uint8_t shift_in;                // bit shift for the lower precision input while loading into acc
    uint8_t shift_out;               // Output shift before writing acc
    uint8_t in0_sign;                // 0: treat input-0 as unsigned, 1: treat input-0 as signed
    int offset;                  // Offset inside IFM buffer from where input-0 will be picked.
    int num_inputs;              // Length of tensors in multiples of 32 (H*W*N*C/32)
    act_t act;                   // 0: Linear, 1: ReLU
    DECL_ITR3D(itr_left)         // Iterator for IFM2 traversale (8-bit precision input)
    int ofm_len;
    uint8_t upshift_fused;
    uint8_t upshift_nonfused;
    uint8_t downshift_eltw_res;
};

struct globalavgpool2d_lcp_t {
    uint32_t div_factor;
    uint16_t ifm_x_eff;
    uint8_t div_shift;
    casc_mode_t casc_mode; // new casc mode RTP, this is updated in wrapper after setup based on column

    uint16_t step_Ci;
    uint16_t step_Co;

    uint16_t inner_loop_count;
    uint16_t outer_loop_count;

    DECL_ITR2D(iterator_inner)
};

struct dwc_lcp_t {
    // layer params
    int32_t kernel_type;
    alignas(int32_t) ReluType act_type;
    int32_t tile_ohg;
    int32_t tile_owg;
    uint32_t iter_tile_ocg;
    int32_t bias_offset;
    int32_t ker_h;
    int32_t ker_w;
    int32_t str_w;
    int32_t str_h;
    uint32_t hloop_wloop_itercnt;
    int32_t ifm_offset;
    int32_t shift_bias;
    int32_t shift_cut;
    int32_t ofm_offset;
    int32_t wts_offset;
    int32_t trans_offset;
    uint32_t transpose_dwc;

    int32_t incAI1;
    int32_t incAI2;
    int32_t incAO1;
    int32_t incAO2;
    //int32_t incAO3;
    int32_t numAL1; // for ifm add_3d_byte
    int32_t incAL1; // for ifm add_3d_byte
    int32_t numAL2; // for ifm add_3d_byte
    int32_t incAL2; // for ifm add_3d_byte
    int32_t incAL3; // for ifm add_3d_byte
    int32_t incBI;  // for wgt line jump
    int32_t numB1;  // for wgt add_2d_byte
    int32_t incB1;  // for wgt add_2d_byte
    int32_t incB2;  // for wgt add_2d_byte
    int32_t numC;   // for bias add_2d_byte
    int32_t incC1;  // for bias add_2d_byte
    int32_t incC2;  // for bias add_2d_byte
    int32_t numALO1; // for ofm add_3d_byte
    int32_t incALO1; // for ofm add_3d_byte
    int32_t numALO2; // for ofm add_3d_byte
    int32_t incALO2; // for ofm add_3d_byte
    int32_t incALO3; // for ofm add_3d_byte
    int32_t outer_loop;
    int32_t inner_loop;
    int32_t ofm_c;
    int32_t ofm_w;
    int32_t ofm_h;
    int32_t ifm_w_align;
    int32_t ifm_h;
    int32_t ker_w_align;
    int32_t tile_ocg;
    int32_t iter_ocg;
    int32_t h_loop;
    int32_t w_loop;
    int32_t iter_cnt;
    int32_t ofm_size;

    //for diff shared ifm use different increments
    int32_t incALO2_h0;
    int32_t incALO2_h1;

    bool    is_transpose;
};


struct avgpool2d_lcp_t{
    uint16_t outer_loop_count;
    uint16_t ofm_len;
    uint16_t offset;
    int16_t inc_Ky;
    int16_t inc_A_0;

    DECL_ITR3D(iterator)

    uint8_t Kx_g;
    uint8_t Ky_g;
    uint8_t Ci_g;

    // changing below to int8
    int8_t shft_0;
    int8_t shft_1;
    int8_t shft_2;
    int8_t shft_3;
    int8_t shft_4;
    int8_t shft_5;
    int8_t shft_6;
    int8_t shft_7;
    int8_t shfl_0; //Check
    int8_t shfl_1; //Check

    uint8_t nifms;
    using DivFactorType = uint16_t;

    DivFactorType div_factor; //The fixed point representation of 1/4 or 1/9 based on 2x2 or 3x3 filter.
    uint8_t  div_shift;
};

// Nearest Neighbour (NN)
struct resize_nn_lcp_t
{
    float step_w;
    float step_h;
    float off_step_h; //for partitioned interpolation, which might be a RTP
    float off_step_w; //for partitioned interpolation, which might be a RTP
    uint16_t ifm_ysize;
    uint16_t ifm_xsize;
    uint16_t ofm_ysize;
    uint16_t ofm_xsize;
    uint16_t nifms;
    uint16_t nbyte;
    uint16_t step_Yi;
    uint8_t conner_align;
    uint8_t upshift_result;
};

struct resize_bilinear_lcp_t
{
    uint16_t loop_count;
    uint16_t loop_count_x_pad;
    uint16_t loop_count_y_pad;
    int16_t  step_Ci;
    int16_t  step_Yi;
    int16_t  step_Yo;
    int16_t  incr_Yi;
    int16_t  step_to_last_Xi_eff;
    int16_t  step_to_last_Yi;
    int8_t   shift_pad_x_0;
    int8_t   shift_pad_x_1;
    int16_t  incr_pad_x;
    DECL_ITR2D(iterator_input);
    DECL_ITR2D(iterator_output);
};

struct leakyrelu_lcp_t
{
    // leakyrelu_kernel_params_t from conv/ox8/super_kernel_types.h
    int8_t shift_out;
    int8_t shift_alpha;
    int16_t alpha;
};

using arch_params_lcp_t = super_kernel_params_t;

struct dependent_conv2d_params_lcp_t
{
    pad2d_lcp_t pad2d;
    maxpool2d_lcp_t maxpool;
};

struct dependendent_gap_params_lcp_t
{
    globalavgpool2d_lcp_t gap;
};

union op_dependent_params_lcp_t
{
    dependent_conv2d_params_lcp_t conv2d;
    dependendent_gap_params_lcp_t gap;
};

union op_dependent_mnv3_params_lcp_t
{
    dwc_lcp_t dwc;
    globalavgpool2d_lcp_t gap;
};

struct conv_lcp_t {
    int8_t shift_out;
    uint8_t out_mode;
    uint8_t upshift_fused = 0;
    uint8_t upshift_nonfused = 0;
    uint8_t downshift_eltw_res = 0;
    int8_t shift_bias_init;
    int8_t shift_psum_in;
    int8_t shift_psum_out;
    int8_t shift_lrelu_out;
    layer_mode op_mode;
    ReluType run_time_act = ReluType::NoRelu;
    ReluType run_time_act_fused = ReluType::NoRelu;
    enum conv_type conv_type;
    uint8_t in0_sign;
    uint8_t ofm_offset_packed;
    int ifm_sign;  //  IFM sign upper bits
    adf_utils::image_padding ifm_padding;
    int8_t ofm_shift_biased;
    int8_t shift_alpha_lrelu;
    int8_t dwc_shift_out;
    int16_t shift_out16;
    int16_t shift_leaky;
    int16_t leaky_alpha;
    int8_t multiplication_factor = 0;
    int8_t conv_after_gap = 0;
};

