# pylint: skip-file

# Copyright (C) 2019 - 2022 Xilinx, Inc. All rights reserved.
# Copyright (C) 2022 - 2025 Advanced Micro Devices, Inc. All rights reserved.
#
# This file contains confidential and proprietary information
# of Xilinx, Inc. and is protected under U.S. and
# international copyright and other intellectual property
# laws.
#
# DISCLAIMER
# This disclaimer is not a license and does not grant any
# rights to the materials distributed herewith. Except as
# otherwise provided in a valid license issued to you by
# Xilinx, and to the maximum extent permitted by applicable
# law: (1) THESE MATERIALS ARE MADE AVAILABLE "AS IS" AND
# WITH ALL FAULTS, AND XILINX HEREBY DISCLAIMS ALL WARRANTIES
# AND CONDITIONS, EXPRESS, IMPLIED, OR STATUTORY, INCLUDING
# BUT NOT LIMITED TO WARRANTIES OF MERCHANTABILITY, NON-
# INFRINGEMENT, OR FITNESS FOR ANY PARTICULAR PURPOSE; and
# (2) Xilinx shall not be liable (whether in contract or tort,
# including negligence, or under any other theory of
# liability) for any loss or damage of any kind or nature
# related to, arising under or in connection with these
# materials, including for any direct, or any indirect,
# special, incidental, or consequential loss or damage
# (including loss of data, profits, goodwill, or any type of
# loss or damage suffered as a result of any action brought
# by a third party) even if such damage or loss was
# reasonably foreseeable or Xilinx had been advised of the
# possibility of the same.
#
# CRITICAL APPLICATIONS
# Xilinx products are not designed or intended to be fail-
# safe, or for use in any application requiring fail-safe
# performance, such as life-support or safety devices or
# systems, Class III medical devices, nuclear facilities,
# applications related to the deployment of airbags, or any
# other applications that could lead to death, personal
# injury, or severe property or environmental damage
# (individually and collectively, "Critical
# Applications"). Customer assumes the sole risk and
# liability of any use of Xilinx products in Critical
# Applications, subject only to applicable laws and
# regulations governing limitations on product liability.
#
# THIS COPYRIGHT NOTICE AND DISCLAIMER MUST BE RETAINED AS
# PART OF THIS FILE AT ALL TIMES.

import math
import numpy as np


