# pylint: skip-file

# Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
# Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# This file contains confidential and proprietary information
# of Xilinx, Inc. and is protected under U.S. and
# international copyright and other intellectual property
# laws.
#
# DISCLAIMER
# This disclaimer is not a license and does not grant any
# rights to the materials distributed herewith. Except as
# otherwise provided in a valid license issued to you by
# Xilinx, and to the maximum extent permitted by applicable
# law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
# WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
# AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
# INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
# (2) Xilinx shall not be liable (whether in contract or tort,
# including negligence, or under any other theory of
# liability) for any loss or damage of any kind or nature
# related to, arising under or in connection with these
# materials, including for any direct, or any indirect,
# special, incidental, or consequential loss or damage
# (including loss of data, profits, goodwill, or any type of
# loss or damage suffered as a result of any action brought
# by a third party) even if such damage or loss was
# reasonably foreseeable or Xilinx had been advised of the
# possibility of the same.
#
# CRITICAL APPLICATIONS
# Xilinx products are not designed or intended to be fail-
# safe, or for use in any application requiring fail-safe
# performance, such as life-support or safety devices or
# systems, Class III medical devices, nuclear facilities,
# applications related to the deployment of airbags, or any
# other applications that could lead to death, personal
# injury, or severe property or environmental damage
# (individually and collectively, "Critical
# Applications"). Customer assumes the sole risk and
# liability of any use of Xilinx products in Critical
# Applications, subject only to applicable laws and
# regulations governing limitations on product liability.
#
# THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
# PART OF THIS FILE AT ALL TIMES.

import numbers
import sys
import copy
import typing
from .data_generator.data_converter import DataConverter, decode_expo_bits

