# pylint: skip-file

import numpy as np
from kernels.python.named_list import *
from kernels.python.data_generator.generate_data_helpers import GenerateDataHelpers

log2 = np.log2
all = lambda *x: np.all( x )
any = lambda *x: np.any( x )
range = np.arange

pmax = max
pmin = min

def max( a, *b ):
    if len( b ) == 0:
        return np.max( a )
    elif len( b ) == 1:
        return np.maximum( a, *b )
    else:
        return pmax( a, *b )

def min( a, *b ):
    if len( b ) == 0:
        return np.min( a )
    elif len( b ) == 1:
        return np.minimum( a, *b )
    else:
        return pmin( a, *b )

def ceil( n, d=1 ):
    if d == 1:
        return int( np.ceil( n ))
    else:
        return d * int( np.ceil( n / d ))

def floor( n, d=1 ):
    if d == 1:
        return np.int64( np.floor( n ))
    else:
        return d * np.int64( np.floor( n / d ))

def round( n, d=1 ):
    if d == 1:
        return np.int64( np.round( n ))
    else:
        return d * np.int64( np.round( n / d ))

def sign( val ):
    if type( val ) == str:
        return type_decoder( val )[1]
    else:
        return val < 0

def sizeof( val ):
    return type_decoder( val )[0] // 8

def min_value_dtype( val ):
    bits, sign = type_decoder( val )
    if bits == 16 and sign == 0:
        return 0
    else:
        return np.iinfo( np.int16 ).min

class DimsHelper:
    def __init__( self, reset=0, bits=32 ):
        self.reset = reset
        self.bits = bits

    def __getitem__( self, key ):
        return getattr( self, key )

    def add_dimension( self, num, step ):
        inc = self.reset + step
        self.reset -= num * step
        return inc

    def from_steps( self, wraps, steps, next_loop_level=False ):
        wraps = make_tuple( wraps )
        steps = make_tuple( steps )
        assert len( steps ) in [1,2,3,4,5], "Only 1d to 5d address increments supported"
        assert len( wraps ) >= len( steps ) - 1, "Wrap dimesions passed are not sufficient"

        nums = []
        incs = []
        for i,s in enumerate( steps ):
            if i == len( wraps ):
                incs.append( self.reset + s )
                self.reset = 0
            else:
                if i < len( steps ) - 1:
                    if i % 3 == 2:
                        num = wraps[i]
                        nums.append( np.prod( wraps[:i+1] ) - 1 )
                    else:
                        num = wraps[i] - 1
                        nums.append( num )
                else:
                    num = wraps[i] - 1
                incs.append( self.add_dimension( num, s ))

                if ( next_loop_level and i == len( steps ) - 1 ) or ( i < len( steps ) - 1 and i % 3 == 2 ):
                    self.reset = -wraps[i] * s

        if len( incs ) == 1:
            return incs[0]
        else:
            return TypedNamedList([ f"uint{self.bits} num{n}" for n in range( len( nums ))] + [ f"int{self.bits} inc{n}" for n in range( len( incs ))], nums + incs )

def find_closest_shifted_int8( float_val ):
    INT8_MAX = 127
    prev_rel_err = 1e9
    curr_float_val = float_val
    best_float_val = float(0)
    shift_val = np.int16
    shift_val = 0
    best_int = np.int8
    closest_curr_int = np.int8
    best_shift_val = np.int16

    while curr_float_val <= INT8_MAX:
        closest_curr_int = round(curr_float_val)
        cur_rel_err = abs(float_val - closest_curr_int / (2**shift_val)) / float_val

        if cur_rel_err < prev_rel_err:
            prev_rel_err = cur_rel_err
            best_float_val = float(closest_curr_int >> shift_val)
            best_shift_val = shift_val
            best_int = closest_curr_int

        curr_float_val *= 2
        shift_val += 1
    return [best_int, best_shift_val]
    
def srs_shift (inner, bitsA, sgnA, bitsW, sgnW, bitsO, sgnO):
    srs = int( max( 0, bitsA - sgnA + bitsW - sgnW + int( round ( log2( inner ) ) ) - bitsO + sgnO ) )
    srs = int( max( 0, srs - max(sgnA, sgnW) * int( round( log2( inner )/2 ) ) ) )
    return srs

class MetaIndex:
    def __init__( self, meta, parameters, templates = None ):
        self._meta = meta
        self._templates = templates
        self._parameters = parameters

    def __getitem__( self, key ):
        assert type( key ) == str, "Index requires str argument, not " + str( key )
        def search( A, B ):
            if key in A:
                return ( A[key], B[key] )
            for k,v in items( A ):
                if isinstance( v, ( dict, NamedList )) and key in v:
                    return ( v[key], B[k][key] )
            else:
                return None

        args = None
        if self._templates is not None:
            args = search( self._meta.templates, self._templates )
        if args is None:
            args = search( self._meta.parameters, self._parameters )
        assert args is not None, f"'{key}' not found in templates and parameters"
        m,p = args
        return m.index( p )

    def __call__( self, key ):
        return self.__getitem__( key )


index = "Not defined"


def array( *x ):
    def unpack_list( l ):
        return [ unpack_list( i._values( )) if isinstance( i, NamedList ) else i for i in l ]
    x = unpack_list( x )
    return np.array( x ) if len( x ) > 1 else np.array( *x )


def sum( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    print( vals )
    return np.sum( vals, axis=axis )
def prod( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    return np.prod( vals, axis=axis )
def cumsum( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    return np.cumsum( vals, axis=axis )
def cumprod( vals, axis=None ):
    if isinstance( vals, ( NamedList, dict )):
        vals = tuple( values( vals ))
    return np.cumprod( vals, axis=axis )

def random_mat( shape, bits, bfloat=False, seed=237 ):
    gen = GenerateDataHelpers( None, default_order="E", bfloat=bfloat, seed=seed )
    return gen.random_mat( shape, bits )

def random_gen(low=-128, high=127, size=1, dtype=np.int8 ):
    return np.random.randint( low=low, high=high, size=size, dtype=dtype )

known_math_ops = ( "min", "max", "ceil", "floor", "round", "int", "float", "range", "log2", "prod", "sum", "cumprod", "cumsum", "array" )
known_logic_ops = ( "for", "in", "if", "else", "and", "or", "not", "all", "any" )
known_helper_ops = ( "sign", "sizeof", "min_value_dtype", "DimsHelper", "from_steps", "random_mat", "index", "find_closest_shifted_int8", "srs_shift", "random_gen" )
known_ops = known_math_ops + known_logic_ops + known_helper_ops

known_types = ( "int8", "int16", "int32", "uint8", "uint16", "uint32", "float8", "float16", "float32", "bfloat16" )

def make_tuple( val ):
    if hasattr( val, "__len__" ):
        return tuple( val )
    else:
        return ( val, )