/*
    Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
    Copyright (C) 2022 - 2024 Advanced Micro Devices, Inc. All rights reserved.

    This file contains confidential and proprietary information
    of Xilinx, Inc. and is protected under U.S. and
    international copyright and other intellectual property
    laws.

    DISCLAIMER
    This disclaimer is not a license and does not grant any
    rights to the materials distributed herewith. Except as
    otherwise provided in a valid license issued to you by
    Xilinx, and to the maximum extent permitted by applicable
    law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
    WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
    AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
    BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
    INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
    (2) Xilinx shall not be liable (whether in contract or tort,
    including negligence, or under any other theory of
    liability) for any loss or damage of any kind or nature
    related to, arising under or in connection with these
    materials, including for any direct, or any indirect,
    special, incidental, or consequential loss or damage
    (including loss of data, profits, goodwill, or any type of
    loss or damage suffered as a result of any action brought
    by a third party) even if such damage or loss was
    reasonably foreseeable or Xilinx had been advised of the
    possibility of the same.

    CRITICAL APPLICATIONS
    Xilinx products are not designed or intended to be fail-
    safe, or for use in any application requiring fail-safe
    performance, such as life-support or safety devices or
    systems, Class III medical devices, nuclear facilities,
    applications related to the deployment of airbags, or any
    other applications that could lead to death, personal
    injury, or severe property or environmental damage
    (individually and collectively, "Critical
    Applications"). Customer assumes the sole risk and
    liability of any use of Xilinx products in Critical
    Applications, subject only to applicable laws and
    regulations governing limitations on product liability.

    THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
    PART OF THIS FILE AT ALL TIMES.                       */

#ifndef __FP32_VMULT_H__
#define __FP32_VMULT_H__

#include "kernel_helpers.h"

ALWAYS_INLINE v16accfloat fp_mul( v16accfloat a, v16accfloat b ) {
    v16int32 one    = broadcast_s32( 1<<23 );
    v16int32 mask_m = broadcast_s32(( 1<<23 )-1 );
    v16int32 mask_e = bneg( mask_m );
    v16int32 bfa = band(( v16int32 ) a, mask_e );              // 1x MV
    v16int32 ma  = band(( v16int32 ) a, mask_m );              // 1x MV
    v16int32 bfb = band(( v16int32 ) b, mask_e );              // 1x MV
    v16int32 mb  = band(( v16int32 ) b, mask_m );              // 1x MV
    v16int32 ob  = ( v16int32 ) to_v32bfloat16( mul_elem_32(( v32bfloat16 ) bfa, ( v32bfloat16 ) bfb ));                 //        1x VEC, 1x ST
    //v16int32 mo  = lsrs( mul_elem_16( bor(ma, one), mb ), 23 );        // 5x MV, 4x VEC, 1x ST  or  1x MV, 4x VEC, 5x ST  opt  3x MV, 2x VEC, 1x ST
    v16int32 mo  = lsrs( mul_elem_16( ma, mb ), 23 );        // 5x MV, 4x VEC, 1x ST  or  1x MV, 4x VEC, 5x ST  opt  3x MV, 2x VEC, 1x ST
    v32accfloat o;
    o =    msc_elem_32(( v32bfloat16 ) shuffle( ob, T16_32x2_hi ), broadcast_bfloat16( 1.0 ),    set_v32accfloat( 0, ( v16accfloat ) bor( ob, ma )));                         // 1x MV
    o = addmsc_elem_32(( v32bfloat16 ) shuffle( ob, T16_32x2_hi ), broadcast_bfloat16( 1.0 ), o, set_v32accfloat( 0, ( v16accfloat ) bor( ob, mb )));
    //o -= set_v32accfloat( 0, (v16accfloat) sel( broadcast_zero_s32(), ob, lt( mo, one )));  // 3x MV, 1x VEC
    o += set_v32accfloat( 0, ( v16accfloat ) bor( ob, mo ));                     // 3x MV, 1x VEC
    return extract_v16accfloat( o, 0 );
}

