import numbers
import sys
import typing
import numpy as np
import ctypes
# from kernel_lib.python.named_list import *
# from kernel_lib.python.data_generator.generate_data_helpers import GenerateDataHelpers

log2 = np.log2
all = lambda *x: np.all( x )
any = lambda *x: np.any( x )
range = np.arange

pmax = max
pmin = min


class DataConverter:
    def __init__( self, rounding=True, bfloat=False, filter_denorm=True, fp8_extend_saturation=False, print_info=False ):
        self.rounding = rounding
        self.bfloat   = bfloat
        self.filter_denorm = filter_denorm
        self.fp8_extend_saturation = fp8_extend_saturation
        if not hasattr( self, "print_info"):
            self.print_info = print_info

    def extend_to_multiple_of(self,arr,N):
        # Calculate the number of elements to add
        padding = (N - len(arr) % N)
        # Extend the array with zeros
        extended_arr = np.pad(arr, (0, padding), mode='constant', constant_values=0)
        return extended_arr

    def reinterpret_cast( self, data, bits_in, bits_out, sgn_out=False, bfloat_in=None, expo_bits_in=8 ):
        if bfloat_in or ( bfloat_in is None and self.bfloat ):
            data = self.collapse_float( data, bits_in, expo_bits_in )
        if bits_in < bits_out:
            data[data<0] += 2**bits_in
            scale_factor = bits_out//bits_in
            if(data.size % scale_factor):
                data = self.extend_to_multiple_of(data,scale_factor)
            data = np.sum( [data[..., p::scale_factor] * 2**( bits_in*p ) for p in range( scale_factor )], 0 )
        elif bits_in > bits_out:
            tmp  = np.zeros( data.shape + ( bits_in//bits_out, ), dtype=np.int_ )
            for p in range( bits_in//bits_out ):
                tmp[..., p] = ( data >> ( bits_out*p )) & ( 2**bits_out-1 )
            data = tmp.reshape( data.shape[:-1] + ( -1, ))
        else:
            data &= 2**bits_out-1
        if sgn_out: data[data>=( 2**( bits_out-1 ))] -= 2**bits_out
        return data

    def pack_shifts( self, shifts, dim, shift_bits=1, subtile_size=2 ):
        shifts_per_byte = 8 // shift_bits
        shape = list( shifts.shape )
        shape[dim+1] = shifts_per_byte
        shape.insert( dim+1, shifts.shape[dim+1] // shifts_per_byte )
        shifts = shifts.reshape( shape )
        def pack( shifts ):
            tmp=0
            for shift in shifts :
                tmp = tmp >> 1 | shift << 7  # shifts in memory are stored in reversed order
            return tmp - 256 * ( tmp >= 128 )
        return  np.apply_along_axis( pack, dim+2, shifts )


    def check_rounding_mode( self, rounding=None, bfloat=None ):
        if rounding is None: rounding = self.rounding
        if rounding == True:
            if bfloat or ( bfloat is None and self.bfloat ):
                rounding = 'half_inf'
            else:
                rounding = 'sym_inf'
        if rounding == False and ( bfloat or ( bfloat is None and self.bfloat )):
            rounding = 'sym_floor'
        return rounding


    def shift_round( self, data, shift, rounding=None, bfloat=None ):
        rounding = self.check_rounding_mode( rounding, bfloat )
        dtype = np.dtype( data.dtype, copy=True )
        data  = data.astype( np.long )
        bits  = dtype.itemsize * 8
        if isinstance( shift, np.ndarray ):
            shift = shift.astype( np.long )
            mask = ( shift >= 63 ) | ( shift <= -bits )
            if np.any( mask ):
                data[ mask] = 0
                shift[mask] = 0
        elif shift >= 63 or shift <= -bits:
            return np.zeros( data.shape, dtype=dtype )
        rshift = np.maximum( 0,  shift )
        lshift = np.maximum( 0, -shift )
        if rounding == 'even':
            data += (( 1 << ( rshift - 1 )) - 1 + (( data >> rshift ) & 1 )) * ( rshift > 0 )
        elif rounding == 'sym_inf':
            data += (( 1 << ( rshift - 1 )) - ( np.sign( data ) < 0 )) * ( rshift > 0 )
        elif rounding == 'sym_zero':
            data += (( 1 << ( rshift - 1 )) - ( np.sign( data ) > 0 )) * ( rshift > 0 )
        elif rounding == 'to_odd':
            data |= (( data & (( 1 << rshift ) - 1 )) != 0 ) << rshift
        elif rounding == 'floor':
            data += (( 1 << rshift ) - 1 )  * ( np.sign( data ) < 0 )  * ( rshift > 0 )
        elif rounding == 'ceil':
            data +=  (( 1 <<  rshift ) - 1 ) * ( np.sign( data ) > 0 )  * ( rshift > 0 )
        elif rounding == 'sym_floor':
            data = data
        elif rounding:
            data +=  ( 1 << ( rshift - 1 )) * ( rshift > 0 )
        if np.any( rshift ): data >>= rshift
        if np.any( lshift ): data <<= lshift
        return data.astype( dtype )


    def srs( self, mat, shift, bits, sgn=True, bfloat=None, ebs=None, expo_bits=8, as_float=False, float_to_int=True, rounding = None, filter_denorm=None ):
        if ebs is not None:
            return self.f2bfp( mat, bits, ebs[0], ebs[1], bfloat=bfloat, expo_bits=expo_bits, as_float=as_float )
        elif bfloat or ( bfloat is None and self.bfloat ):
            return self.f2bf( mat, bits, expo_bits, rounding = rounding, filter_denorm=filter_denorm )
        if float_to_int and not issubclass( mat.dtype.type, numbers.Integral ):
            mat = ( mat * 2**shift ).astype( np.float32 ).getfield( np.int32 )
            exp = ( mat >> 23 ) & 255 
            normal = np.logical_not(exp == 0)
            shift = np.maximum( 0, 23 - (exp - 127 ))
            mant = (( 1 << 23 )*normal + ( mat & 0x7FFFFF )) * np.sign( mat )
            rnd  = self.shift_round( mant, shift ) << shift
            mat  = ( mat + ( rnd - mant ) * np.sign( mat )).getfield( np.float32 )
            #mat = np.round( mat * 2**shift ).astype( np.long )
            shift = 0

        low, high = -2**( bits-1 )*int( sgn ), 2**( bits-int( sgn ))-1
        mat = self.shift_round( mat.astype( np.long ), shift )
        if self.print_info:
          large_value_threshold = int( self.large_value_threshold * high ) if type( self.large_value_threshold ) is float else self.large_value_threshold
          small = np.sum( mat<self.small_value_threshold )
          large = np.sum( mat>large_value_threshold )
          if small > mat.size/10:
            self.log_msg.append( "[INFO]: Reference has {} % values smaller than {}".format( 100.0*small/mat.size, self.small_value_threshold ))
          if large > mat.size/10:
            self.log_msg.append( "[INFO]: Reference has {} % values larger than {}".format( 100.0*large/mat.size, large_value_threshold ))
        if not isinstance( mat, np.ndarray ): mat = np.array( mat )
        mat[mat<low]  = low
        mat[mat>high] = high
        return mat


    def ups( self, mat, shift, bits=32, sgn=True, bfloat=None, ebs=None ):
        if ebs is not None and len( mat )==2:
            return self.bfp2f( mat, bits, ebs[0], ebs[1], bfloat=bfloat )
        elif bfloat or ( bfloat is None and self.bfloat ):
            return mat
        mat = mat.astype( np.long ) << shift
        if bits < 64:
            low, high = -2**( bits-1 )*int( sgn ), 2**( bits-int( sgn ))-1
            mat[mat<low]  = low
            mat[mat>high] = high
        return mat

    def truncate( self, mat, bits, sgn=True, bfloat=None, ebs=1 ):
        if bfloat or ( bfloat is None and self.bfloat ): return mat.astype( np.float32 )
        elif (ebs > 1): return mat.astype( np.float32 )
        elif ( mat.dtype == np.long and bits==64 ): return mat
        mat = mat & ( 2**bits-1 )
        if sgn: mat[mat >= 2**( bits-1 )] -= 2**bits
        return mat



    def f2bf( self, f, bits=16, expo_bits=8, rounding=None, filter_denorm=None ):
        if not isinstance( f, ( np.ndarray, np.number )):
            f = np.array( f, dtype=np.float32 )
        elif f.dtype == np.float64:
            f = f.astype( np.float32 )
        assert f.dtype == np.float32, "type {} not supported for f2bf".format( f.dtype )
        rounding = self.check_rounding_mode( rounding, bfloat=True )
        expo_bits, expo_bias, extend = decode_expo_bits( expo_bits )
        bits_fp = bits + 8 - expo_bits
        bi = f.copy( ).getfield( np.int32 )
        if isinstance( bi, np.number ): bi = np.array( bi )
        if bits_fp < 32:
            expo = ( bi >> 23 ) & 255
            mask = ( expo < 255 ) & (( expo_bits >= 8 ) | ( expo > 127 - expo_bias ))
            bi[ mask ] = self.shift_round( bi[ mask ], 32 - bits_fp, rounding=rounding, bfloat=True ) << ( 32 - bits_fp )
        if expo_bits < 8:
            if isinstance( bi, np.number ): bi = np.array( bi )
            sign = np.sign( bi )
            expo = ( bi >> 23 ) & 255
            mant = bi & 0x7FFFFF
            expo_overflow = 126 + 2**expo_bits + ( extend > 0 ) - expo_bias
            expo_valid_max = expo_overflow - 1
            mant_valid_mask = 0x7FFFFF & ( -1 << ( 32 - bits_fp ))
            mant_valid_max  = 0x7FFFFF & ( -1 << ( 32 - bits_fp + ( extend > 1 )))
            huge = ( expo < 255 ) & (( expo >= expo_overflow ) | (( extend > 1 ) & ( expo == expo_valid_max ) & ( mant >= mant_valid_mask )))
            max_repr_value = ( expo_valid_max << 23 ) | mant_valid_max
            infinity = ( 255 << 23 )
            if rounding == 'floor':
                bi[huge & sign < 0] = ( -1 << 31 ) | infinity
                bi[huge & sign > 0] = max_repr_value
            elif rounding == 'ceil':
                bi[huge & sign < 0] = ( -1 << 31 ) | max_repr_value
                bi[huge & sign > 0] = infinity
            elif rounding == 'sym_floor' or not rounding:
                bi[huge] = ( np.minimum( 0, sign[huge] ) << 31 ) | max_repr_value
            else:
                bi[huge] = ( np.minimum( 0, sign[huge] ) << 31 ) | infinity
            bi[ expo == 0 ] &= -1 << 31
            mask = ( expo <= ( 127 - expo_bias )) & ( expo > 0 )
            if filter_denorm or ( filter_denorm is None and self.filter_denorm ):
                bi[ mask ] &= ( -1 << 31 )
            else:
                shift_dn = 128 - expo_bias - expo[ mask ] + 32 - bits_fp
                mant_dn = self.shift_round(( 1 << 23 ) | ( bi[mask] & (( 1 << 23 ) - 1 )), shift_dn, rounding=rounding, bfloat=True ) << shift_dn
                expo_dn = expo[ mask ]
                expo_dn[ mant_dn == 0 ] = 0
                expo_dn[ mant_dn >= ( 2 << 23 )] += 1
                bi[ mask ] = ( expo_dn << 23 ) + ( mant_dn & (( 1 << 23 ) - 1 )) - ( bi[mask].astype(np.uint32) & ( 1<<31 ) )
            if extend and self.fp8_extend_saturation:
                mask = ( bi & (( 1 << 31 ) - 1 )) == ( 255 << 23 )
                bi[ mask ] = (( 126 + 2**expo_bits + ( extend > 0 ) - expo_bias ) << 23 ) - ( 1 << 32 - bits_fp ) - ( bi[ mask ] & ( 1<<31 ))
        elif filter_denorm or ( filter_denorm is None and self.filter_denorm ):
            bi[( bi & ( 0xFF << 23 ) == 0 ) | ( f.getfield( np.int32 ) & ( 0xFF << 23 ) == 0 )] &= ( -1 << 31 )
        b = bi.getfield( np.float32 )
        return b


    def f2bfp( self, f, bits=8, ebs=1, dim=-1, bfloat=None, expo_bits=8, rounding=None, as_float=False, limit_range=False, m2_0=False, m2_0_rnd=None, check_double_rnd=False, shift_bits=0, subtile_size=2, sgn_mag=False, sparse=False ):
        # print( bits, ebs, dim, rounding, as_float, limit_range, m2_0, m2_0_rnd, check_double_rnd, shift_bits, subtile_size )
        f = f.astype( np.float32 )
        max_shift = ( 2**shift_bits - 1 )
        if dim < 0: dim += len( f.shape )
        if m2_0_rnd is None: m2_0_rnd = self.rounding != 'sym_floor'
        shape = list( f.shape )
        shape[dim] = ebs
        shape.insert( dim, f.shape[dim]//ebs )
        fi = f.reshape( shape ).getfield( np.int32 )
        fis = np.sign( fi )
        fie = ( fi >> 23 ) & 255
        fim = (( fie>0 )<<23 ) | ( fi & (( 1<<23 )-1 ))
        fie = np.maximum( 1, fie )
        fieb = fie 
        if bfloat or ( bfloat is None and self.bfloat ):
            eb, expo_bias, extend = decode_expo_bits( expo_bits )
            if extend == 2:
                fieb = ((( fi & (( 1 << 31 ) - 1 )) + ( 1 << ( 23 - ( bits - eb - 1 )))) >> 23 ) & 255
        me  = np.max( fieb - m2_0 * ( fis*fim == ( -1<<23 )), dim+1 )

        if isinstance( bits, np.ndarray ):
            assert np.all( bits.shape == me.shape ), "[f2bfp]: If bits is ndarray, then shape needs to match shared exponents"
            mbits = np.expand_dims( bits, dim+1 )
        else:
            mbits = bits
        bm_max = ( 1<<( mbits-1 ))-1
        if limit_range:
            shift = 25-mbits + np.mod( shift, mbits-4 ) #TODO is mbits-4 correct?
            me = 127 + np.mod( me-127, mbits-4 )
            fim &= bm_max << shift

        if rounding is None: rounding = self.rounding
        if rounding == True: rounding = 'sym_inf'
        
        shift_shape = shape[:]
        shift_shape[dim+1] = subtile_size
        shift_shape.insert( dim+1, ebs//subtile_size )
        exp_diff = ( np.expand_dims( me, dim+1 ) - fie ).reshape( shift_shape ).getfield( np.int32 )
        min_shift = np.min( exp_diff, axis=dim+2 )
        shift_values =  np.minimum( min_shift, max_shift )
        if bfloat or ( bfloat is None and self.bfloat ):
            #TODO -2 below is not accurate as it depends on the
            me = np.maximum( 0, me - ( 2**eb - 1 - expo_bias - ( not extend )))
            shift = np.expand_dims( me, dim+1 ) - 127 - np.repeat( shift_values, subtile_size, dim+1 )
            # print( "shift:", shift[0, 0, :2] )
            # print( me[0, 0] )
            scale = ( 2.0**np.minimum( 127, -shift )).astype( np.float32 )
            #TODO what's the limit for the common exponent
            bm = self.f2bf( f.reshape( shape ) * scale, bits, expo_bits, rounding=rounding, filter_denorm = False) / scale
            # print( bm[0, 0, 0:2] ) 
        else:
            #assert 0, "We should not be here"
            # print( "my data" )
            # print( me[0,0] )
            shift = np.expand_dims( me, dim+1 ) - fie + 25 - mbits - np.repeat( shift_values, subtile_size, dim+1 )
            # print( shift[0,0,:] )
            bm = self.shift_round( fim, shift,  rounding=rounding )
            bm[bm>bm_max] = bm_max # saturate to bmax
            bm = bm & bm_max
            bm[shift>24] = 0 # prevents sign to become part of mantissa
            bm[fis<0] = (1 << ( mbits-1 ) | bm[fis<0] )
            if check_double_rnd:
                self.f2bfp_check_rnd_value = bm.copy( )
            elif hasattr( self, 'f2bfp_check_rnd_value' ) and self.f2bfp_check_rnd_value is not None:
                if self.f2bfp_check_rnd_value.size == bm.size:
                    mask = self.f2bfp_check_rnd_value != bm
                    assert not np.any( mask ), "f2bfp double rnd diff errors: {}\n1. call:\n{}\n2. call\n{}".format( np.sum( mask ), self.f2bfp_check_rnd_value[mask], bm[mask] )
                    self.f2bfp_check_rnd_value = None
                else:
                    self.log_msg.append( "[Warning]: f2bfp double round check skipped due to size mismatch: {} != {}".format( self.f2bfp_check_rnd_value.size, bm.size ))
                self.f2bfp_check_rnd_value = None
            if sgn_mag:
                bm[bm>=2**( bits-1 )] -= 2**bits
            else:
                bm[bm>=2**( bits-1 )] = 2**( bits-1 ) - bm[bm>=2**( bits-1 )]
        # print( "bm fin", bm[0,0,:] )
        bm = bm.reshape( f.shape )
        bfp = ( bm , me - 256 * ( me >= 128 ), shift_values )
        ret = self.bfp2f( bfp, bits, ebs, dim, bfloat, sgn_mag, subtile_size ) if as_float else bfp
        # if as_float:
        #     print( "output", ret[0,:ebs] )
        return ret

    def bfp2f( self, bfp, bits=8, ebs=1, dim=-1, bfloat=None, sgn_mag=False, subtile_size=2 ):
        bm, me, shift = bfp
        # print( bfloat, self.bfloat, bm[0, :ebs] )
        if bfloat or ( bfloat is None and self.bfloat ):
            return bm
        if dim < 0: dim += len( me.shape )
        shape = list( me.shape )
        shape[dim] *= ebs
        if isinstance( bits, np.ndarray ):
            bits = np.repeat( bits, ebs, axis=dim )
        fis = bm >> ( bits-1 )
        if sgn_mag:
            mag_bm = bm & ( 2**( bits-1 ) -1 )
        else:
            mag_bm = ( bm ^ fis ) - fis
        lz  = bits-2 - np.floor( np.log2( np.maximum( 1, mag_bm ))).astype( int )
        fie = np.maximum( 0, np.repeat( me & 255, ebs, axis=dim ) - np.repeat( shift, subtile_size, dim+1 ).reshape( shape ) - lz )
        fim = self.shift_round( mag_bm, bits - 25 - lz )
        fi  = ( fis<<31 ) | ((( fie<<23 ) | ( fim & (( 1<<23 )-1 ))) * ( mag_bm!=0 ))
        f = fi.getfield( np.float32 )
        # print( np.max( f ))
        return np.reshape( f, shape ).astype( np.float32 )


    def collapse_bfp( self, bfp, bits=8, ebs=1, dim=-1, bfloat=None, expo_bits=8, order=None, shift_bits=0, subtile_size=2, sparse = False, sparse_factor = 1.0, duplicate_ebs = False ):
        
        if sparse == True:
            sparse_factor = sparse_factor       # bm is already compressed
            enc_factor = 1
            spm, bm, me, shift = bfp  # we get the sparse info spm
            shape = list( me.shape )
            me = np.concatenate( ( spm, np.expand_dims(me, dim) ), dim )
        else:
            sparse_factor = 1
            enc_factor = 0
            bm, me, shift = bfp
            shape = list( me.shape )
            
        if duplicate_ebs:
            # only for weights
            # 
            # remap to duplicated ebs = ebs/2
            subtile_size = subtile_size/2
            ebs = ebs/2
            me = np.repeat(me, 2, axis=dim)
            shape = list( me.shape )
            c = np.repeat(shift, 2, axis=dim+1)
            shift = c.reshape((shape[dim-1],shape[dim],-1))
            
        if dim < 0: dim += len( bm.shape )
        # print( bm[0, 0], me[0], shift, bits, ebs, dim )
        if bfloat or ( bfloat is None and self.bfloat ):
            shift = np.expand_dims( me + 256 * ( me < 0 ), dim+1 ) - 127 - np.repeat( shift, subtile_size, dim+1 )
            scale = ( 2.0**np.minimum( 127, -shift.reshape( bm.shape ))).astype( np.float32 )
            # print( "collapse:", bm[0, 0:2], shift[0, 0, 0] )
            bm = self.collapse_float( bm * scale, bits, expo_bits )
            # print( "collapsed:", bm[0, 0:2] )
        
        shape.insert( dim+1, int( np.ceil( ebs * (bits) / 8.0 )))
        
        if bits != 8:
            perm = list( range( len( bm.shape )))
            perm.append( perm.pop( dim ))
            bm = bm.transpose( perm )
            if np.lcm( 8, bits ) in ( 8, bits ):
                bm = self.reinterpret_cast( bm, bits, 8, True, bfloat_in=False )
            else:
                bm = self.reinterpret_cast( bm, bits, np.lcm( 8, bits ),    True, bfloat_in=False )
                bm = self.reinterpret_cast( bm,       np.lcm( 8, bits ), 8, True, bfloat_in=False )
            perm = list( range( len( bm.shape )))
            perm.insert( dim, perm.pop( -1 ))
            bm = bm.transpose( perm )
            
        if shift_bits > 0:
            packed_shifts = self.pack_shifts( shift, dim, shift_bits, subtile_size )
            if sparse == True:
                clps = np.concatenate(( me, packed_shifts, bm.reshape( shape )), dim+1 )
            else:
                clps = np.concatenate(( np.expand_dims( me, dim+1 ), packed_shifts, bm.reshape( shape )), dim+1 )
        else:
            if sparse == True:
                clps = np.concatenate(( me, bm.reshape( shape )), dim+1 )
            else:
                clps = np.concatenate(( np.expand_dims( me, dim+1 ), bm.reshape( shape )), dim+1 )
            
        shape[dim] *= shape.pop( dim+1 ) + 1 + int( np.ceil( ebs / subtile_size * shift_bits / 8 )) + enc_factor*3 # 3B for encoding 24b
        
        if order is not None:
            val=''
            for p in reversed( order ):
                if p.isdigit( ):
                    val = p + val
                elif val:
                    return ( clps.reshape( shape ), order[:-len( val )] + str( int( np.ceil( int( val ) * ( enc_factor*((3*(ebs/4))/8)/(ebs) + sparse_factor * ( (bits / 8.0) + (1.0 / ebs) + shift_bits/( subtile_size*8.0 ))))))) # number of elements * bytes per sample
                else:
                    return ( clps.reshape( shape ), order )
                
        return clps.reshape( shape )
    

    def collapse_float( self, data, bits=16, expo_bits=8 ):
        if not isinstance( data, ( np.ndarray, np.number )):
            data = np.array( data, dtype=np.float32 )
        elif data.dtype == np.float64:
            data = data.astype( np.float32 )
        assert data.dtype == np.float32, "type {} not supported for collapse_float".format( data.dtype )
        expo_bits, expo_bias, extend = decode_expo_bits( expo_bits )
        shift = 32 - ( bits + 8 - expo_bits )
        bi = data.copy( ).getfield( np.int32 )
        if expo_bits < 8:
            expo = ( bi >> 23 ) & 255
            if extend:
                nan = expo == 255
            denorm = ( expo <= ( 127 - expo_bias )) & ( expo > 0 )
            if np.any( denorm ):
                bi[ denorm ] = ((( 1 << 23 ) | ( bi[ denorm ] & (( 1 << 23 ) - 1 ))) >> ( 128 - expo_bias - (( bi[ denorm ] >> 23 ) & 255 ))) - ( bi[ denorm ] & ( -1*2**31 ))
            expo_bias -= 2**( expo_bits - 1 ) - 1
            if expo_bias != 0:
                expo = ( bi >> 23 ) & 255
                bi[( expo > 0 ) & ( expo < 255 )] += expo_bias << 23
            bi = (( bi & ~(( 1 << 30 ) - 1 )) >> ( 8 - expo_bits )) | ( bi & (( 1 << ( 22 + expo_bits )) - 1 ))
            if extend == 2:
                bi[nan] = ( bi[nan] > 0 ) * 2**( bits - 1 + shift ) - 1
            elif extend:
                bi[(( bi >> shift ) & (( 1 << ( bits )) - 1 )) == 0 ] = 0
                bi[nan] = -1 << ( 23 + expo_bits )
        if shift > 0:
            bi >>= shift
        return bi

    # bits_out must be a valid data type 
    def reinterpret_cast_with_lcm(self, data_in, bits_in, bits_out, sgn_out=False, bfloat_in=None, expo_bits_in=8):
        if np.lcm( bits_out, bits_in ) in ( bits_out, bits_in ):
            data_out = self.reinterpret_cast( data_in, bits_in, bits_out, sgn_out , bfloat_in )
        else:
            data_out = self.reinterpret_cast( data_in, bits_in, np.lcm( bits_out, bits_in ),  True, bfloat_in=False )
            data_out = self.reinterpret_cast( data_out.copy(), np.lcm( bits_out, bits_in ), bits_out, True, bfloat_in=False )
        return data_out


type_expo_bits_default = {"float": 8, "float32": 8, "float16": 5, "float8": 4.607, "float6": 2.301, "float4": 2, "bfloat16": 8}

            
def decode_expo_bits( eb ):
    if type( eb ) == str:
        eb = type_expo_bits_default[eb]
    if np.mod( eb, 1 ) == 0:
        return ( eb, 2**( eb - 1 ) - 1, False )
    bits = int( eb )
    eb   = int( round( 1e3 * np.mod( eb, 1 )))
    ext  = int( round( eb / 300.0 ))
    bias = eb - 300 * ext
    assert bits < 8, "Extended exponent range and exponent bias only supported with exponents of less then 8 bits"
    return ( bits, bias, ext )

def encode_expo_bits( bits=0, bias=None, extend=False ):
    if bias is None and not extend:
        return bits
    else:
        if bias is None:
            bias = 2**( bits - 1 ) - 1
        return bits + 0.3 * extend + 1e-3 * bias

class NamedList:
    def __init__( self, field_names, init_val=0, init_access=None, init_missing=0 ):
        if isinstance( field_names, ( dict, NamedList )):
            init_access = 'dict'
            init_val = field_names
            if isinstance( field_names, dict ):
                field_names = list( field_names.keys( ))
                field_names.sort( )
            else:
                field_names = field_names._keys( )
        elif isinstance( field_names, ( tuple, list )) and len( field_names ) > 0 and isinstance( field_names[0], ( tuple, list )) and len( field_names[0] ) == 2:
            field_names,init_val = zip( *field_names )
        self._field_names = list( field_names )
        if sys.version_info >= ( 3, 0 ) and isinstance( init_val, ( map, zip, filter )): init_val = tuple( init_val )
        if not init_access:
            if   isinstance( init_val, ( list, tuple )):     init_access = 'list'
            elif isinstance( init_val, ( dict, NamedList )): init_access = 'dict'
        for idx, key in enumerate( field_names ):
            if not init_access: val = init_val
            elif init_access in ( 'dict', 'key' ):
                if isinstance( init_val, NamedList ): val = init_val._get( key, init_missing )
                else: val = init_val.get( key, init_missing )
            else: val = init_val[idx] if len( init_val )>idx else init_missing
            if isinstance( val, dict ):
                val = NamedList( val )
            setattr( self, key, val )

    def __getitem__( self, key ):
        if isinstance( key, numbers.Integral ):
            return getattr( self, self._field_names[key] )
        elif type( key ) is slice:
            return [getattr( self, k ) for k in self._field_names[key]]
        elif isinstance( key, ( tuple, list )) or ( sys.version_info >= ( 3, 0 ) and isinstance( key, ( map, zip ))):
            return tuple( map( self.__getitem__, key ))
        else:
            if not hasattr( self, key ):
                raise KeyError( "'{}' not in NamedList. Available are {}".format( key, self._field_names ))
            return getattr( self, key )

    def __setitem__( self, key, val ):
        if isinstance( key, numbers.Integral ):
            setattr( self, self._field_names[key], val )
        elif type( key ) is slice:
            for k, v in zip( self._field_names[key], val ): setattr( self, k, v )
        elif isinstance( key, ( tuple, list )) or ( sys.version_info >= ( 3, 0 ) and isinstance( key, ( map, zip ))):
            for k, v in zip( key, val ): self.__setitem__( k, v )
        elif not hasattr( self, key ):
            self._append( key, val )
        else:
            setattr( self, key, val )

    def __len__( self ):
        return len( self._field_names )

    def __iter__( self ):
        return iter( self._field_names )

    def _values( self ):
        return [getattr( self, key ) for key in self._field_names]

    def _get( self, key, default=None ):
        if isinstance( key, numbers.Integral ) and abs( key ) < len( self._field_names ):
            return getattr( self, self._field_names[key] )
        elif key in self._field_names:
            return getattr( self, key )
        else:
            return default

    def _keys( self ):
        return tuple( self._field_names )

    def _items( self ):
        return zip( self._field_names, self._values( ))

    def _append( self, key, val ):
        self._field_names.append( key )
        setattr( self, key, val )

    def _extend( self, named_list ):
        keys = named_list if isinstance( named_list, ( dict, )) else named_list._field_names
        for key in keys:
            if key not in self._field_names:
                self._field_names.append( key )
            setattr( self, key, named_list[key] )

    def _update( self, named_list ):
        if isinstance( named_list, ( tuple, list )):
            min_len = min( len( self ), len( named_list ))
            self[:min_len] = named_list[:min_len]
        else:
            keys = named_list if isinstance( named_list, ( dict, )) else named_list._field_names
            for key in self._field_names:
                if key in keys:
                    setattr( self, key, named_list[key] )

    def __or__( self, named_list ):
        new = self._copy( )
        for k,v in ( named_list.items( ) if isinstance( named_list, dict ) else named_list._items( )):
            if k not in self._field_names:
                new._append( k, v )
        return new


    def _cmp( self, op, c ):
        def top( a, b ):
            if a is None:
                if b is None:
                    return False
                else:
                    return op( -float('inf'), b )
            else:
                if b is None:
                    return op( a, -float('inf') )
                else:
                    return op( a, b )
        if isinstance( c, ( tuple, list, NamedList )):
            return [ top( self[k], c[k] ) for k in range( len( self._field_names ))]
        elif isinstance( c, dict ):
            return [ top( self[k], c.get( k )) for k in self._field_names ]
        else:
            return [ top( self[k], c ) for k in self._field_names ]

    def __eq__( self, c ):
        return self._cmp( lambda x, y: x==y, c )
    def __ge__( self, c ):
        return self._cmp( lambda x, y: x>=y, c )
    def __gt__( self, c ):
        return self._cmp( lambda x, y: x>y, c )
    def __le__( self, c ):
        return self._cmp( lambda x, y: x<=y, c )
    def __lt__(self, c):
        return self._cmp(lambda x, y: x < y, c)
    def __ne__( self, c ):
        return self._cmp( lambda x, y: x!=y, c )

    def _copy( self ):
        #return copy.copy( self )
        return type( self )( self )

    def __str__( self ):
        return "{" + ", ".join( ['"{}": {}'.format( k, '"'+v+'"' if type( v ) is str else ( int( v ) if type( v ) is bool else v )) for k, v in self._items( )] ) + "}"



def type_decoder( typ ):
    sgn = True
    if typ.startswith( 'uint' ):
        typ = typ[1:]
        sgn = False
    if typ in ( "int", "float" ):
        bits = 32
    elif typ == "bool":
        bits = 8
    elif typ == "bfloat16":
        bits = 16
    elif typ.startswith( 'int' ):
        if ':' in typ: bits = int( typ.split( ':' )[-1] )
        else:          bits = int( typ.split( '_' )[0][3:] )
    elif typ.startswith( "float" ):
        bits = int( typ[5:] )
    else:
        raise TypeError( "Please implement packing for this type {}".format( typ ))
    return bits, sgn

def sizeof( val ):
    return type_decoder( val )[0] / 8

def sign( val ):
    if type( val ) == str:
        return type_decoder( val )[1]
    else:
        return val < 0


class TypedNamedList( NamedList ):
    def __init__( self, typed_fields, init_val=0, init_access=None, converter=None ):
        self._converter = converter if converter else DataConverter( bfloat=True )
        if len( typed_fields ) == 0:
            self._types = []
            self._field_names = []
        elif isinstance( typed_fields, TypedNamedList ):
            self._types = list( typed_fields._types )
            NamedList.__init__( self, typed_fields )
        else:
            types, field_names = zip( *map( self.__split_named_type, typed_fields ))
            self._types = list( types )
            NamedList.__init__( self, field_names, init_val=init_val, init_access=init_access )

    def __split_named_type( self, name ):
        if ( callable( name[0] ) and issubclass( name[0], TypedNamedList )) or isinstance( name, ( tuple, list )):
            return name
        typ, name = name.split( )
        if ':' in name:
            name, bw = name.split( ':' )
            typ += ':' + bw
        return ( typ, name )


    def _append( self, key, val ):
        t,key = self.__split_named_type( key )
        self._types.append( t )
        self._field_names.append( key )
        setattr( self, key, val )

    def _extend( self, named_list ):
        for t,key in map( self.__split_named_type, named_list ) if not isinstance( named_list, TypedNamedList ) else zip( named_list._types, named_list._field_names ):
            if key not in self._field_names:
                self._types.append( t )
                self._field_names.append( key )
            setattr( self, key, named_list[key] )


    def __or__( self, named_list ):
        assert isinstance( named_list, TypedNamedList ), "Only TypedNamedList supported for rhs of 'or'"
        new = self._copy( )
        for i,k in enumerate( named_list ):
            if k not in self._field_names:
                new._append(( named_list._types[i], k ), named_list[k] )
        return new

    def _get_alignment( self ):
        align = 8
        for typ, val in zip( self._types, self._values( )):
            if isinstance( val, TypedNamedList ):
                align = max( align, val._get_alignment( ))
            elif callable( typ ) and issubclass( typ, TypedNamedList ):
                align = max( align, typ( init_val=val )._get_alignment( ))
            else:
                align = max( align, type_decoder( typ.split( ':' )[0] )[0] )
        return align


    def _get_stream( self, bits_stream=32, stream=None, ptr=0 ):
        def append_stream( stream, ptr, val, bits, sign=False ):
            if type( val ) is bool: val = int( val )
            val = int(val) & ( 2**bits-1 )
            pos  = ptr %  bits_stream
            word = ptr // bits_stream
            while pos + bits > bits_stream:
                take = bits_stream - pos
                push = (( val & ( 2**take-1 )) << pos )
                if pos==0: stream.append( push )
                else: stream[word] |= push
                val >>= take
                ptr += take; bits -= take
                pos  = ptr %  bits_stream
                word = ptr // bits_stream
            if bits > 0:
                if pos==0: stream.append( val )
                else: stream[word] |= ( val << pos )
                ptr += bits
            return stream, ptr

        chained = stream is not None
        align = self._get_alignment( )
        if not chained: stream = []
        elif ptr % align != 0:
            pad = align - ptr % align
            stream, ptr = append_stream( stream, ptr, 0, pad )

        for typ, val in zip( self._types, self._values( )):
            if isinstance( val, TypedNamedList ):
                stream, ptr = val._get_stream( bits_stream, stream, ptr )
            elif callable( typ ) and issubclass( typ, TypedNamedList ):
                stream, ptr = typ( init_val=val )._get_stream( bits_stream, stream, ptr )
            else:
                bits, sign = type_decoder( typ )
                if ptr % bits != 0 and ':' not in typ:
                    pad = bits - ptr % bits
                    stream, ptr = append_stream( stream, ptr, 0, pad )
                if "float" in typ:
                    ( expo_bits, expo_bias, fmt_ext ) = decode_expo_bits( typ )
                    val = self._converter.collapse_float( val, bits=bits, expo_bits=expo_bits )
                stream, ptr = append_stream( stream, ptr, val, bits, sign )
        if ptr % align != 0:
            pad = align - ptr % align
            stream, ptr = append_stream( stream, ptr, 0, pad )
        return ( stream, ptr ) if chained else stream


    def _set_from_stream( self, stream, bits_stream=32, ptr=None, buf=0, has=0 ):
        def strip_padding( pad, ptr, buf, has ):
            print( "strip pad in:", pad, ptr, has )
            if has >= pad:
                buf >>= pad
                has -=  pad
            else:
                pad -= has
                ptr += pad // bits_stream
                pad = pad % bits_stream
                if pad != 0:
                    has = bits_stream - pad
                    buf = stream[ptr] >> pad
                    ptr += 1
            print( "strip pad out", ptr, has )
            return ptr, buf, has

        chained = ptr is not None
        align = self._get_alignment( )
        if not chained: ptr = 0
        elif ( ptr * bits_stream - has ) % align != 0:
            ptr, buf, has = strip_padding( align - ( ptr * bits_stream - has ) % align, ptr, buf, has )

        for idx, typ in enumerate( self._types ):
            if isinstance( self._get( idx ), TypedNamedList ):
                ptr, buf, has = self._get( idx )._set_from_stream( stream, bits_stream, ptr, buf, has )
            elif callable( typ ) and issubclass( typ, TypedNamedList ):
                tnl = typ( )
                ptr, buf, has = tnl._set_from_stream( stream, bits_stream, ptr, buf, has )
                self.__setitem__( idx, tnl )
            else:
                bits, sign = type_decoder( typ )
                assert "float" not in typ, "Conversion from byte array to float type not yet implemented"
                if ( ptr * bits_stream - has ) % bits != 0 and ':' not in typ:
                    ptr, buf, has = strip_padding( bits - ( ptr * bits_stream - has ) % bits, ptr, buf, has )
                while has < bits:
                    buf |= ( stream[ptr] & ( 2**bits_stream-1 )) << has
                    has += bits_stream
                    ptr += 1
                val = buf & ( 2**bits-1 )
                buf >>= bits
                has -=  bits
                if sign and val & ( 2**( bits-1 )):
                    val -= 2**bits
                self.__setitem__( idx, val )

        if ( ptr * bits_stream - has ) % align != 0:
            ptr, buf, has = strip_padding( align - ( ptr * bits_stream - has ) % align, ptr, buf, has )
        return ( ptr, buf, has ) if chained else None


    def _to_byte_array( self ):
        return bytearray( self._get_stream( 8 ))

    def _set_from_byte_array( self, barr ):
        self._set_from_stream( barr, 8 )

NamedList_or_dict = typing.Union[ NamedList, dict ]
Optional_NamedList_or_dict = typing.Optional[ NamedList_or_dict ]

#helper functions to access dict or NamedList
def values( dnl: NamedList_or_dict ):
    return dnl._values( ) if isinstance( dnl, NamedList ) else dnl.values( )
def items( dnl: NamedList_or_dict ):
    return dnl._items( ) if isinstance( dnl, NamedList ) else dnl.items( )
def keys( dnl: NamedList_or_dict ):
    return dnl._keys( ) if isinstance( dnl, NamedList ) else dnl.keys( )
def get( dnl: NamedList_or_dict, key, default=None ):
    return dnl._get( key, default ) if isinstance( dnl, NamedList ) else dnl.get( key, default )
def copy( dnl: NamedList_or_dict ):
    return dnl._copy( ) if isinstance( dnl, NamedList ) else dnl.copy( )



def max( a, *b ):
    if len( b ) == 0:
        return np.max( a )
    elif len( b ) == 1:
        return np.maximum( a, *b )
    else:
        return pmax( a, *b )

def min( a, *b ):
    if len( b ) == 0:
        return np.min( a )
    elif len( b ) == 1:
        return np.minimum( a, *b )
    else:
        return pmin( a, *b )

def ceil( n, d=1 ):
    if d == 1:
        return int( np.ceil( n ))
    else:
        return d * int( np.ceil( n / d ))

def floor( n, d=1 ):
    if d == 1:
        return np.int64( np.floor( n ))
    else:
        return d * np.int64( np.floor( n / d ))

def round( n, d=1 ):
    if d == 1:
        return np.int64( np.round( n ))
    else:
        return d * np.int64( np.round( n / d ))

class DimsHelper:
    def __init__( self, reset=0, bits=32 ):
        self.reset = reset
        self.bits = bits

    def __getitem__( self, key ):
        return getattr( self, key )

    def add_dimension( self, num, step ):
        inc = self.reset + step
        self.reset -= num * step
        return inc

    def from_steps( self, wraps, steps, next_loop_level=False ):
        wraps = make_tuple( wraps )
        steps = make_tuple( steps )
        assert len( steps ) in [1,2,3,4,5], "Only 1d to 5d address increments supported"
        assert len( wraps ) >= len( steps ) - 1, "Wrap dimesions passed are not sufficient"

        nums = []
        incs = []
        for i,s in enumerate( steps ):
            if i == len( wraps ):
                incs.append( self.reset + s )
                self.reset = 0
            else:
                if i < len( steps ) - 1:
                    if i % 3 == 2:
                        num = wraps[i]
                        nums.append( np.prod( wraps[:i+1] ) - 1 )
                    else:
                        num = wraps[i] - 1
                        nums.append( num )
                else:
                    num = wraps[i] - 1
                incs.append( self.add_dimension( num, s ))

                if ( next_loop_level and i == len( steps ) - 1 ) or ( i < len( steps ) - 1 and i % 3 == 2 ):
                    self.reset = -wraps[i] * s

        if len( incs ) == 1:
            return incs[0]
        else:
            return TypedNamedList([ f"uint{self.bits} num{n}" for n in range( len( nums ))] + [ f"int{self.bits} inc{n}" for n in range( len( incs ))], nums + incs )

    def __str__(self):
        return "DimsHelper(reset={}, bits={})".format( self.reset, self.bits )

def make_tuple( val ):
    if hasattr( val, "__len__" ):
        return tuple( val )
    else:
        return ( val, )

class dims_2d_param(ctypes.Structure):
    _fields_ = [
        ("num0", ctypes.c_uint32),
        ("inc0", ctypes.c_int32),
        ("inc1", ctypes.c_int32)
    ]

class dims_3d_param(ctypes.Structure):
    _fields_ = [
        ("num0", ctypes.c_uint32),
        ("num1", ctypes.c_uint32),
        ("inc0", ctypes.c_int32),
        ("inc1", ctypes.c_int32),
        ("inc2", ctypes.c_int32)
    ]

class dims_4d_param(ctypes.Structure):
    _fields_ = [
        ("num0", ctypes.c_uint32),
        ("num1", ctypes.c_uint32),
        ("num2", ctypes.c_uint32),
        ("inc0", ctypes.c_int32),
        ("inc1", ctypes.c_int32),
        ("inc2", ctypes.c_int32),
        ("inc3", ctypes.c_int32)
    ]

class dims_5d_param(ctypes.Structure):
    _fields_ = [
        ("num0", ctypes.c_uint32),
        ("num1", ctypes.c_uint32),
        ("num2", ctypes.c_uint32),
        ("num3", ctypes.c_uint32),
        ("inc0", ctypes.c_int32),
        ("inc1", ctypes.c_int32),
        ("inc2", ctypes.c_int32),
        ("inc3", ctypes.c_int32),
        ("inc4", ctypes.c_int32)
    ]

def random_gen(low=-128, high=127, size=1, dtype=np.int8 ):
    return np.random.randint( low=low, high=high, size=size, dtype=dtype )

def conv_to_local_ptr(addr: int) -> int:
    """Constant offset."""
    core_local_offset = 0xE0000
    return core_local_offset + addr
