# pylint: skip-file

# Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
# Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# This file contains confidential and proprietary information
# of Xilinx, Inc. and is protected under U.S. and
# international copyright and other intellectual property
# laws.
#
# DISCLAIMER
# This disclaimer is not a license and does not grant any
# rights to the materials distributed herewith. Except as
# otherwise provided in a valid license issued to you by
# Xilinx, and to the maximum extent permitted by applicable
# law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
# WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
# AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
# INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
# (2) Xilinx shall not be liable (whether in contract or tort,
# including negligence, or under any other theory of
# liability) for any loss or damage of any kind or nature
# related to, arising under or in connection with these
# materials, including for any direct, or any indirect,
# special, incidental, or consequential loss or damage
# (including loss of data, profits, goodwill, or any type of
# loss or damage suffered as a result of any action brought
# by a third party) even if such damage or loss was
# reasonably foreseeable or Xilinx had been advised of the
# possibility of the same.
#
# CRITICAL APPLICATIONS
# Xilinx products are not designed or intended to be fail-
# safe, or for use in any application requiring fail-safe
# performance, such as life-support or safety devices or
# systems, Class III medical devices, nuclear facilities,
# applications related to the deployment of airbags, or any
# other applications that could lead to death, personal
# injury, or severe property or environmental damage
# (individually and collectively, "Critical
# Applications"). Customer assumes the sole risk and
# liability of any use of Xilinx products in Critical
# Applications, subject only to applicable laws and
# regulations governing limitations on product liability.
#
# THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
# PART OF THIS FILE AT ALL TIMES.

import numpy as np
import numbers