class NamedList:
    def __init__( self, field_names, init_val=0, init_access=None, init_missing=0 ):
        if isinstance( field_names, ( dict, NamedList )):
            init_access = 'dict'
            init_val = field_names
            if isinstance( field_names, dict ):
                field_names = list( field_names.keys( ))
                field_names.sort( )
            else:
                field_names = field_names._keys( )
        elif isinstance( field_names, ( tuple, list )) and len( field_names ) > 0 and isinstance( field_names[0], ( tuple, list )) and len( field_names[0] ) == 2:
            field_names,init_val = zip( *field_names )
        self._field_names = list( field_names )
        if sys.version_info >= ( 3, 0 ) and isinstance( init_val, ( map, zip, filter )): init_val = tuple( init_val )
        if not init_access:
            if   isinstance( init_val, ( list, tuple )):     init_access = 'list'
            elif isinstance( init_val, ( dict, NamedList )): init_access = 'dict'
        for idx, key in enumerate( field_names ):
            if not init_access: val = init_val
            elif init_access in ( 'dict', 'key' ):
                if isinstance( init_val, NamedList ): val = init_val._get( key, init_missing )
                else: val = init_val.get( key, init_missing )
            else: val = init_val[idx] if len( init_val )>idx else init_missing
            if isinstance( val, dict ):
                val = NamedList( val )
            setattr( self, key, val )

    def __getitem__( self, key ):
        if isinstance( key, numbers.Integral ):
            return getattr( self, self._field_names[key] )
        elif type( key ) is slice:
            return [getattr( self, k ) for k in self._field_names[key]]
        elif isinstance( key, ( tuple, list )) or ( sys.version_info >= ( 3, 0 ) and isinstance( key, ( map, zip ))):
            return tuple( map( self.__getitem__, key ))
        else:
            if not hasattr( self, key ):
                raise KeyError( "'{}' not in NamedList. Available are {}".format( key, self._field_names ))
            return getattr( self, key )

    def __setitem__( self, key, val ):
        if isinstance( key, numbers.Integral ):
            setattr( self, self._field_names[key], val )
        elif type( key ) is slice:
            for k, v in zip( self._field_names[key], val ): setattr( self, k, v )
        elif isinstance( key, ( tuple, list )) or ( sys.version_info >= ( 3, 0 ) and isinstance( key, ( map, zip ))):
            for k, v in zip( key, val ): self.__setitem__( k, v )
        elif not hasattr( self, key ):
            self._append( key, val )
        else:
            setattr( self, key, val )

    def __len__( self ):
        return len( self._field_names )

    def __iter__( self ):
        return iter( self._field_names )

    def _values( self ):
        return [getattr( self, key ) for key in self._field_names]

    def _get( self, key, default=None ):
        if isinstance( key, numbers.Integral ) and abs( key ) < len( self._field_names ):
            return getattr( self, self._field_names[key] )
        elif key in self._field_names:
            return getattr( self, key )
        else:
            return default

    def _keys( self ):
        return tuple( self._field_names )

    def _items( self ):
        return zip( self._field_names, self._values( ))

    def _append( self, key, val ):
        self._field_names.append( key )
        setattr( self, key, val )

    def _extend( self, named_list ):
        keys = named_list if isinstance( named_list, ( dict, )) else named_list._field_names
        for key in keys:
            if key not in self._field_names:
                self._field_names.append( key )
            setattr( self, key, named_list[key] )

    def _update( self, named_list ):
        if isinstance( named_list, ( tuple, list )):
            min_len = min( len( self ), len( named_list ))
            self[:min_len] = named_list[:min_len]
        else:
            keys = named_list if isinstance( named_list, ( dict, )) else named_list._field_names
            for key in self._field_names:
                if key in keys:
                    setattr( self, key, named_list[key] )

    def __or__( self, named_list ):
        new = self._copy( )
        for k,v in ( named_list.items( ) if isinstance( named_list, dict ) else named_list._items( )):
            if k not in self._field_names:
                new._append( k, v )
        return new


    def _cmp( self, op, c ):
        def top( a, b ):
            if a is None:
                if b is None:
                    return False
                else:
                    return op( -float('inf'), b )
            else:
                if b is None:
                    return op( a, -float('inf') )
                else:
                    return op( a, b )
        if isinstance( c, ( tuple, list, NamedList )):
            return [ top( self[k], c[k] ) for k in range( len( self._field_names ))]
        elif isinstance( c, dict ):
            return [ top( self[k], c.get( k )) for k in self._field_names ]
        else:
            return [ top( self[k], c ) for k in self._field_names ]

    def __eq__( self, c ):
        return self._cmp( lambda x, y: x==y, c )
    def __ge__( self, c ):
        return self._cmp( lambda x, y: x>=y, c )
    def __gt__( self, c ):
        return self._cmp( lambda x, y: x>y, c )
    def __le__( self, c ):
        return self._cmp( lambda x, y: x<=y, c )
    def __lt__(self, c):
        return self._cmp(lambda x, y: x < y, c)
    def __ne__( self, c ):
        return self._cmp( lambda x, y: x!=y, c )

    def _copy( self ):
        #return copy.copy( self )
        return type( self )( self )

    def __str__( self ):
        return "{" + ", ".join( ['"{}": {}'.format( k, '"'+v+'"' if type( v ) is str else ( int( v ) if type( v ) is bool else v )) for k, v in self._items( )] ) + "}"



def type_decoder( typ ):
    sgn = True
    if typ.startswith( 'uint' ):
        typ = typ[1:]
        sgn = False
    if typ in ( "int", "float" ):
        bits = 32
    elif typ == "bool":
        bits = 8
    elif typ == "bfloat16":
        bits = 16
    elif typ.startswith( 'int' ):
        if ':' in typ: bits = int( typ.split( ':' )[-1] )
        else:          bits = int( typ.split( '_' )[0][3:] )
    elif typ.startswith( "float" ):
        bits = int( typ[5:] )
    else:
        raise TypeError( "Please implement packing for this type {}".format( typ ))
    return bits, sgn



