/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2024 Advanced Micro Devices, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __CONVERSIONS_H__
#define __CONVERSIONS_H__



inline float get_magic( int shift ) {
    int magic = (( 127 + 23 + shift ) << 23 ) + ( 1 << 22 );
    return as_float( magic );
}



// to int4 ---------------------------------------------------




// to int8 ---------------------------------------------------

inline v64int8 to_v64int8( v64accfloat in, int shift, bool sign ) {
    v64accfloat v_magic = broadcast_to_v64float( get_magic( -shift ));
    v64acc32 acc = ( v64acc32 )( in + v_magic ) - ( v64acc32 ) v_magic;
    return ssrs( acc, 0, sign );
}
inline v64int8 to_v64int8( v64accfloat in, int shift ) {
    return to_v64int8( in, shift, 1 );
}
inline v64uint8 to_v64uint8( v64accfloat in, int shift ) {
    return ( v64uint8 ) to_v64int8( in, shift, 0 );
}



// to int16 ---------------------------------------------------

inline v32int16 to_v32int16( v32accfloat in, int shift, bool sign ) {
    v32accfloat v_magic = broadcast_to_v32float( get_magic( -shift ));
    v32acc32 acc = ( v32acc32 )( in + v_magic ) - ( v32acc32 ) v_magic;
    return lsrs( acc, 0, sign );
}
inline v32int16 to_v32int16( v32accfloat in, int shift ) {
    return to_v32int16( in, shift, 1 );
}
inline v32uint16 to_v32uint16( v32accfloat in, int shift ) {
    return ( v32uint16 ) to_v32int16( in, shift, 0 );
}



inline v64int16 to_v64int16( v64accfloat in, int shift, bool sign ) {
    v64accfloat v_magic = broadcast_to_v64float( get_magic( -shift ));
    v64acc32 acc = ( v64acc32 )( in + v_magic ) - ( v64acc32 ) v_magic;
    return lsrs( acc, 0, sign );
}
inline v64int16 to_v64int16( v64accfloat in, int shift ) {
    return to_v64int16( in, shift, 1 );
}
inline v64uint16 to_v64uint16( v64accfloat in, int shift ) {
    return ( v64uint16 ) to_v64int16( in, shift, 0 );
}
inline void to_v64int16( v64int16 * restrict ptr, v64accfloat in, int shift, bool sign ) {
    v64accfloat v_magic = broadcast_to_v64float( get_magic( -shift ));
    v64acc32 acc = ( v64acc32 )( in + v_magic ) - ( v64acc32 ) v_magic;
    v32int16 * pOut = ( v32int16* ) ptr;
    pOut[1] = lsrs( extract_v32acc32( acc, 1 ), 0, sign );
    pOut[0] = lsrs( extract_v32acc32( acc, 0 ), 0, sign );
}



// to int32 ---------------------------------------------------
inline v32int32 to_v32int32_via_bfloat16( v32accfloat in, int shift ) {
    v32bfloat16 bf = to_v32bfloat16( in );
    v16int32 y0 = bfloat16_to_int( extract_v16bfloat16( bf, 0 ), shift );
    v16int32 y1 = bfloat16_to_int( extract_v16bfloat16( bf, 1 ), shift );
    return concat( y0, y1 );
}

inline v32int32 to_v32int32_via_int16( v32accfloat in, int shift ) {
    v32accfloat magic_lo = broadcast_to_v32float( get_magic( -shift ));
    v32accfloat magic_hi = broadcast_to_v32float( get_magic( -shift + 16 ));
    v32int16 one = broadcast_to_v32int16( 1 );

    v32accfloat af;
    v32acc32 ai;
    af = in + magic_hi;
    ai = ( v32acc32 ) af - ( v32acc32 ) magic_hi;

    v32int16 out_hi = lsrs( ai, 0 );

    //reverse conversion of high for substraction
    ai = mac_elem_32_narrow( out_hi, one, ( v32acc32 ) magic_hi );
    af = ( v32accfloat ) ai - magic_hi;

    //substract already converted igh part and convert low
    af = in - af + magic_lo;

    ai = ( v32acc32 ) af - ( v32acc32 ) magic_lo;

    //needs to be unsigned since we need one more bit because of unsymetric ranges
    v32uint16 out_lp = ulsrs( ai, 0 );
    v32uint16 out_ln = ulsrs( -ai, 0 );

    //combine all parts to finialize conversion
    v32acc64 acc = mac_elem_32_2( concat( out_lp, out_ln ), concat( one, -one ), lups( out_hi, 16 ));
    return lsrs( acc, 0 );
}

