# pylint: skip-file

# Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
# Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# This file contains confidential and proprietary information
# of Xilinx, Inc. and is protected under U.S. and
# international copyright and other intellectual property
# laws.
#
# DISCLAIMER
# This disclaimer is not a license and does not grant any
# rights to the materials distributed herewith. Except as
# otherwise provided in a valid license issued to you by
# Xilinx, and to the maximum extent permitted by applicable
# law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
# WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
# AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
# INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
# (2) Xilinx shall not be liable (whether in contract or tort,
# including negligence, or under any other theory of
# liability) for any loss or damage of any kind or nature
# related to, arising under or in connection with these
# materials, including for any direct, or any indirect,
# special, incidental, or consequential loss or damage
# (including loss of data, profits, goodwill, or any type of
# loss or damage suffered as a result of any action brought
# by a third party) even if such damage or loss was
# reasonably foreseeable or Xilinx had been advised of the
# possibility of the same.
#
# CRITICAL APPLICATIONS
# Xilinx products are not designed or intended to be fail-
# safe, or for use in any application requiring fail-safe
# performance, such as life-support or safety devices or
# systems, Class III medical devices, nuclear facilities,
# applications related to the deployment of airbags, or any
# other applications that could lead to death, personal
# injury, or severe property or environmental damage
# (individually and collectively, "Critical
# Applications"). Customer assumes the sole risk and
# liability of any use of Xilinx products in Critical
# Applications, subject only to applicable laws and
# regulations governing limitations on product liability.
#
# THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
# PART OF THIS FILE AT ALL TIMES.

import numpy as np
import numbers
import os.path
from .data_shaper import DataShaper
from .data_converter import DataConverter