class TypedNamedList( NamedList ):
    def __init__( self, typed_fields, init_val=0, init_access=None, converter=None ):
        self._converter = converter if converter else DataConverter( bfloat=True )
        if len( typed_fields ) == 0:
            self._types = []
            self._field_names = []
        elif isinstance( typed_fields, TypedNamedList ):
            self._types = list( typed_fields._types )
            NamedList.__init__( self, typed_fields )
        else:
            types, field_names = zip( *map( self.__split_named_type, typed_fields ))
            self._types = list( types )
            NamedList.__init__( self, field_names, init_val=init_val, init_access=init_access )

    def __split_named_type( self, name ):
        if ( callable( name[0] ) and issubclass( name[0], TypedNamedList )) or isinstance( name, ( tuple, list )):
            return name
        typ, name = name.split( )
        if ':' in name:
            name, bw = name.split( ':' )
            typ += ':' + bw
        return ( typ, name )


    def _append( self, key, val ):
        t,key = self.__split_named_type( key )
        self._types.append( t )
        self._field_names.append( key )
        setattr( self, key, val )

    def _extend( self, named_list ):
        for t,key in map( self.__split_named_type, named_list ) if not isinstance( named_list, TypedNamedList ) else zip( named_list._types, named_list._field_names ):
            if key not in self._field_names:
                self._types.append( t )
                self._field_names.append( key )
            setattr( self, key, named_list[key] )


    def __or__( self, named_list ):
        assert isinstance( named_list, TypedNamedList ), "Only TypedNamedList supported for rhs of 'or'"
        new = self._copy( )
        for i,k in enumerate( named_list ):
            if k not in self._field_names:
                new._append(( named_list._types[i], k ), named_list[k] )
        return new

    def _get_alignment( self ):
        align = 8
        for typ, ( key, val ) in zip( self._types, self._items( )):
            if isinstance( val, TypedNamedList ):
                align = max( align, val._get_alignment( ))
            elif callable( typ ) and issubclass( typ, TypedNamedList ):
                align = max( align, typ( init_val=val )._get_alignment( ))
            else:
                try:
                    align = max( align, type_decoder( typ.split( ':' )[0] )[0] )
                except Exception as e:
                    print( f"Error encountered while extracting alignment for element {key} with type {typ}:" )
                    raise e
        return align


    def _get_stream( self, bits_stream=32, stream=None, ptr=0 ):
        def append_stream( stream, ptr, val, bits, sign=False ):
            if type( val ) is bool: val = int( val )
            val = int(val) & ( 2**bits-1 )
            pos  = ptr %  bits_stream
            word = ptr // bits_stream
            while pos + bits > bits_stream:
                take = bits_stream - pos
                push = (( val & ( 2**take-1 )) << pos )
                if pos==0: stream.append( push )
                else: stream[word] |= push
                val >>= take
                ptr += take; bits -= take
                pos  = ptr %  bits_stream
                word = ptr // bits_stream
            if bits > 0:
                if pos==0: stream.append( val )
                else: stream[word] |= ( val << pos )
                ptr += bits
            return stream, ptr

        chained = stream is not None
        align = self._get_alignment( )
        if not chained: stream = []
        elif ptr % align != 0:
            pad = align - ptr % align
            stream, ptr = append_stream( stream, ptr, 0, pad )

        for typ, val in zip( self._types, self._values( )):
            if isinstance( val, TypedNamedList ):
                stream, ptr = val._get_stream( bits_stream, stream, ptr )
            elif callable( typ ) and issubclass( typ, TypedNamedList ):
                stream, ptr = typ( init_val=val )._get_stream( bits_stream, stream, ptr )
            else:
                try:
                    bits, sign = type_decoder( typ )
                except Exception as e:
                    print( f"Error encountered while processing element {key} with type {typ}:" )
                    raise e
                if ptr % bits != 0 and ':' not in typ:
                    pad = bits - ptr % bits
                    stream, ptr = append_stream( stream, ptr, 0, pad )
                if "float" in typ:
                    ( expo_bits, expo_bias, fmt_ext ) = decode_expo_bits( typ )
                    val = self._converter.collapse_float( val, bits=bits, expo_bits=expo_bits )
                stream, ptr = append_stream( stream, ptr, val, bits, sign )
        if ptr % align != 0:
            pad = align - ptr % align
            stream, ptr = append_stream( stream, ptr, 0, pad )
        return ( stream, ptr ) if chained else stream


    def _set_from_stream( self, stream, bits_stream=32, ptr=None, buf=0, has=0 ):
        def strip_padding( pad, ptr, buf, has ):
            print( "strip pad in:", pad, ptr, has )
            if has >= pad:
                buf >>= pad
                has -=  pad
            else:
                pad -= has
                ptr += pad // bits_stream
                pad = pad % bits_stream
                if pad != 0:
                    has = bits_stream - pad
                    buf = stream[ptr] >> pad
                    ptr += 1
            print( "strip pad out", ptr, has )
            return ptr, buf, has

        chained = ptr is not None
        align = self._get_alignment( )
        if not chained: ptr = 0
        elif ( ptr * bits_stream - has ) % align != 0:
            ptr, buf, has = strip_padding( align - ( ptr * bits_stream - has ) % align, ptr, buf, has )

        for idx, typ in enumerate( self._types ):
            if isinstance( self._get( idx ), TypedNamedList ):
                ptr, buf, has = self._get( idx )._set_from_stream( stream, bits_stream, ptr, buf, has )
            elif callable( typ ) and issubclass( typ, TypedNamedList ):
                tnl = typ( )
                ptr, buf, has = tnl._set_from_stream( stream, bits_stream, ptr, buf, has )
                self.__setitem__( idx, tnl )
            else:
                try:
                    bits, sign = type_decoder( typ )
                except Exception as e:
                    print( f"Error encountered while processing element {key} with type {typ}:" )
                    raise e
                assert "float" not in typ, "Conversion from byte array to float type not yet implemented"
                if ( ptr * bits_stream - has ) % bits != 0 and ':' not in typ:
                    ptr, buf, has = strip_padding( bits - ( ptr * bits_stream - has ) % bits, ptr, buf, has )
                while has < bits:
                    buf |= ( stream[ptr] & ( 2**bits_stream-1 )) << has
                    has += bits_stream
                    ptr += 1
                val = buf & ( 2**bits-1 )
                buf >>= bits
                has -=  bits
                if sign and val & ( 2**( bits-1 )):
                    val -= 2**bits
                self.__setitem__( idx, val )

        if ( ptr * bits_stream - has ) % align != 0:
            ptr, buf, has = strip_padding( align - ( ptr * bits_stream - has ) % align, ptr, buf, has )
        return ( ptr, buf, has ) if chained else None


    def _to_byte_array( self ):
        return bytearray( self._get_stream( 8 ))

    def _set_from_byte_array( self, barr ):
        self._set_from_stream( barr, 8 )

NamedList_or_dict = typing.Union[ NamedList, dict ]
Optional_NamedList_or_dict = typing.Optional[ NamedList_or_dict ]

#helper functions to access dict or NamedList
def values( dnl: NamedList_or_dict ):
    return dnl._values( ) if isinstance( dnl, NamedList ) else dnl.values( )
def items( dnl: NamedList_or_dict ):
    return dnl._items( ) if isinstance( dnl, NamedList ) else dnl.items( )
def keys( dnl: NamedList_or_dict ):
    return dnl._keys( ) if isinstance( dnl, NamedList ) else dnl.keys( )
def get( dnl: NamedList_or_dict, key, default=None ):
    return dnl._get( key, default ) if isinstance( dnl, NamedList ) else dnl.get( key, default )
def copy( dnl: NamedList_or_dict ):
    return dnl._copy( ) if isinstance( dnl, NamedList ) else dnl.copy( )