class DataShaper:
    def __init__( self, defOrder='RC', print_info=False, get_first_significant_dim=True ):
        self.defOrder   = defOrder
        self.print_info = print_info
        self.get_first_significant_dim = get_first_significant_dim
        self.log_msg    = []


    def _reorder_granularity_range( self, order, z, start_data_dim=-1, stop_str_dim=None ):
        if stop_str_dim is None: stop_str_dim = len( order )
        gran = 1
        pre_group = {}
        for idx, s in enumerate( order.split( z )[1:] ):
            step = ''; pg = ''
            off = 0
            if s.find( ')' ) >= 0 and ( s.find( '(' ) > s.find( ')' ) or s.find( '(' ) < 0 ):
                pg, s = s.split( ')', 1 )
            for i, c in enumerate( s ):
                if c.isdigit( ):
                    step += c
                else:
                    if c == '>': off =  int( s[i+1] )
                    if c == '<': off = -int( s[i+1] )
                    break
            if len( pg )>0 and pg[0] in '<>':
                if pg[0] == '>': off =  int( pg[1] )
                if pg[0] == '<': off = -int( pg[1] )
                pg = pg[2:]
            if step and idx+off > start_data_dim and idx <= stop_str_dim:
                gran *= int( step )
                for p in pg:
                    if p.isdigit( ) or p in '<>': continue
                    elif p in pre_group: pre_group[p] *= int( step )
                    else:                pre_group[p]  = int( step )
        return gran, pre_group


    def _reorder_decode( self, shape, order, defOrder=None ):
        if not defOrder: defOrder = self.defOrder
        Ds = [order.count( c ) for c in defOrder]
        D  = list( shape )
        size = [0]*sum( Ds )
        perm = [0]*sum( Ds )
        pad_im = [0]*len( shape )
        pad_ex = [0]*sum( Ds )
        brdcst = [1]*sum( Ds )
        align  = [1]*sum( Ds )
        val = ''
        val_gi = ''
        off = 0
        group = False
        d = [sum( Ds[0:i+1] )-1 for i in range( len( Ds ))]
        p =  sum( Ds )-1
        for z in reversed( order ):
            if z.isdigit( ):
                if group: val_gi = z + val_gi
                else:     val    = z + val
            elif z == '>':
                if group: off =  int( val_gi ); val_gi = ''
                else:     off =  int( val );    val    = ''
            elif z == '<':
                if group: off = -int( val_gi ); val_gi = ''
                else:     off = -int( val );    val    = ''
            elif z == ')': group = True
            elif z == '(':
                vi = int( val )
                if vi>1:
                    pdi = d[idx] + 1
                    if D[idx]%vi != 0:
                        dim_sub = np.prod( np.maximum( 1, size[sum( Ds[0:idx] ):sum( Ds[0:idx+1] )] ))
                        pad_im[idx] += ( vi - D[idx]%vi ) * dim_sub
                    size[pdi+off] = vi
                    D[idx] = int( math.ceil( 1.0*D[idx]/vi ))
                group = False; off=0; val=''; val_gi=''
            elif z == '%':  # Pad dimension by N
                pad_ex[p+1] += max( 0, int( val )-1 ) * ( size[perm[p+1]]+pad_ex[p+1] )
                val = ''
            elif z == '*':  # Broadcast dimension by N
                brdcst[p] *= int( val )
                val = ''
            elif z == '|':  # Align data after a dimension to N
                align[p] *= int( val )
                val = ''
            elif z in defOrder:
                idx = defOrder.find( z )
                perm[p] = d[idx]+off
                if off<0:
                    start_dim = d[idx]+off-sum( Ds[0:idx] ) if val else -1
                    stop_dim  = d[idx]-sum( Ds[0:idx] )
                    gran, pre_group = self._reorder_granularity_range( order, z, start_dim, stop_dim )
                    for i, c in enumerate( defOrder ):
                        if c in pre_group:
                            if D[i] >= pre_group[c]:
                                gran //= pre_group[c]
                            elif D[i] > 1:
                                gran = int( math.ceil( 1.0*gran/D[i] ))
                    D_rem = max( 1, D[idx] // gran )
                else:
                    D_rem = D[idx]
                if val:
                    vi = int( val )
                    if group:
                        if vi>D_rem: vi_rem = vi//D_rem; vi//=vi_rem; val=str( vi_rem )
                        else: val = '1'
                    else: val = ''
                else:
                    vi = D_rem
                if vi>0:
                    if D[idx]%vi != 0:
                        dim_sub = np.prod( np.maximum( 1, size[sum( Ds[0:idx] ):sum( Ds[0:idx+1] )] ))
                        pad_im[idx] += ( vi - D[idx]%vi ) * dim_sub
                    size[d[idx]+off] = vi
                    D[idx] = int( math.ceil( 1.0*D[idx]/vi ))
                if not group:
                    off = 0
                d[idx]-=1; p-=1
        if self.print_info:
            self.log_msg.append( '[INFO]: reorder s={:<15} o={:<15} -> pi={:<15} s={:<30} p={:<30} pe={:<30}, b={:<30}, a={:<30}'.format( *list(map( str, ( shape, order, pad_im, size, perm, pad_ex, brdcst, align )))))
        return pad_im, size, perm, pad_ex, brdcst, align


    def reorder_mat( self, mat, order, defOrder=None, inverse=False ):
        pad_im, size, perm, pad_ex, brdcst, align = self._reorder_decode( mat.shape, order, defOrder )
        if not inverse:
            if sum( pad_im ) > 0:
                mat = np.pad( mat, tuple( zip( [0]*len( pad_im ), pad_im )), 'constant' )
            mat = mat.reshape( *size ).transpose( perm )
            if sum( pad_ex ) > 0:
                mat = np.pad( mat, tuple( zip( [0]*len( pad_ex ), pad_ex )), 'constant' )
            if np.prod( brdcst ) > 1:
                for idx, b in enumerate( brdcst ):
                    if b>1:
                        mat = np.repeat( mat, b, axis=idx )
            if np.prod( align ) > 1:
                for idx, a in reversed( tuple( enumerate( align ))):
                    if a>1:
                        mat = mat.reshape( mat.shape[:idx+1] + ( -1, ))
                        pad = a - ( mat.shape[-1] % a )
                        if pad < a:
                            mp = np.zeros(( len( mat.shape ), 2 ), dtype=int )
                            mp[-1, -1] = pad
                            mat = np.pad( mat, mp, 'constant' )
        else:
            assert sum( pad_im )==0, "Reverse of implicit padding not supported"
            assert sum( pad_ex )==0, "Reverse of explicit padding not supported"
            assert np.prod( brdcst )==1, "Reverse of broadcasting not supported"
            assert np.prod( align )==1, "Reverse of alignment not supported"
            perm_inv = [perm.index( p ) for p in range( len( perm ))]
            size_inv = [size[p] for p in perm]
            mat = mat.reshape( *size_inv )
            mat = mat.transpose( perm_inv )

        return mat.reshape( -1 )


    def get_dim_steps( self, shape, order, defOrder=None, bits=8, ebs=None, shift_bits_avg=0, sparse_ratio=1 ):
        if not defOrder: defOrder = self.defOrder
        pad_im, size, perm, pad_ex, brdcst, align = self._reorder_decode( shape, order, defOrder )
        sz = 1
        d = len( shape )-1
        sp = len( perm )
        dim = [0]*len( shape )
        key = ''.join( [ x for x in order if x.upper( ) >= 'A' and x.upper( ) <= 'Z' ] )
        for i, s in enumerate( reversed( size )):
            sz *= s
            p = len( perm )-1-i
            if ( sz >= shape[d]+pad_im[d] and self.get_first_significant_dim ) or order.count( defOrder[ d ] ) == sp - p:
                # current dimension contains all elements
                dim[d] = pi0 = perm.index( p )
                if s < sz:
                    pi1 = pi0
                    for pt in range( p+1, sp ):
                        pi2 = perm.index( pt )
                        if pi2 == pi1+1 or ( pi2 > pi1 and np.all( np.take( size, perm[pi1+1:pi2] ) == 1 )):
                            # Found X...X coupling
                            dim[d] = pi2
                            pi1 = pi2
                            self.log_msg.append( 'INFO: Found {} coupling ( order={}, size={}, perm={}, p={} )'.format( key[pi1:pi2+1], order, size, perm, p ))
                        else:
                            break
                sp -= order.count( defOrder[ d ] )
                sz = 1
                d -= 1
        # dim = [perm.index(p) for p in dim]
        size_inv = ( np.array( size )[perm] + pad_ex ) * brdcst
        idx = -2 if bits == 4 and size_inv[-1] == 2 else -1
        if ebs or sparse_ratio != 1: assert size_inv[idx] >= 8, "Sparse/exponent block is too small. Data ( order ) unexpected or update to script is required"
        # bytes per sample -> bits per sample per field accumulated and translated to bytes per sample and then applied to the data format shape.
        size_inv[idx] = int( ( size_inv[idx] * ( sparse_ratio * bits + ( shift_bits_avg / sparse_ratio if ebs else 0 ) ) ) / 8 ) + ( ( ( ( (size_inv[idx]//ebs)*(ebs//4)*3 ) // 8 ) if sparse_ratio != 1.0 else size_inv[idx]//ebs ) if ebs else 0 ) 
        step = [0]*( len( shape ) + 1 )
        cur = 1
        for i_rev, ( s, al ) in enumerate( reversed( tuple( zip( size_inv, align )))):
            i = len( perm )-1-i_rev
            if al > 1:
                cur = (( cur + al-1 )//al )*al
            if i in dim:
                step[dim.index( i )] = cur
            cur *= s
        step[-1] = cur
        return step