class GenerateDataHelpers( DataShaper, DataConverter ):
    def __init__( self, path,
                 default_order='NHWC',
                 rounding=True, bfloat=False, seed=None,
                 appendix=None, print_info=False ):

        DataShaper.__init__( self, default_order, print_info )
        DataConverter.__init__( self, rounding, bfloat )
        self.path      = path
        self.appendix  = appendix

        self.idx       = 0
        self.itr       = None
        self.sizes     = {}
        self.small_value_threshold = 4
        self.large_value_threshold = 0.99
        self.limit_actv_minimum = 1

        self.random = np.random.RandomState( seed )
        self.random_exponent = False

    def check_double_rnd_enabled( self ):
        return True

    def set_np_print_to_hex( self ):
        np.set_printoptions( formatter={'int':lambda x: hex( np.long( x ))} )# & 0xFFFFFFFF )} )

    def set_np_print_to_dec( self ):
        np.set_printoptions( formatter={} )


    def random_mat( self, shape, bits, sgn=True, bfloat=None, ebs=None, expo_bits=8, gauss=False, var=None, padding=None, debug=False, **kwargs ):
        if bfloat is None: bfloat = self.bfloat
        if var is None: var = min( 128, 2**( expo_bits + 1 )) if bfloat else 2**(( bits + 1 ) * 3 // 4 )
        if debug:
            if len( shape ) < 4: shape = list( shape ).extend( [1]*( 4-len( shape )))
            data = np.zeros( shape=shape, dtype=np.int_ )
            sa = 16 if bits>8 else ( 4 if bits>4 else 2 )
            for n in range( min( sa, shape[0] )):
              for y in range( min( sa, shape[1] )):
                for x in range( min( sa, shape[2] )):
                  for c in range( min( sa, shape[3] )):
                    data[n, y, x, c] = (( n*sa + y )*sa + x )*sa + c
            data[data >= 2**( bits-int( sgn ))] -= 2**bits
        elif gauss or bfloat or ebs:
            data = self.random.normal( size=shape, loc=( 0 if sgn else var/4 ), scale=var )
            if self.random_exponent:
                data *= 2**self.random.normal( size=shape, loc=0, scale=2 )
            # if ebs:
            #     data = np.zeros( shape )
            #     mi = np.min( shape )
            #     ma = np.max( shape )
            #     for i in range( mi ):
            #         #data[i, i+(ma-mi)//2] = i
            #         data[i, i] = i - ( mi!=ma ) * mi / 4
            data = data.astype( np.float32 )
            if ebs:
                data = self.f2bfp( data, bits, ebs[0], ebs[1], as_float=True, expo_bits=expo_bits, rounding = False, m2_0_rnd=False, check_double_rnd=self.check_double_rnd_enabled())
            else :
                data = self.srs( data, 0, bits, sgn, bfloat=bfloat, ebs=ebs, expo_bits=expo_bits, as_float=ebs is not None, float_to_int=not bfloat )
        else:
            if {"min", "max"}.issubset(kwargs.keys()):
                low  = kwargs["min"]
                high = kwargs["max"]
                data = self.random.randint( low, high, size=shape, dtype=np.long )
            else:
                low, high = -2**( bits-1 )*int( sgn ), 2**( bits-int( sgn ))-1
                data = self.random.randint( low, high+1, size=shape, dtype=np.long )
        if padding is not None:
            data = np.pad( data, list(zip( [0]*len( padding ), padding )), 'constant' )
        return data


    def block_float_dot_product( self, actv, wght, accum=None, bits=13, dot_len=16, ebs=None, reduce_first=False ):
        assert isinstance( actv, np.ndarray ) and actv.dtype == np.float32, "Expected first input as ndarray of type float32 holding the values of given lower precision type"
        assert isinstance( wght, np.ndarray ) and wght.dtype == np.float32, "Expected second input as ndarray of type float32 holding the values of given lower precision type"
        assert accum is None or ( isinstance( accum, np.ndarray ) and accum.dtype == np.float32 ), "Expected accum to be None or ndarray of type float32"

        assert len( actv.shape ) == 2 and len( wght.shape ) == 2, "broadcasting pattern hardcoded as of now"
        actv = np.expand_dims( actv, 2 )
        wght = np.expand_dims( wght, 0 )
        idx  = 1
        mul  = actv * wght

        def get_expo( data, ebs, dim=-1 ):
            data = ( data.getfield( np.int32 ) >> 23 ) & 255
            if ebs:
                shape = list( data.shape )
                shape[dim] = ebs
                shape.insert( dim, data.shape[dim]//ebs )
                data = np.broadcast_to( np.max( data.reshape( shape ), axis=dim+1, keepdims=True ), shape ).reshape( data.shape )
            return data

        def rounding( data, max_expo ):
            if data.dtype == np.float64:
                data = data.getfield( np.int64 )
                shift = np.maximum( -5, max_expo + 896 - (( data >> 52 ) & 2047 ))
                mant = data & (( 1 << 52 ) - 1 )
                data[( shift > 24 ) | (( shift == 24 ) & ( mant == 0 ))] = 0
                mask = ( shift == 24 ) & ( mant != 0 )
                if np.any( mask ):
                    data[mask] += ( 1 << 52 )
                    data[mask] &= ~(( 1 << 52 ) - 1 )
                shift = np.minimum( 23, shift )
                mag  = data & (( 1 << 63 ) - 1 )
                mag += (( 1 << ( shift + 28 )) - 1 + ((( data  >> ( shift + 29 )) & 1 ) | ( shift == 23 ))) * ( shift >= -5 )
                data = ( mag & ~(( 1 << ( shift + 29 )) - 1 )) + ( data & ( -1 << 63 ))
                return data.getfield( np.float64 )

            data = data.getfield( np.int32 )
            shift = np.maximum( 0, max_expo - (( data >> 23 ) & 255 ))
            mant = data & (( 1 << 23 ) - 1 )
            data[( shift > 24 ) | (( shift == 24 ) & ( mant == 0 ))] = 0
            mask = ( shift == 24 ) & ( mant != 0 )
            if np.any( mask ):
                data[mask] += ( 1 << 23 )
                data[mask] &= ~(( 1 << 23 ) - 1 )
            shift = np.minimum( 23, shift )
            mag  = data & (( 1 << 31 ) - 1 )
            mag += (( 1 << ( shift - 1 )) - 1 + ((( data  >> shift ) & 1 ) | ( shift == 23 ))) * ( shift > 0 )
            data = ( mag & ~(( 1 << shift ) - 1 )) + ( data & ( -1 << 31 ))
            return data.getfield( np.float32 )

        expo_in = get_expo( actv, ebs, idx ) + get_expo( wght, ebs, idx ) - 127
        if accum is None:
            accum = np.zeros( np.array( mul.shape )[ np.arange( len( mul.shape )) != idx ], dtype=np.float32 )


        if bits < 16 or reduce_first:
            reduction = reduce_first if reduce_first else 2
            assert idx==1, "next statement incorrect when idx != 1"
            redu_shape = ( mul.shape[0], -1, reduction, mul.shape[2] )
            max_expo = np.max( expo_in.reshape( redu_shape ), axis=idx+1 )
            mul = mul.reshape( redu_shape )
            if not ebs:
                mul = rounding( mul, np.expand_dims( max_expo, idx+1 ))
            mul = np.sum( mul.astype( np.float64 ), axis=idx+1 )
            expo_in = max_expo
            dot_len //= reduction

        for p in range( 0, mul.shape[idx], dot_len ):
            mi = np.take( mul, np.arange( p, p+dot_len ), axis=idx )
            max_expo = np.maximum( np.max( np.take( expo_in, np.arange( p, p+dot_len ), axis=idx ), axis=idx ), get_expo( accum, None ))
            accum = rounding( accum, max_expo )
            mi = rounding( mi, np.expand_dims( max_expo, idx ))
            accum = accum.astype( np.float64 )
            accum += np.sum( mi.astype( np.float64 ), axis=1 )
            accum = accum.astype( np.float32 )
        return accum


    def compress_data( self, data, bits=8, structured=False, stochastic=False, structured_format=( 4, 2 ), bfloat=None, expo_bits=8 ):
        mask_width = 16 if structured else 32
        if len( data ) == 2:
            data, exps = data
            words = data.size // mask_width
        else:
            exps = None
            words = data.size * bits//8 // mask_width
            data = self.reinterpret_cast( data, bits, 8, bfloat_in=bfloat, expo_bits_in=expo_bits )
        data = data.reshape( -1 )
        comp = np.zeros( words * ( mask_width + mask_width//8 ), dtype=np.int_ )
        ptr  = 0

        for idx in range( words ):
            line = data[idx*mask_width:( idx+1 )*mask_width]
            zero = line != 0

            fixed_mask = True
            if not stochastic and fixed_mask:
                zero = zero.reshape( -1, max( 1, bits//8 ))
                for m in range( 0, zero.shape[0], structured_format[0] ):
                    z = np.any( zero[m:m+structured_format[0], :], 1 )
                    if sum( z ) == 0:
                        z[:structured_format[1]] = True
                    elif sum( z ) < structured_format[1]:
                        assert structured_format == ( 4, 2 ), "Only 4:2 format supprted for structured compression"
                        z[-1-np.where( z==True )[0][0]] = True
                    assert sum( z ) == structured_format[1], "Invalid number of mask bits set"
                    zero[m:m+structured_format[0], :] = np.expand_dims( z, 1 )
                zero = zero.flatten( )

            mask = int( ''.join( [str( int( x )) for x in reversed( zero )] ), 2 )
            comp[ptr:ptr+mask_width//8] = [( mask>>( 8*p )) & 0xFF for p in range( mask_width//8 )]
            ptr += mask_width//8

            if not stochastic and not fixed_mask:
                zero = zero.reshape( -1, max( 1, bits//8 ))
                for m in range( 0, zero.shape[0], structured_format[0] ):
                    z = np.any( zero[m:m+structured_format[0], :], 1 )
                    if sum( z ) == 0:
                        z[:structured_format[1]] = True
                    elif sum( z ) < structured_format[1]:
                        assert structured_format == ( 4, 2 ), "Only 4:2 format supprted for structured compression"
                        z[-1-np.where( z==True )[0][0]] = True
                    assert sum( z ) == structured_format[1], "Invalid number of mask bits set"
                    zero[m:m+structured_format[0], :] = np.expand_dims( z, 1 )
                zero = zero.flatten( )

            comp[ptr:ptr+line[zero].size] = line[zero]
            ptr += int( np.ceil( np.sum( zero ) / 4.0 ) * 4 )

        comp = comp[:ptr]
        if exps is not None:
            cmpr_ebs = int( mask_width*structured )
            comp = comp.reshape( -1, cmpr_ebs + mask_width//8 )
            exps[exps<0] += 256
            bfp = self.collapse_bfp(( comp[:, mask_width//8:], np.expand_dims( exps, -1 )), bits, cmpr_ebs, bfloat=bfloat, expo_bits=expo_bits )
            comp = np.concatenate(( comp[:, :mask_width//8], bfp ), 1 ).flatten( )

        self.log_msg.append( "Compressed size: {:2.4} %, Non-zeros: {:2.4} %".format( 100.0 * ptr / data.size, 100.0 * np.sum( data!=0 ) / data.size ))

        return comp

    def sparsify_data( self, data, sparsity,  bits=8, sgn=True ):
        zeroes_to_add = int(( sparsity/ 100.0 ) * data.size ) - np.count_nonzero( data == 0 )
        idx = np.random.choice( np.flatnonzero( data ), zeroes_to_add, False )
        shape = data.shape
        flattened_data = data.reshape( -1 )
        flattened_data[idx] = 0
        self.log_msg.append( "Data sparsity achieved : {} %".format(( flattened_data.size - np.count_nonzero( flattened_data ))/float( flattened_data.size ) * 100 ))
        return flattened_data.reshape( shape )

    def str_append( self, string, appendix=None ):
        if appendix is None: appendix = self.appendix
        if appendix is None: return string
        appendix = str( appendix )
        if not appendix.startswith( '_' ): string += '_'
        return string + appendix


    def write_file( self, name, mat, order=None, bits=32, bits_per_line=32, reinterpret=None, bfloat=None, ebs=None, expo_bits=8, is_hex=False, defOrder=None, append=False, itr=None, filter_denorm=None ):
        if type( order ) in ( tuple, list ):
            order = order[self.idx]
        if ebs:
            if type( mat ) is not tuple:
                mat = self.f2bfp( mat, bits, *ebs, bfloat=bfloat, expo_bits=expo_bits, filter_denorm=filter_denorm )
            mat, order = self.collapse_bfp( mat, bits, *ebs, bfloat=bfloat, expo_bits=expo_bits, order=order )
            data = self.reorder_mat( mat, order, defOrder )
            bits = 8
        elif order and order != defOrder:
            data = self.reorder_mat( mat, order, defOrder )
        elif type( mat ) == np.ndarray:
            data = mat.reshape( -1 )
        else: data = mat
        if bfloat is None: bfloat = self.bfloat
        if bfloat and not ebs and 'param' not in name:
            data = self.collapse_float( data, bits, expo_bits )
            if reinterpret is None and np.mod( np.log2( bits ), 1 ) != 0:
                reinterpret = ( bits, int(2**np.ceil( np.log2( bits ))))
                if reinterpret[1] < 8:
                    print( "WARNING: Reinterpret cast operation might fail since destination reinterpretation is not a standard data type size" )
                    print( "    " + reinterpret[1] + " bits data type not allowed" )
                if np.mod( bits * data.size, reinterpret[1] ) != 0:
                    print( "WARNING: Reinterpret cast operation might fail since padding is necessary" )
        size = len( data ) * bits // 8
        if itr is None:
            itr = self.idx if self.itr is None else self.itr
        if bits < 8 and reinterpret is None: reinterpret = ( bits, 8 )
        if reinterpret is not None:
            data = self.reinterpret_cast_with_lcm(data.copy(), *reinterpret, bfloat_in=False)
            bits = bits * reinterpret[1] // reinterpret[0]
        values_per_line = max( 1, bits_per_line // bits )
        with open( os.path.join( self.path, self.str_append( self.str_append( name ), itr ) + '.txt' ), 'w' if not append else 'a' ) as fo:
            for i in range( 0, len( data ), values_per_line ):
                frmt = '{{:{}x}}'.format( bits // 4 ) if is_hex else '{{:{}}}'.format( int( np.ceil( 0.301 * bits + 1 )))
                fo.write( ' '.join( map( frmt.format, data[i:i+values_per_line] ))+'\n' )

        # TODO update size sanity to new file naming strategy
        #if self.idx not in self.sizes:
        #    self.sizes[self.idx] = {}
        #if name not in self.sizes[self.idx]:
        #    self.sizes[self.idx][name] = size
        #elif append:
        #    sanity = self.sizes[self.idx].get( name + "_append_sanity" )
        #    if sanity is None:
        #        self.sizes[self.idx][name + "_append_sanity"] = [self.sizes[self.idx][name], size]
        #        self.sizes[self.idx][name] += size
        #    elif sanity[-1] is None:
        #        assert size in sanity[1:-1], "Appended size to file expected to be constant between iterations ( was one of={}, got={} )".format( sanity[1:-1], size )
        #    else:
        #        self.sizes[self.idx][name] += size
        #        self.sizes[self.idx][name + "_append_sanity"].append( size )
        #else:
        #    sanity = self.sizes[self.idx].get( name + "_append_sanity" )
        #    if sanity is not None:
        #        self.sizes[self.idx][name + "_append_sanity"].append( None )
        #        sz = sanity[0]
        #    else:
        #        sz = self.sizes[self.idx][name]
    
        #    print( f"{name}, {self.sizes}" )
        #    assert sz == size, "File size expected to be constant between iterations ( was={}, got={} )".format( sz, size )
        return size


    def write_size_file( self, name ):
        with open( os.path.join( self.path, self.str_append( self.str_append( name ), 'size' ) + '.txt' ), 'w' ) as f:
            for k in sorted( self.sizes.keys( )):
                f.write( str( self.sizes[k].get( name, 0 )) + '\n' )


    def limit_actv_size( self, shape, order, bits, size=0x3000, defOrder=None ):
        if not defOrder: defOrder = self.defOrder
        if type( order ) in ( tuple, list ):
            order = order[self.idx]
        shape_gran = []
        for z in defOrder:
            gran, pre_g = self._reorder_granularity_range( order, z )
            for key in pre_g:
                p = defOrder.find( key )
                fix  = shape_gran[p] // pre_g[key]
                free = shape_gran[p] // shape[p]
                used = shape_gran[p] // free // fix
                shape_gran[p] //= free
                gran          //= used
            shape_gran.append( gran )
        shape_rem = np.maximum( 1, np.divide( shape, shape_gran ))
        cut_off = np.maximum( np.ones( shape_rem.shape ), self.limit_actv_minimum )
        shape_new = shape
        sz = size if isinstance( size, numbers.Integral ) else size( shape )
        while np.prod( shape_gran * shape_rem ) * bits/8 > sz:
            if np.all( shape_rem == cut_off ):
                raise ValueError( "Unable to find shape fitting in {} with granuarity {}, bits {} and minimum {}".format( sz, shape_gran, bits, self.limit_actv_minimum ))
            p = np.argmax( shape_rem * ( shape_rem > cut_off ))
            shape_rem[p] = max( cut_off[p], np.ceil( shape_rem[p] / 2 ))
            shape_new = [int( x ) for x in np.minimum( shape, shape_gran * shape_rem )]
            sz = size if isinstance( size, numbers.Integral ) else size( shape_new )
        self.log_msg.append( '[WARN]: Limit s={:<15} o={:<15} -> s={:<15}'.format( str( shape ), order, str( shape_new )))
        return shape_new


    def sign_config_decoder( self, config, setup ):
        signs = {}
        if type( config ) is bool:
            for key in setup: signs[key] = config
        elif type( config ) is int:
            for i, key in enumerate( setup ): signs[key] = ( config>>i ) & 1
        elif type( config ) is dict:
            signs = config
        else:
          for key in setup: signs[key] = []
          for conf in config:
            if type( conf ) is bool:
              for key in setup: signs[key].append( conf )
            elif type( conf ) is int:
              for i, key in enumerate( setup ):
                signs[key].append(( conf>>i ) & 1 )
            elif type( conf ) is dict:
              for key in enumerate( setup ):
                signs[key].append( conf[key] )
        return signs