class DataConverter:
    def __init__( self, rounding=True, bfloat=False, filter_denorm=True, fp8_extend_saturation=False, print_info=False ):
        self.rounding = rounding
        self.bfloat   = bfloat
        self.filter_denorm = filter_denorm
        self.fp8_extend_saturation = fp8_extend_saturation
        if not hasattr( self, "print_info"):
            self.print_info = print_info

    def extend_to_multiple_of(self,arr,N):
        # Calculate the number of elements to add
        padding = (N - len(arr) % N)
        # Extend the array with zeros
        extended_arr = np.pad(arr, (0, padding), mode='constant', constant_values=0)
        return extended_arr

    def reinterpret_cast( self, data, bits_in, bits_out, sgn_out=False, bfloat_in=None, expo_bits_in=8 ):
        if bfloat_in or ( bfloat_in is None and self.bfloat ):
            data = self.collapse_float( data, bits_in, expo_bits_in )
        if bits_in < bits_out:
            data[data<0] += 2**bits_in
            scale_factor = bits_out//bits_in
            if(data.size % scale_factor):
                data = self.extend_to_multiple_of(data,scale_factor)
            data = np.sum( [data[..., p::scale_factor] * 2**( bits_in*p ) for p in range( scale_factor )], 0 )
        elif bits_in > bits_out:
            tmp  = np.zeros( data.shape + ( bits_in//bits_out, ), dtype=np.int_ )
            for p in range( bits_in//bits_out ):
                tmp[..., p] = ( data >> ( bits_out*p )) & ( 2**bits_out-1 )
            data = tmp.reshape( data.shape[:-1] + ( -1, ))
        else:
            data &= 2**bits_out-1
        if sgn_out: data[data>=( 2**( bits_out-1 ))] -= 2**bits_out
        return data

    def pack_shifts( self, shifts, dim, shift_bits=1, subtile_size=2 ):
        shifts_per_byte = 8 // shift_bits
        shape = list( shifts.shape )
        shape[dim+1] = shifts_per_byte
        shape.insert( dim+1, shifts.shape[dim+1] // shifts_per_byte )
        shifts = shifts.reshape( shape )
        def pack( shifts ):
            tmp=0
            for shift in shifts :
                tmp = tmp >> 1 | shift << 7  # shifts in memory are stored in reversed order
            return tmp - 256 * ( tmp >= 128 )
        return  np.apply_along_axis( pack, dim+2, shifts )


    def check_rounding_mode( self, rounding=None, bfloat=None ):
        if rounding is None: rounding = self.rounding
        if rounding == True:
            if bfloat or ( bfloat is None and self.bfloat ):
                rounding = 'half_inf'
            else:
                rounding = 'sym_inf'
        if rounding == False and ( bfloat or ( bfloat is None and self.bfloat )):
            rounding = 'sym_floor'
        return rounding


    def shift_round( self, data, shift, rounding=None, bfloat=None ):
        rounding = self.check_rounding_mode( rounding, bfloat )
        dtype = np.dtype( data.dtype, copy=True )
        data  = data.astype( np.long )
        bits  = dtype.itemsize * 8
        if isinstance( shift, np.ndarray ):
            shift = shift.astype( np.long )
            mask = ( shift >= 63 ) | ( shift <= -bits )
            if np.any( mask ):
                data[ mask] = 0
                shift[mask] = 0
        elif shift >= 63 or shift <= -bits:
            return np.zeros( data.shape, dtype=dtype )
        rshift = np.maximum( 0,  shift )
        lshift = np.maximum( 0, -shift )
        if rounding == 'even':
            data += (( 1 << ( rshift - 1 )) - 1 + (( data >> rshift ) & 1 )) * ( rshift > 0 )
        elif rounding == 'sym_inf':
            data += (( 1 << ( rshift - 1 )) - ( np.sign( data ) < 0 )) * ( rshift > 0 )
        elif rounding == 'sym_zero':
            data += (( 1 << ( rshift - 1 )) - ( np.sign( data ) > 0 )) * ( rshift > 0 )
        elif rounding == 'to_odd':
            data |= (( data & (( 1 << rshift ) - 1 )) != 0 ) << rshift
        elif rounding == 'floor':
            data += (( 1 << rshift ) - 1 )  * ( np.sign( data ) < 0 )  * ( rshift > 0 )
        elif rounding == 'ceil':
            data +=  (( 1 <<  rshift ) - 1 ) * ( np.sign( data ) > 0 )  * ( rshift > 0 )
        elif rounding == 'sym_floor':
            data = data
        elif rounding:
            data +=  ( 1 << ( rshift - 1 )) * ( rshift > 0 )
        if np.any( rshift ): data >>= rshift
        if np.any( lshift ): data <<= lshift
        return data.astype( dtype )


    def srs( self, mat, shift, bits, sgn=True, bfloat=None, ebs=None, expo_bits=8, as_float=False, float_to_int=True, rounding = None, filter_denorm=None ):
        if ebs is not None:
            return self.f2bfp( mat, bits, ebs[0], ebs[1], bfloat=bfloat, expo_bits=expo_bits, as_float=as_float )
        elif bfloat or ( bfloat is None and self.bfloat ):
            return self.f2bf( mat, bits, expo_bits, rounding = rounding, filter_denorm=filter_denorm )
        if float_to_int and not issubclass( mat.dtype.type, numbers.Integral ):
            mat = ( mat * 2**shift ).astype( np.float32 ).getfield( np.int32 )
            exp = ( mat >> 23 ) & 255 
            normal = np.logical_not(exp == 0)
            shift = np.maximum( 0, 23 - (exp - 127 ))
            mant = (( 1 << 23 )*normal + ( mat & 0x7FFFFF )) * np.sign( mat )
            rnd  = self.shift_round( mant, shift ) << shift
            mat  = ( mat + ( rnd - mant ) * np.sign( mat )).getfield( np.float32 )
            #mat = np.round( mat * 2**shift ).astype( np.long )
            shift = 0

        low, high = -2**( bits-1 )*int( sgn ), 2**( bits-int( sgn ))-1
        mat = self.shift_round( mat.astype( np.long ), shift )
        if self.print_info:
          large_value_threshold = int( self.large_value_threshold * high ) if type( self.large_value_threshold ) is float else self.large_value_threshold
          small = np.sum( mat<self.small_value_threshold )
          large = np.sum( mat>large_value_threshold )
          if small > mat.size/10:
            self.log_msg.append( "[INFO]: Reference has {} % values smaller than {}".format( 100.0*small/mat.size, self.small_value_threshold ))
          if large > mat.size/10:
            self.log_msg.append( "[INFO]: Reference has {} % values larger than {}".format( 100.0*large/mat.size, large_value_threshold ))
        if not isinstance( mat, np.ndarray ): mat = np.array( mat )
        mat[mat<low]  = low
        mat[mat>high] = high
        return mat


    def ups( self, mat, shift, bits=32, sgn=True, bfloat=None, ebs=None ):
        if ebs is not None and len( mat )==2:
            return self.bfp2f( mat, bits, ebs[0], ebs[1], bfloat=bfloat )
        elif bfloat or ( bfloat is None and self.bfloat ):
            return mat
        mat = mat.astype( np.long ) << shift
        if bits < 64:
            low, high = -2**( bits-1 )*int( sgn ), 2**( bits-int( sgn ))-1
            mat[mat<low]  = low
            mat[mat>high] = high
        return mat

    def truncate( self, mat, bits, sgn=True, bfloat=None, ebs=1 ):
        if bfloat or ( bfloat is None and self.bfloat ): return mat.astype( np.float32 )
        elif (ebs > 1): return mat.astype( np.float32 )
        elif ( mat.dtype == np.long and bits==64 ): return mat
        mat = mat & ( 2**bits-1 )
        if sgn: mat[mat >= 2**( bits-1 )] -= 2**bits
        return mat



    def f2bf( self, f, bits=16, expo_bits=8, rounding=None, filter_denorm=None ):
        if not isinstance( f, ( np.ndarray, np.number )):
            f = np.array( f, dtype=np.float32 )
        elif f.dtype == np.float64:
            f = f.astype( np.float32 )
        assert f.dtype == np.float32, "type {} not supported for f2bf".format( f.dtype )
        rounding = self.check_rounding_mode( rounding, bfloat=True )
        expo_bits, expo_bias, extend = decode_expo_bits( expo_bits )
        bits_fp = bits + 8 - expo_bits
        bi = f.copy( ).getfield( np.int32 )
        if isinstance( bi, np.number ): bi = np.array( bi )
        if bits_fp < 32:
            expo = ( bi >> 23 ) & 255
            mask = ( expo < 255 ) & (( expo_bits >= 8 ) | ( expo > 127 - expo_bias ))
            bi[ mask ] = self.shift_round( bi[ mask ], 32 - bits_fp, rounding=rounding, bfloat=True ) << ( 32 - bits_fp )
        if expo_bits < 8:
            if isinstance( bi, np.number ): bi = np.array( bi )
            sign = np.sign( bi )
            expo = ( bi >> 23 ) & 255
            mant = bi & 0x7FFFFF
            expo_overflow = 126 + 2**expo_bits + ( extend > 0 ) - expo_bias
            expo_valid_max = expo_overflow - 1
            mant_valid_mask = 0x7FFFFF & ( -1 << ( 32 - bits_fp ))
            mant_valid_max  = 0x7FFFFF & ( -1 << ( 32 - bits_fp + ( extend > 1 )))
            huge = ( expo < 255 ) & (( expo >= expo_overflow ) | (( extend > 1 ) & ( expo == expo_valid_max ) & ( mant >= mant_valid_mask )))
            max_repr_value = ( expo_valid_max << 23 ) | mant_valid_max
            infinity = ( 255 << 23 )
            if rounding == 'floor':
                bi[huge & sign < 0] = ( -1 << 31 ) | infinity
                bi[huge & sign > 0] = max_repr_value
            elif rounding == 'ceil':
                bi[huge & sign < 0] = ( -1 << 31 ) | max_repr_value
                bi[huge & sign > 0] = infinity
            elif rounding == 'sym_floor' or not rounding:
                bi[huge] = ( np.minimum( 0, sign[huge] ) << 31 ) | max_repr_value
            else:
                bi[huge] = ( np.minimum( 0, sign[huge] ) << 31 ) | infinity
            bi[ expo == 0 ] &= -1 << 31
            mask = ( expo <= ( 127 - expo_bias )) & ( expo > 0 )
            if filter_denorm or ( filter_denorm is None and self.filter_denorm ):
                bi[ mask ] &= ( -1 << 31 )
            else:
                shift_dn = 128 - expo_bias - expo[ mask ] + 32 - bits_fp
                mant_dn = self.shift_round(( 1 << 23 ) | ( bi[mask] & (( 1 << 23 ) - 1 )), shift_dn, rounding=rounding, bfloat=True ) << shift_dn
                expo_dn = expo[ mask ]
                expo_dn[ mant_dn == 0 ] = 0
                expo_dn[ mant_dn >= ( 2 << 23 )] += 1
                bi[ mask ] = ( expo_dn << 23 ) + ( mant_dn & (( 1 << 23 ) - 1 )) - ( bi[mask].astype(np.uint32) & ( 1<<31 ) )
            if extend and self.fp8_extend_saturation:
                mask = ( bi & (( 1 << 31 ) - 1 )) == ( 255 << 23 )
                bi[ mask ] = (( 126 + 2**expo_bits + ( extend > 0 ) - expo_bias ) << 23 ) - ( 1 << 32 - bits_fp ) - ( bi[ mask ] & ( 1<<31 ))
        elif filter_denorm or ( filter_denorm is None and self.filter_denorm ):
            bi[( bi & ( 0xFF << 23 ) == 0 ) | ( f.getfield( np.int32 ) & ( 0xFF << 23 ) == 0 )] &= ( -1 << 31 )
        b = bi.getfield( np.float32 )
        return b


    def f2bfp( self, f, bits=8, ebs=1, dim=-1, bfloat=None, expo_bits=8, rounding=None, as_float=False, limit_range=False, m2_0=False, m2_0_rnd=None, check_double_rnd=False, shift_bits=0, subtile_size=2, sgn_mag=False, sparse=False ):
        # print( bits, ebs, dim, rounding, as_float, limit_range, m2_0, m2_0_rnd, check_double_rnd, shift_bits, subtile_size )
        f = f.astype( np.float32 )
        max_shift = ( 2**shift_bits - 1 )
        if dim < 0: dim += len( f.shape )
        if m2_0_rnd is None: m2_0_rnd = self.rounding != 'sym_floor'
        shape = list( f.shape )
        shape[dim] = ebs
        shape.insert( dim, f.shape[dim]//ebs )
        fi = f.reshape( shape ).getfield( np.int32 )
        fis = np.sign( fi )
        fie = ( fi >> 23 ) & 255
        fim = (( fie>0 )<<23 ) | ( fi & (( 1<<23 )-1 ))
        fie = np.maximum( 1, fie )
        fieb = fie 
        if bfloat or ( bfloat is None and self.bfloat ):
            eb, expo_bias, extend = decode_expo_bits( expo_bits )
            if extend == 2:
                fieb = ((( fi & (( 1 << 31 ) - 1 )) + ( 1 << ( 23 - ( bits - eb - 1 )))) >> 23 ) & 255
        me  = np.max( fieb - m2_0 * ( fis*fim == ( -1<<23 )), dim+1 )

        if isinstance( bits, np.ndarray ):
            assert np.all( bits.shape == me.shape ), "[f2bfp]: If bits is ndarray, then shape needs to match shared exponents"
            mbits = np.expand_dims( bits, dim+1 )
        else:
            mbits = bits
        bm_max = ( 1<<( mbits-1 ))-1
        if limit_range:
            shift = 25-mbits + np.mod( shift, mbits-4 ) #TODO is mbits-4 correct?
            me = 127 + np.mod( me-127, mbits-4 )
            fim &= bm_max << shift

        if rounding is None: rounding = self.rounding
        if rounding == True: rounding = 'sym_inf'
        
        shift_shape = shape[:]
        shift_shape[dim+1] = subtile_size
        shift_shape.insert( dim+1, ebs//subtile_size )
        exp_diff = ( np.expand_dims( me, dim+1 ) - fie ).reshape( shift_shape ).getfield( np.int32 )
        min_shift = np.min( exp_diff, axis=dim+2 )
        shift_values =  np.minimum( min_shift, max_shift )
        if bfloat or ( bfloat is None and self.bfloat ):
            #TODO -2 below is not accurate as it depends on the
            me = np.maximum( 0, me - ( 2**eb - 1 - expo_bias - ( not extend )))
            shift = np.expand_dims( me, dim+1 ) - 127 - np.repeat( shift_values, subtile_size, dim+1 )
            # print( "shift:", shift[0, 0, :2] )
            # print( me[0, 0] )
            scale = ( 2.0**np.minimum( 127, -shift )).astype( np.float32 )
            #TODO what's the limit for the common exponent
            bm = self.f2bf( f.reshape( shape ) * scale, bits, expo_bits, rounding=rounding, filter_denorm = False) / scale
            # print( bm[0, 0, 0:2] ) 
        else:
            #assert 0, "We should not be here"
            # print( "my data" )
            # print( me[0,0] )
            shift = np.expand_dims( me, dim+1 ) - fie + 25 - mbits - np.repeat( shift_values, subtile_size, dim+1 )
            # print( shift[0,0,:] )
            bm = self.shift_round( fim, shift,  rounding=rounding )
            bm[bm>bm_max] = bm_max # saturate to bmax
            bm = bm & bm_max
            bm[shift>24] = 0 # prevents sign to become part of mantissa
            bm[fis<0] = (1 << ( mbits-1 ) | bm[fis<0] )
            if check_double_rnd:
                self.f2bfp_check_rnd_value = bm.copy( )
            elif hasattr( self, 'f2bfp_check_rnd_value' ) and self.f2bfp_check_rnd_value is not None:
                if self.f2bfp_check_rnd_value.size == bm.size:
                    mask = self.f2bfp_check_rnd_value != bm
                    assert not np.any( mask ), "f2bfp double rnd diff errors: {}\n1. call:\n{}\n2. call\n{}".format( np.sum( mask ), self.f2bfp_check_rnd_value[mask], bm[mask] )
                    self.f2bfp_check_rnd_value = None
                else:
                    self.log_msg.append( "[Warning]: f2bfp double round check skipped due to size mismatch: {} != {}".format( self.f2bfp_check_rnd_value.size, bm.size ))
                self.f2bfp_check_rnd_value = None
            if sgn_mag:
                bm[bm>=2**( bits-1 )] -= 2**bits
            else:
                bm[bm>=2**( bits-1 )] = 2**( bits-1 ) - bm[bm>=2**( bits-1 )]
        # print( "bm fin", bm[0,0,:] )
        bm = bm.reshape( f.shape )
        bfp = ( bm , me - 256 * ( me >= 128 ), shift_values )
        ret = self.bfp2f( bfp, bits, ebs, dim, bfloat, sgn_mag, subtile_size ) if as_float else bfp
        # if as_float:
        #     print( "output", ret[0,:ebs] )
        return ret

    def bfp2f( self, bfp, bits=8, ebs=1, dim=-1, bfloat=None, sgn_mag=False, subtile_size=2 ):
        bm, me, shift = bfp
        # print( bfloat, self.bfloat, bm[0, :ebs] )
        if bfloat or ( bfloat is None and self.bfloat ):
            return bm
        if dim < 0: dim += len( me.shape )
        shape = list( me.shape )
        shape[dim] *= ebs
        if isinstance( bits, np.ndarray ):
            bits = np.repeat( bits, ebs, axis=dim )
        fis = bm >> ( bits-1 )
        if sgn_mag:
            mag_bm = bm & ( 2**( bits-1 ) -1 )
        else:
            mag_bm = ( bm ^ fis ) - fis
        lz  = bits-2 - np.floor( np.log2( np.maximum( 1, mag_bm ))).astype( int )
        fie = np.maximum( 0, np.repeat( me & 255, ebs, axis=dim ) - np.repeat( shift, subtile_size, dim+1 ).reshape( shape ) - lz )
        fim = self.shift_round( mag_bm, bits - 25 - lz )
        fi  = ( fis<<31 ) | ((( fie<<23 ) | ( fim & (( 1<<23 )-1 ))) * ( mag_bm!=0 ))
        f = fi.getfield( np.float32 )
        # print( np.max( f ))
        return np.reshape( f, shape ).astype( np.float32 )


    def collapse_bfp( self, bfp, bits=8, ebs=1, dim=-1, bfloat=None, expo_bits=8, order=None, shift_bits=0, subtile_size=2, sparse = False, sparse_factor = 1.0, duplicate_ebs = False ):
        
        if sparse == True:
            sparse_factor = sparse_factor       # bm is already compressed
            enc_factor = 1
            spm, bm, me, shift = bfp  # we get the sparse info spm
            shape = list( me.shape )
            me = np.concatenate( ( spm, np.expand_dims(me, dim) ), dim )
        else:
            sparse_factor = 1
            enc_factor = 0
            bm, me, shift = bfp
            shape = list( me.shape )
            
        if duplicate_ebs:
            # only for weights
            # 
            # remap to duplicated ebs = ebs/2
            subtile_size = subtile_size/2
            ebs = ebs/2
            me = np.repeat(me, 2, axis=dim)
            shape = list( me.shape )
            c = np.repeat(shift, 2, axis=dim+1)
            shift = c.reshape((shape[dim-1],shape[dim],-1))
            
        if dim < 0: dim += len( bm.shape )
        # print( bm[0, 0], me[0], shift, bits, ebs, dim )
        if bfloat or ( bfloat is None and self.bfloat ):
            shift = np.expand_dims( me + 256 * ( me < 0 ), dim+1 ) - 127 - np.repeat( shift, subtile_size, dim+1 )
            scale = ( 2.0**np.minimum( 127, -shift.reshape( bm.shape ))).astype( np.float32 )
            # print( "collapse:", bm[0, 0:2], shift[0, 0, 0] )
            bm = self.collapse_float( bm * scale, bits, expo_bits )
            # print( "collapsed:", bm[0, 0:2] )
        
        shape.insert( dim+1, int( np.ceil( ebs * (bits) / 8.0 )))
        
        if bits != 8:
            perm = list( range( len( bm.shape )))
            perm.append( perm.pop( dim ))
            bm = bm.transpose( perm )
            if np.lcm( 8, bits ) in ( 8, bits ):
                bm = self.reinterpret_cast( bm, bits, 8, True, bfloat_in=False )
            else:
                bm = self.reinterpret_cast( bm, bits, np.lcm( 8, bits ),    True, bfloat_in=False )
                bm = self.reinterpret_cast( bm,       np.lcm( 8, bits ), 8, True, bfloat_in=False )
            perm = list( range( len( bm.shape )))
            perm.insert( dim, perm.pop( -1 ))
            bm = bm.transpose( perm )
            
        if shift_bits > 0:
            packed_shifts = self.pack_shifts( shift, dim, shift_bits, subtile_size )
            if sparse == True:
                clps = np.concatenate(( me, packed_shifts, bm.reshape( shape )), dim+1 )
            else:
                clps = np.concatenate(( np.expand_dims( me, dim+1 ), packed_shifts, bm.reshape( shape )), dim+1 )
        else:
            if sparse == True:
                clps = np.concatenate(( me, bm.reshape( shape )), dim+1 )
            else:
                clps = np.concatenate(( np.expand_dims( me, dim+1 ), bm.reshape( shape )), dim+1 )
            
        shape[dim] *= shape.pop( dim+1 ) + 1 + int( np.ceil( ebs / subtile_size * shift_bits / 8 )) + enc_factor*3 # 3B for encoding 24b
        
        if order is not None:
            val=''
            for p in reversed( order ):
                if p.isdigit( ):
                    val = p + val
                elif val:
                    return ( clps.reshape( shape ), order[:-len( val )] + str( int( np.ceil( int( val ) * ( enc_factor*((3*(ebs/4))/8)/(ebs) + sparse_factor * ( (bits / 8.0) + (1.0 / ebs) + shift_bits/( subtile_size*8.0 ))))))) # number of elements * bytes per sample
                else:
                    return ( clps.reshape( shape ), order )
                
        return clps.reshape( shape )
    

    def collapse_float( self, data, bits=16, expo_bits=8 ):
        if not isinstance( data, ( np.ndarray, np.number )):
            data = np.array( data, dtype=np.float32 )
        elif data.dtype == np.float64:
            data = data.astype( np.float32 )
        assert data.dtype == np.float32, "type {} not supported for collapse_float".format( data.dtype )
        expo_bits, expo_bias, extend = decode_expo_bits( expo_bits )
        shift = 32 - ( bits + 8 - expo_bits )
        bi = data.copy( ).getfield( np.int32 )
        if expo_bits < 8:
            expo = ( bi >> 23 ) & 255
            if extend:
                nan = expo == 255
            denorm = ( expo <= ( 127 - expo_bias )) & ( expo > 0 )
            if np.any( denorm ):
                bi[ denorm ] = ((( 1 << 23 ) | ( bi[ denorm ] & (( 1 << 23 ) - 1 ))) >> ( 128 - expo_bias - (( bi[ denorm ] >> 23 ) & 255 ))) - ( bi[ denorm ] & ( -1*2**31 ))
            expo_bias -= 2**( expo_bits - 1 ) - 1
            if expo_bias != 0:
                expo = ( bi >> 23 ) & 255
                bi[( expo > 0 ) & ( expo < 255 )] += expo_bias << 23
            bi = (( bi & ~(( 1 << 30 ) - 1 )) >> ( 8 - expo_bits )) | ( bi & (( 1 << ( 22 + expo_bits )) - 1 ))
            if extend == 2:
                bi[nan] = ( bi[nan] > 0 ) * 2**( bits - 1 + shift ) - 1
            elif extend:
                bi[(( bi >> shift ) & (( 1 << ( bits )) - 1 )) == 0 ] = 0
                bi[nan] = -1 << ( 23 + expo_bits )
        if shift > 0:
            bi >>= shift
        return bi

    # bits_out must be a valid data type 
    def reinterpret_cast_with_lcm(self, data_in, bits_in, bits_out, sgn_out=False, bfloat_in=None, expo_bits_in=8):
        if np.lcm( bits_out, bits_in ) in ( bits_out, bits_in ):
            data_out = self.reinterpret_cast( data_in, bits_in, bits_out, sgn_out , bfloat_in )
        else:
            data_out = self.reinterpret_cast( data_in, bits_in, np.lcm( bits_out, bits_in ),  True, bfloat_in=False )
            data_out = self.reinterpret_cast( data_out.copy(), np.lcm( bits_out, bits_in ), bits_out, True, bfloat_in=False )
        return data_out


type_expo_bits_default = {"float": 8, "float32": 8, "float16": 5, "float8": 4.607, "float6": 2.301, "float4": 2, "bfloat16": 8}

            
def decode_expo_bits( eb ):
    if type( eb ) == str:
        eb = type_expo_bits_default[eb]
    if np.mod( eb, 1 ) == 0:
        return ( eb, 2**( eb - 1 ) - 1, False )
    bits = int( eb )
    eb   = int( round( 1e3 * np.mod( eb, 1 )))
    ext  = int( round( eb / 300.0 ))
    bias = eb - 300 * ext
    assert bits < 8, "Extended exponent range and exponent bias only supported with exponents of less then 8 bits"
    return ( bits, bias, ext )



def encode_expo_bits( bits=0, bias=None, extend=False ):
    if bias is None and not extend:
        return bits
    else:
        if bias is None:
            bias = 2**( bits - 1 ) - 1
        return bits + 0.3 * extend + 1e-3 * bias