inline v32int32 to_v32int32( v32accfloat in, int shift ) {
    return to_v32int32_via_bfloat16( in, shift );
    //return to_v32int32_via_int16( in, shift);
}


inline v64acc32 to_v64int32_via_bfloat16( v64accfloat in, int shift ) {
    v32bfloat16 bf0 = to_v32bfloat16( extract_v32accfloat( in, 0 ));
    v32bfloat16 bf1 = to_v32bfloat16( extract_v32accfloat( in, 1 ));
    v16acc32 y0 = ( v16acc32 ) bfloat16_to_int( extract_v16bfloat16( bf0, 0 ), shift );
    v16acc32 y1 = ( v16acc32 ) bfloat16_to_int( extract_v16bfloat16( bf1, 1 ), shift );
    v16acc32 y2 = ( v16acc32 ) bfloat16_to_int( extract_v16bfloat16( bf0, 0 ), shift );
    v16acc32 y3 = ( v16acc32 ) bfloat16_to_int( extract_v16bfloat16( bf1, 1 ), shift );
    return concat( y0, y1, y2, y3 );
}

inline v64acc32 to_v64int32_via_int16( v64accfloat in, int shift ) {
    v64accfloat magic_lo = broadcast_to_v64float( get_magic( -shift ));
    v64accfloat magic_hi = broadcast_to_v64float( get_magic( -shift + 16 ));
    v32int16 one = broadcast_to_v32int16( 1 );

    v64accfloat af;
    v64acc32 ai;
    af = in + magic_hi;
    ai = sub(( v64acc32 )af, ( v64acc32 ) magic_hi );

    v32int16 out0_hi = lsrs( extract_v32acc32( ai, 0 ), 0 );
    v32int16 out1_hi = lsrs( extract_v32acc32( ai, 1 ), 0 );

    //reverse conversion of high for substraction
    ai = mac_elem_64( concat( out0_hi, out1_hi ), concat( one, one ), ( v64acc32 )magic_hi );
    af = ( v64accfloat )ai - magic_hi;

    //substract already converted igh part and convert low
    af = in - af + magic_lo;

    ai = sub(( v64acc32 )af, ( v64acc32 )magic_lo );

    //needs to be unsigned since we need one more bit because of unsymetric ranges
    v32uint16 out0_lp = ulsrs( extract_v32acc32( ai, 0 ), 0 );
    v32uint16 out1_lp = ulsrs( extract_v32acc32( ai, 1 ), 0 );
    v32uint16 out0_ln = ulsrs( extract_v32acc32( -ai, 0 ), 0 );
    v32uint16 out1_ln = ulsrs( extract_v32acc32( -ai, 1 ), 0 );

    //combine all parts to finialize conversion
    v32acc64 acc0 = mac_elem_32_2( concat( out0_lp, out0_ln ), concat( one, -one ), lups( out0_hi, 16 ));
    v32acc64 acc1 = mac_elem_32_2( concat( out1_lp, out1_ln ), concat( one, -one ), lups( out1_hi, 16 ));
    return concat(( v32acc32 ) lsrs( acc0, 0 ), ( v32acc32 ) lsrs( acc1, 0 ));
}

inline v64acc32 to_v64int32( v64accfloat in, int shift ) {
    return to_v64int32_via_bfloat16( in, shift );
    //return to_v64int32_via_int16( in, shift);
}