ALWAYS_INLINE v16accfloat fp_negmul( v16accfloat a, v16accfloat b ) {
    v16int32 one    = broadcast_s32( 1<<23 );
    v16int32 mask_m = broadcast_s32(( 1<<23 )-1 );
    v16int32 mask_e = bneg( mask_m );
    v16int32 bfa = band(( v16int32 ) a, mask_e );              // 1x MV
    v16int32 ma  = band(( v16int32 ) a, mask_m );              // 1x MV
    v16int32 bfb = band(( v16int32 ) b, mask_e );              // 1x MV
    v16int32 mb  = band(( v16int32 ) b, mask_m );              // 1x MV
    v16int32 ob  = ( v16int32 ) to_v32bfloat16( negmul_elem_32(( v32bfloat16 ) bfa, ( v32bfloat16 ) bfb ));                 //        1x VEC, 1x ST
    //v16int32 mo  = lsrs( mul_elem_16( bor(ma, one), mb ), 23 );        // 5x MV, 4x VEC, 1x ST  or  1x MV, 4x VEC, 5x ST  opt  3x MV, 2x VEC, 1x ST
    v16int32 mo  = lsrs( mul_elem_16( ma, mb ), 23 );        // 5x MV, 4x VEC, 1x ST  or  1x MV, 4x VEC, 5x ST  opt  3x MV, 2x VEC, 1x ST
    v32accfloat o;
    o =    msc_elem_32(( v32bfloat16 ) shuffle( ob, T16_32x2_hi ), broadcast_bfloat16( 1.0 ),    set_v32accfloat( 0, ( v16accfloat ) bor( ob, ma )));                         // 1x MV
    o = addmsc_elem_32(( v32bfloat16 ) shuffle( ob, T16_32x2_hi ), broadcast_bfloat16( 1.0 ), o, set_v32accfloat( 0, ( v16accfloat ) bor( ob, mb )));
    //o -= set_v32accfloat( 0, (v16accfloat) sel( broadcast_zero_s32(), ob, lt( mo, one )));  // 3x MV, 1x VEC
    o += set_v32accfloat( 0, ( v16accfloat ) bor( ob, mo ));                     // 3x MV, 1x VEC
    return extract_v16accfloat( o, 0 );
}

ALWAYS_INLINE v16accfloat fp_addmsc( v16accfloat a, v16accfloat b, v16accfloat acc1, v16accfloat acc2 ) {
    v16int32 one    = broadcast_s32( 1<<23 );
    v16int32 mask_m = broadcast_s32(( 1<<23 )-1 );
    v16int32 mask_e = bneg( mask_m );
    v16int32 bfa = band(( v16int32 ) a, mask_e );              // 1x MV
    v16int32 ma  = band(( v16int32 ) a, mask_m );              // 1x MV
    v16int32 bfb = band(( v16int32 ) b, mask_e );              // 1x MV
    v16int32 mb  = band(( v16int32 ) b, mask_m );              // 1x MV
    v16int32 ob  = ( v16int32 ) to_v32bfloat16( negmul_elem_32(( v32bfloat16 ) bfa, ( v32bfloat16 ) bfb ));                 //        1x VEC, 1x ST
    //v16int32 mo  = lsrs( mul_elem_16( bor(ma, one), mb ), 23 );        // 5x MV, 4x VEC, 1x ST  or  1x MV, 4x VEC, 5x ST  opt  3x MV, 2x VEC, 1x ST
    v16int32 mo  = lsrs( mul_elem_16( ma, mb ), 23 );        // 5x MV, 4x VEC, 1x ST  or  1x MV, 4x VEC, 5x ST  opt  3x MV, 2x VEC, 1x ST
    v32accfloat o;
    o =    msc_elem_32(( v32bfloat16 ) shuffle( ob, T16_32x2_hi ), broadcast_bfloat16( 1.0 ),    set_v32accfloat( 0, ( v16accfloat ) bor( ob, ma )));                         // 1x MV
    o = addmsc_elem_32(( v32bfloat16 ) shuffle( ob, T16_32x2_hi ), broadcast_bfloat16( 1.0 ), o, set_v32accfloat( 0, ( v16accfloat ) bor( ob, mb )));
    //o -= set_v32accfloat( 0, (v16accfloat) sel( broadcast_zero_s32(), ob, lt( mo, one )));  // 3x MV, 1x VEC
    o += set_v32accfloat( 0, ( v16accfloat ) bor( ob, mo ));                     // 3x MV, 1x VEC
    o += set_v32accfloat( 0, acc1 );
    o += set_v32accfloat( 0, acc2 );
    return extract_v16accfloat( o, 0 );
}