// to float32 ---------------------------------------------------
inline v64accfloat to_v64accfloat( v64int8 in, int shift, bool sign ) {
    v64bfp16ebs8 bfp, eye;
    bfp = insert( bfp, 0, get_expo( 6 ));
    bfp = insert( bfp, 1, get_expo( 6 ));
    bfp = insert( bfp, in );
    eye  = insert( eye,  0, get_expo( -shift ));
    eye  = insert( eye,  1, get_expo( -shift ));
    eye  = insert( eye, sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    return mul_8x8_8x8T( bfp, sign, eye, true );
}
inline v64accfloat to_v64accfloat( v64int8 in, int shift ) {
    return to_v64accfloat( in, shift, 1 );
}
inline v64accfloat to_v64accfloat( v64uint8 in, int shift ) {
    return to_v64accfloat(( v64int8 )in, shift, 0 );
}


inline v64accfloat to_v64accfloat_via_bfp16ebs8( v64int16 in, int shift, bool sign ) {
    v64bfp16ebs8 bfp0, bfp1, eye;
    bfp0 = insert( bfp0, 0, get_expo( 6 ));
    bfp0 = insert( bfp0, 1, get_expo( 6 ));
    bfp1 = insert( bfp1, 0, get_expo( 14 ));
    bfp1 = insert( bfp1, 1, get_expo( 14 ));
    bfp0 = insert( bfp0, ( v64int8 ) shuffle( extract_v32int16( in, 0 ), extract_v32int16( in, 1 ), T8_64x2_lo ));
    bfp1 = insert( bfp1, ( v64int8 ) shuffle( extract_v32int16( in, 0 ), extract_v32int16( in, 1 ), T8_64x2_hi ));
    eye  = insert( eye,  0, get_expo( -shift ));
    eye  = insert( eye,  1, get_expo( -shift ));
    eye  = insert( eye, sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    v64accfloat acc = mul_8x8_8x8T( bfp1, sign,  eye, true );
    return            mac_8x8_8x8T( bfp0, false, eye, true, acc );
}


inline v64accfloat to_v64accfloat_via_bfp16ebs16( v64int16 in, int shift, bool sign ) {
    v128bfp16ebs16 bfpy;
    v128bfp16ebs16_sparse eyel, eyeh;
    bfpy = insert( bfpy, 0, get_expo( 6 - shift ));
    bfpy = insert( bfpy, 1, get_expo( 6 - shift ));
    bfpy = insert( bfpy, 2, get_expo( 6 - shift ));
    bfpy = insert( bfpy, 3, get_expo( 6 - shift ));
    bfpy = insert( bfpy, 1, ( v64int8 ) extract_v32int16( in, 1 ));
    bfpy = insert( bfpy, 0, ( v64int8 ) extract_v32int16( in, 0 ));
    eyel = insert( eyel, 0, get_expo( 0 ));
    eyel = insert( eyel, 1, get_expo( 0 ));
    eyel = insert( eyel, get_sparse( 5 ));
    eyel = insert( eyel, sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    eyeh = insert( eyeh, 0, get_expo( 8 ));
    eyeh = insert( eyeh, 1, get_expo( 8 ));
    eyeh = insert( eyeh, get_sparse( 10 ));
    eyeh = insert( eyeh, sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    v64accfloat acc = mul_8x16_16x8T( bfpy, sign,  eyeh, true );
    return            mac_8x16_16x8T( bfpy, false, eyel, true, acc );
}


inline v64accfloat to_v64accfloat_via_magic( v64int16 in, int shift, bool sign ) {
    v64accfloat v_magic = broadcast_to_v64float( get_magic( -shift ));
    return ( v64accfloat ) ( sups( in, 0, sign ) + ( v64acc32 ) v_magic ) - v_magic;
}

inline v64accfloat to_v64accfloat( v64int16 in, int shift, bool sign ) {
    //return to_v64accfloat_via_bfp16ebs8( in, shift, sign );
    return to_v64accfloat_via_bfp16ebs16( in, shift, sign );
    //return to_v64accfloat_via_magic( in, shift, sign );
}
inline v64accfloat to_v64accfloat( v64int16 in, int shift ) {
    return to_v64accfloat( in, shift, 1 );
}
inline v64accfloat to_v64accfloat( v64uint16 in, int shift ) {
    return to_v64accfloat(( v64int16 )in, shift, 0 );
}


inline v64accfloat to_v64accfloat_via_bfp16ebs16( v16int32 in0, v16int32 in1, v16int32 in2, v16int32 in3, int shift, bool sign ) {
    v128bfp16ebs16 bfpy0, bfpy1;
    v128bfp16ebs16_sparse eyel, eyeh, eyeh2;
    bfpy0 = insert( bfpy0, 0, get_expo( 6  - shift ));
    bfpy0 = insert( bfpy0, 1, get_expo( 6  - shift ));
    bfpy0 = insert( bfpy0, 2, get_expo( 6  - shift ));
    bfpy0 = insert( bfpy0, 3, get_expo( 6  - shift ));
    bfpy1 = insert( bfpy1, 0, get_expo( 14 - shift ));
    bfpy1 = insert( bfpy1, 1, get_expo( 14 - shift ));
    bfpy1 = insert( bfpy1, 2, get_expo( 14 - shift ));
    bfpy1 = insert( bfpy1, 3, get_expo( 14 - shift ));
    bfpy0 = insert( bfpy0, 1, ( v64int8 ) shuffle( in2, in3, T8_64x2_lo ));
    bfpy0 = insert( bfpy0, 0, ( v64int8 ) shuffle( in0, in1, T8_64x2_lo ));
    bfpy1 = insert( bfpy1, 1, ( v64int8 ) shuffle( in2, in3, T8_64x2_hi ));
    bfpy1 = insert( bfpy1, 0, ( v64int8 ) shuffle( in0, in1, T8_64x2_hi ));
    eyel  = insert( eyel,  0, get_expo( 0 ));
    eyel  = insert( eyel,  1, get_expo( 0 ));
    eyel  = insert( eyel,  get_sparse( 5 ));
    eyel  = insert( eyel,  sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    eyeh  = insert( eyeh,  0, get_expo( 16 ));
    eyeh  = insert( eyeh,  1, get_expo( 16 ));
    eyeh  = insert( eyeh,  get_sparse( 10 ));
    eyeh  = insert( eyeh,  sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    eyeh2 = insert( eyeh2, 0, get_expo( 6 ));
    eyeh2 = insert( eyeh2, 1, get_expo( 6 ));
    eyeh2 = insert( eyeh2, get_sparse( 10 ));
    eyeh2 = insert( eyeh2, sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    v64accfloat acc, ac2, ach;
    ac2  = mul_8x16_16x8T( bfpy1, sign, eyeh2, true );
    ach  = mul_8x16_16x8T( bfpy1, sign, eyeh,  true );
    ach  = addmac_8x16_16x8T_conf( bfpy0, false, eyeh, true, ach, ac2, 0, 0, 0, 1 );
    acc  = mul_8x16_16x8T(    bfpy1, false, eyel, true );
    acc  = addmac_8x16_16x8T( bfpy0, false, eyel, true, acc, ac2 );
    acc += ach;
    return acc;
}


inline v64accfloat to_v64accfloat_via_magic( v16int32 in0, v16int32 in1, v16int32 in2, v16int32 in3, int shift, bool sign ) {
    v64accfloat magic_lo = broadcast_to_v64float( get_magic( -shift ));
    v64accfloat magic_hi = broadcast_to_v64float( get_magic( -shift + 16 ));
    v64uint16 in_lo;
    v64int16  in_hi;
    v64int16 one = broadcast_to_v64int16( 1 );
    in_lo = set_v64uint16( 1, ( v32uint16 ) shuffle( in2, in3, T16_32x2_lo ));
    in_lo = insert( in_lo, 0, ( v32uint16 ) shuffle( in0, in1, T16_32x2_lo ));
    in_hi = set_v64int16(  1, ( v32int16  ) shuffle( in2, in3, T16_32x2_hi ));
    in_hi = insert( in_hi, 0, ( v32int16  ) shuffle( in0, in1, T16_32x2_hi ));
    v64accfloat acc_lo = ( v64accfloat ) mac_elem_64( in_lo, one, ( v64acc32 ) magic_lo );
    v64accfloat acc_hi = ( v64accfloat ) mac_elem_64( in_hi, one, ( v64acc32 ) magic_hi );
    return acc_hi - ( magic_hi + magic_lo ) + acc_lo;
}

inline v64accfloat to_v64accfloat( v16int32 in0, v16int32 in1, v16int32 in2, v16int32 in3, int shift, bool sign ) {
    //return to_v64accfloat_via_bfp16ebs16( in0, in1, in2, in3, shift, sign );
    return to_v64accfloat_via_magic( in0, in1, in2, in3, shift, sign );
}
inline v64accfloat to_v64accfloat( v16int32 in0, v16int32 in1, v16int32 in2, v16int32 in3, int shift ) {
    return to_v64accfloat( in0, in1, in2, in3, shift, 1 );
}
inline v64accfloat to_v64accfloat( v16uint32 in0, v16uint32 in1, v16uint32 in2, v16uint32 in3, int shift ) {
    return to_v64accfloat(( v16int32 )in0, ( v16int32 )in1, ( v16int32 )in2, ( v16int32 )in3, shift, 0 );
}


inline v64accfloat to_v64accfloat( v64bfp16ebs8 in ) {
    v64bfp16ebs8 eye;
    eye  = insert( eye,  0, get_expo( 0 ));
    eye  = insert( eye,  1, get_expo( 0 ));
    eye  = insert( eye, sel( broadcast_zero_s8( ), broadcast_s8( 0x40 ), 0x8040201008040201ll ));
    return mul_8x8_8x8T( in, eye );
}






#endif //__CONVERSIONS_H__