ALWAYS_INLINE v32accfloat fp_mul( v32accfloat a, v32accfloat b ) {
    /*
    v16int32 one    = broadcast_s32( 1<<23 );
    v16int32 mask_m = broadcast_s32(( 1<<23 )-1 );
    v16int32 mask_e = bneg( mask_m );
    v32int32 bfa = band(( v32int32 ) a, mask_e );
    v32int32 ma  = band(( v32int32 ) a, mask_m );
    v32int32 bfb = band(( v32int32 ) b, mask_e );
    v32int32 mb  = band(( v32int32 ) b, mask_m );
    v32int32 ob  = ( v32int32 ) to_v64bfloat16( mul_elem_64(( v64bfloat16 ) bfa, ( v64bfloat16 ) bfb ));
    v32int32 mo  = lsrs( mul_elem_32( ma, mb ), 23 );
    v32bfloat16 ob16 = ( v32bfloat16 ) shuffle( extract_v16int32( ob, 0 ), extract_v16int32( ob, 1 ), T16_32x2_hi );
    v32accfloat o;
    o =    msc_elem_32( ob16, broadcast_to_v32bfloat16( 1.0 ),    ( v32accfloat ) bor( ob, ma ));
    o = addmsc_elem_32( ob16, broadcast_to_v32bfloat16( 1.0 ), o, ( v32accfloat ) bor( ob, mb ));
    o += ( v32accfloat ) bor( ob, mo );
    return o;
    */
    v16int32 one    = broadcast_s32( 1<<23 );
    v16int32 mask_m = broadcast_s32(( 1<<23 )-1 );
    v16int32 mask_e = bneg( mask_m );
    v16int32 bfa0 = band(( v16int32 ) extract_v16accfloat( a, 0 ), mask_e );
    v16int32 bfa1 = band(( v16int32 ) extract_v16accfloat( a, 1 ), mask_e );
    v16int32 ma0  = band(( v16int32 ) extract_v16accfloat( a, 0 ), mask_m );
    v16int32 ma1  = band(( v16int32 ) extract_v16accfloat( a, 1 ), mask_m );
    v16int32 bfb0 = band(( v16int32 ) extract_v16accfloat( b, 0 ), mask_e );
    v16int32 bfb1 = band(( v16int32 ) extract_v16accfloat( b, 1 ), mask_e );
    v16int32 mb0  = band(( v16int32 ) extract_v16accfloat( b, 0 ), mask_m );
    v16int32 mb1  = band(( v16int32 ) extract_v16accfloat( b, 1 ), mask_m );
    v16int32 ob0  = ( v16int32 ) to_v32bfloat16( mul_elem_32(( v32bfloat16 ) bfa0, ( v32bfloat16 ) bfb0 ));
    v16int32 ob1  = ( v16int32 ) to_v32bfloat16( mul_elem_32(( v32bfloat16 ) bfa1, ( v32bfloat16 ) bfb1 ));
    v32acc64 macc = mul_elem_32(( v32int16 )  shuffle( ma0, ma1, T16_32x2_hi ), ( v32int16 )  shuffle( mb0, mb1, T16_32x2_hi ));
    macc = mac_elem_32_conf(( v32uint16 ) shuffle( ma0, ma1, T16_32x2_lo ), ( v32int16 )  shuffle( mb0, mb1, T16_32x2_hi ), macc, 0, 1, 0, 0 );
    macc = mac_elem_32(( v32int16 )  shuffle( ma0, ma1, T16_32x2_hi ), ( v32uint16 ) shuffle( mb0, mb1, T16_32x2_lo ), macc );
    macc = mac_elem_32_conf(( v32uint16 ) shuffle( ma0, ma1, T16_32x2_lo ), ( v32uint16 ) shuffle( mb0, mb1, T16_32x2_lo ), macc, 0, 1, 0, 0 );
    v16int32 mo0  = lsrs( extract_v16acc64( macc, 0 ), 23 );
    v16int32 mo1  = lsrs( extract_v16acc64( macc, 1 ), 23 );
    v32bfloat16 ob16 = ( v32bfloat16 ) shuffle( ob0, ob1, T16_32x2_hi );
    v32accfloat o;
    o =    msc_elem_32( ob16, broadcast_to_v32bfloat16( 1.0 ),    insert( set_v32accfloat( 0, ( v16accfloat ) bor( ob0, ma0 )), 1, ( v16accfloat ) bor( ob1, ma1 )));
    o = addmsc_elem_32( ob16, broadcast_to_v32bfloat16( 1.0 ), o, insert( set_v32accfloat( 0, ( v16accfloat ) bor( ob0, mb0 )), 1, ( v16accfloat ) bor( ob1, mb1 )));
    o += insert( set_v32accfloat( 0, ( v16accfloat ) bor( ob0, mo0 )), 1, ( v16accfloat ) bor( ob1, mo1 ));
    return o;
}
ALWAYS_INLINE v64accfloat fp_mul( v64accfloat a, v64accfloat b ) {
    return concat( fp_mul( extract_v32accfloat( a, 0 ), extract_v32accfloat( b, 0 )), fp_mul( extract_v32accfloat( a, 1 ), extract_v32accfloat( b, 1 )));
}


#endif //__FP32_VMULT_H__
