# pylint: skip-file

import numpy as np
import re
import importlib
import typing
from kernel_lib.python.named_list import *
from kernel_lib.python.data_generator.data_converter import type_expo_bits_default
from kernel_lib.python.known_functions import *


meta_hook = None

class MetaParser:
    def __init__( self,
            meta: NamedList,
            parameters: Optional_NamedList_or_dict = None,
            templates: Optional_NamedList_or_dict = None
    ):
        self._meta = meta
        global meta_hook
        meta_hook = meta
        if parameters is not None:
            self.update( parameters, templates )


    def update( self,
            parameters: NamedList_or_dict,
            templates: Optional_NamedList_or_dict = None
    ):
        meta = self._meta
        def set_index():
            global index
            index = MetaIndex( meta, parameters, templates )
        def do_eval( ex, variables=None, **kwargs ):
            return evaluate( ex, templates, parameters, variables, meta.parameters, **kwargs )
        def parse_function( #self,
                func_str: str,
                parse_only: typing.Optional[ bool ] = False ):
            assert func_str.startswith( "function" ), "string must start with 'function'"
            file_name, func = func_str[8:].split( ":", 1 )
            file_name = f"kernel_lib.kernels.{meta.name}.{meta.name}_{file_name.strip( )}"
            func_name, args = func.split( '(', 1 )
            func_name = func_name.strip( )
            func_mod = importlib.import_module( file_name )
            assert hasattr( func_mod, func_name ), f"expected function {func_name} in file {file_name}.py"
            func = getattr( func_mod, func_name )
            args = do_eval( '(' + args, variables )
            return ( func, args ) if parse_only else func( *args )
        def dict_eval( d, variables=None, **kwargs ):
            if isinstance( d, ( NamedList, dict )):
                if set( d ) == { "key", "values" }:
                    k = str( do_eval( d["key"], variables, **kwargs ))
                    if k in ( "True", "False" ) and k not in d["values"]:
                        k = str( int( k == "True" ))
                    return dict_eval( d["values"][k], variables, **kwargs )
                else:
                    ret = NamedList([ ])
                    for k,v in items( d ):
                        ret._append( k, dict_eval( v, variables, **kwargs ))
                    return ret
            elif type( d ) != str:
                return [ do_eval( v, variables, **kwargs ) for v in d ]
            elif d.startswith( "function" ):
                return parse_function( d )
            else:
                return do_eval( d, variables, **kwargs )
        def str_eval( ex, variables=None, **kwargs ):
            ret = do_eval( ex, variables, evaluate_strings=True, **kwargs )
            print( "str_eval:", ret )
            if type( ret ) == str:
                ret = do_eval( ret, variables, **kwargs )
            print( "str_eval:", ret )
            return ret

        set_index( )
        variables = {}
        if "variables" in meta:
            for k,v in items( meta.variables ):
                variables[k] = dict_eval( v, variables )
        for k,v in items( meta.requirements ):
            if isinstance( v, ( tuple, list )):
                for e in v:
                    if not do_eval( e, variables ):
                        raise ValueError( f"requirement {k} not met for params={parameters} with templates={templates} for subexpression '{e}'" )

            elif not do_eval( v, variables ):
                raise ValueError( f"requirement {k} not met for params={parameters} with templates={templates}" )

        dims = keys( meta.parameters["subvolume"] )
        if "time_split" in meta.parameters:
            dims_ts = keys( meta.parameters["time_split"] )
            time_split = NamedList( dims_ts, [ get( get( parameters, "time_split", get( variables, "time_split", {} )), k, 1 ) for k in dims_ts ])
        else:            
            time_split = NamedList( dims, [ get( get( parameters, "time_split", get( variables, "time_split", {} )), k, 1 ) for k in dims ])
        tensors = NamedList( dims, [parameters["tensor"][k] if "tensor" in parameters and k in parameters["tensor"]
                                    else parameters["subvolume"][k] * time_split[k] if k in keys(time_split)
                                    else parameters["subvolume"][k] for k in dims ])
        tensor_vars = tensors
        if "variables" in meta:
            tensor_vars |= meta.variables
        self.kernel_interface = NamedList( meta.kernel_interface._keys( ))
        self.time_split = time_split

        for buf,p in items( meta.kernel_interface ):
            dtype = p._get( "dtype", "int8" )
            if "dtype" in parameters:
                if type( parameters["dtype"] ) == str:
                    dtype = parameters["dtype"]
                else:
                    dtype = get( parameters["dtype"], buf, dtype )
            if "type" in dtype or "T" in dtype:
                dtype = evaluate( dtype, templates, parameters )
            bits, sign = type_decoder( dtype )
            sync_type = p._get( "sync_type", "L1" )
            if type( sync_type ) != str:
                if "stream" in sync_type:
                    assert templates["sync_type"][buf] in sync_type, f"Expected sync_type to be found in buffer interface for buffer {buf}"
                    sync_type = templates["sync_type"][buf]
                else:
                    sync_type = sync_type[0]

            def get_shape( dim_vars, str_mode=False ):
                if "shape" not in p:
                    if buf[0] == "P" and "size" not in p:
                        return bits // 8
                    return str_eval( p.size, dim_vars ) // ( bits // 8 )
                elif type( p.shape ) == str:
                    return str_eval( p.shape, dim_vars )
                else:
                    return [ str_eval( d, dim_vars ) for d in p.shape ]
            def advance( k ):
                svv = NamedList( dims, parameters["subvolume"] )
                svv[k] *= 2
                if "variables" in meta:
                    svv |= meta.variables
                return svv

            sv = get_shape( variables, str_mode=True )
            
            
            if "time_iter_dims" in p:
                td = NamedList([( k, v ) for k,v in time_split._items( ) if k in p.time_iter_dims ])
                print( td )
                not_td = NamedList([( k, v ) for k,v in time_split._items( ) if k not in p.time_iter_dims ])
                buf_tensor_vars = NamedList([ (k, tensor_vars[k]//v) for k,v in not_td._items( )] + [(k, v) for k,v in tensor_vars._items( ) if k not in not_td._keys( )])
                print( not_td )
            else:
                buf_tensor_vars = tensor_vars
                td = time_split
            
            if "tensor_shape" in p:
                tv = dict_eval( p.tensor_shape, buf_tensor_vars, evaluate_strings=True )
            else:
                tv = get_shape( buf_tensor_vars )
            

            if "time_step" in p:
                ts = dict_eval( p.time_step, variables )
                if not isinstance( ts, ( NamedList, dict )):
                    ts = np.array( ts )
                    if ts.ndim == 1:
                        ts = np.diag( ts )
                    ts = dict( zip( dims, ts ))
                #ts = [ ts[k] for k,v in items( td ) if v > 1 ]
                ts = [ ts[k] for k in keys( td )]
            else:
                #ts = [( 1 if "time_split." + k in sv else np.subtract( get_shape( advance( k )), sv )) for k,v in items( td ) if v > 1 ]
                print( td, sv )
                for k in keys( td ):
                    print( advance( k ))
                    print( get_shape( advance( k )))
                    print( np.subtract( 512, 512 ))
                ts = [( 1 if "time_split." + k in get( p, "shape", "" ) else np.subtract( get_shape( advance( k )), sv )) for k in keys( td )]
            if "padding" in p:
                tp = np.array( dict_eval( p.padding, variables ))
                if tp.ndim == 1:
                    tp = np.vstack(( tp, np.matmul( np.subtract( values( td ), 1 ), ts ) + sv - ( tv + tp ))).T
            else:
                tp = np.matmul( np.subtract( values( td ), 1 ), np.vstack(ts)) + sv - tv
                if np.any( tp > 0 ):
                    tp = np.pad( np.maximum( 0, tp ).reshape( 1, -1 ), (( 1, 0 ), ( 0, 0 )), "constant" ).T
                else:
                    tp = 0

            shape = {
                "subvolume": sv,
                "tensor": tv,
                "time_iters": td,
                "time_step": [ ts[i] for i,v in enumerate( values( td )) if v > 1 ],
                "padding": tp
            }

            order = p.order[2] if templates and "sync_type" in templates and templates["sync_type"].get( buf ) == "stream" else ( None if "order" not in p else p.order[1] )
            if isinstance( order, ( dict, NamedList )):
                k = str( do_eval( order["key"], variables ))
                if k in ( "True", "False" ) and k not in order["values"]:
                    k = str( int( k == "True" ))
                print( f'order dict: get {order["key"]} as {k} from {order["values"]}' )
                order = order["values"][k]

            self.kernel_interface[buf] = NamedList({
                "dtype": dtype,
                "bits": bits,
                "sign": sign,
                "bfloat": "float" in dtype,
                "expo_bits": type_expo_bits_default[dtype] if "float" in dtype else 8,
                "shape": shape,
                "data_order": order,
                "raw_order": p.order[0] if "order" in p else None,
                "sync_type": sync_type,
                "name": p.name if "name" in p else {"I": "input", "O": "output"}.get( buf[0], buf )
            })
            
            # specify reuse over multiple iterations for async buffers.
            # Usage:
            #   1. "reuse": "<iters>"
            #      buffer used over this number of iterations
            #   2. "reuse": { "iters": "<iters>", "step": "<step>" }
            #      In addition allow a step between different iterations (linear, default = 0)
            if "reuse" in p:
                if isinstance( p.reuse, ( dict, NamedList )):
                    reuse = {
                        "iters": do_eval( p.reuse["iters"], variables ),
                        "step" : do_eval( p.reuse["step"], variables )
                    }
                else:
                    reuse = {
                        "iters": do_eval( p.reuse, variables ),
                        "step" : 0
                    }
                self.kernel_interface[buf]["reuse"] = reuse

            # Check if there is a range
            if "range" in p:
                min_val = do_eval( p.range[0], variables )
                max_val = do_eval( p.range[1], variables )
                self.kernel_interface[buf]["min"] = min_val
                self.kernel_interface[buf]["max"] = max_val

            if "value" in p:
                if p.value.startswith( "function" ):
                    (func, args) = parse_function( p.value, parse_only=True )
                    self.kernel_interface[buf]._append( "value", { "function": func, "args": args })
                else:
                    self.kernel_interface[buf]._append( "value", do_eval( p.value, variables ))

            if "packing" in p:
                packing = NamedList({
                    "shape": [ do_eval( d, variables ) for d in p.packing.shape ],
                    "refs": {
                        "body": p.packing.function.body,
                        "args": p.packing.function.args,
                        "vars": variables,
                    }
                })
                func = lambda *args: do_eval( packing["refs"]["body"], packing["refs"]["vars"] | dict( map( lambda z: ( z[0], LocalVar( *z )), zip( packing["refs"]["args"], args ))), use_exec=True )
                packing._append( "function", func )
                self.kernel_interface[buf]._append( "packing", packing )

            size_vars = NamedList(( "shape", "dtype", "packing" ), ( shape["subvolume"], dtype, packing if "packing" in p else None )) | variables
            if "size" in p:
                size = do_eval( p.size, size_vars )
            else:
                size = prod( shape["subvolume"] ) * ( bits / 8 )
            if "L1_size" in p:
                L1_size = do_eval( p.L1_size, size_vars )
                assert L1_size >= size, "Size to allocate in L1 must be at least match the data size"
            else:
                L1_size = size

            self.kernel_interface[buf]._append( "sv_size", size )
            self.kernel_interface[buf]._append( "L1_size", L1_size )

        for k,v in items( self.kernel_interface ):
            print( f"kernel_interface {k}:", str( v ))

        self.kernel_params = TypedNamedList([ ])
        for k,v in items( meta.kernel_parameter_setup ):
            if k.startswith( "function" ):
                self.kernel_params._extend( parse_function( k + ':' + v ))
            else:
                vls = dict_eval( v, variables )
                self.kernel_params._append( k, vls )
                variables[ k.split( )[-1]] = vls

        if "model_parameters" in meta:
            if isinstance( meta.model_parameters, ( dict, NamedList )):
                self.model_params = NamedList([ ])
                for k,v in items( meta.model_parameters ):
                    self.model_params._append( k, dict_eval( v, variables ))
            else:
                self.model_params = do_eval( meta.model_parameters, variables )
        else:
            self.model_params = NamedList( self.kernel_params )
            if "model_extra_parameters" in meta:
                for k,v in items( meta.model_extra_parameters ):
                    self.model_params._append( k, do_eval( v, variables ))
        print( str( self.model_params ))


        performance = NamedList([])
        for k,v in items( meta.performance ):
            while isinstance( v, ( dict, NamedList )):
                v = v["values"][str( do_eval( v["key"], variables ))]
            performance[k] = do_eval( v, variables )
            variables[k] = performance[k]
        self.performance_metrics = performance





class LocalVar: #there should be a better way
    def __init__( self, name, ref=None ):
        self.name = name
        self.ref = ref
    def __str__( self ):
        return self.name


def evaluate( expression: typing.Union[ str, int, float, None ],
        templates: NamedList_or_dict,
        parameters: NamedList_or_dict,
        variables: Optional_NamedList_or_dict = None,
        parameter_types: Optional_NamedList_or_dict = None,
        parse_only: bool = False,
        use_exec: bool = False,
        evaluate_strings: bool = False,
):
    if type( expression ) != str:
        return expression
    resolved = ''
    local_objects = {}
    local_object_id = 0
    is_float = re.compile(r"^[-+]?(?:\b[0-9]+(?:\.[0-9]*)?|\.[0-9]+\b)(?:[eE][-+]?[0-9]+\b)?$").match
    is_hex = re.compile(r"^0x[0-9a-fA-F]+$").match
    str_pat = re.compile( r"('[^']+')" )
    word_pat = re.compile( r"([0-9a-zA-Z_.\[\]]+)" )
    try:
      for sp in str_pat.split( expression ):
        if str_pat.match( sp ):
            resolved += sp
            continue
        for part in word_pat.split( sp ):
            if is_float( part ) or is_hex( part ) or ( not word_pat.match( part )) or part in known_ops or part in ( '[', ']' ):
                resolved += part
            elif part in known_types:
                resolved += "'" + part + "'"
            else:
                obj = None
                if part[0] == '[':
                    raise AttributeError( f"list specification not supported for {expression}. Use tuple instead" )
                from_params = False
                while len( part ) > 0:
                    def check( c ):
                        return part.find( c ) if c in part else len( part )
                    def check_parameter( p, dt ):
                        if type( p ) == str and ( "int" in dt or "float" in dt ):
                            print( f"call again with {p}" )
                            p = evaluate( p, templates, parameters, parameter_types=parameter_types )
                        return p

                    s = min( check( '.' ), check( '[' ), check( ']' ))
                    name = part[:s]
                    part = part[s+1:]
                    if name.isdigit( ):
                        name = int( name )
                    if obj is not None:
                        #obj = getattr( obj, name )
                        dnl = isinstance( obj, ( dict, NamedList ))
                        if type( name ) == str and not hasattr( obj, name ) and ( not isinstance( obj, ( dict, NamedList )) or ( name not in obj )):
                            name = evaluate( name, templates, parameters, parameter_types=parameter_types )
                        try:
                            obj = obj[name]
                        except:
                            raise AttributeError( f"Failed to access '{obj}' with '{name}'" )
                    elif variables is not None and name in variables:
                        obj = variables[name]
                    elif templates is not None and name in templates:
                        obj = templates[name]
                        from_params = True
                    elif name in parameters:
                        obj = parameters[name]
                        if parameter_types is not None:
                            obj = check_parameter( obj, parameter_types[name] )
                        from_params = True
                    else:
                        for k,v in items( parameters ):
                            if isinstance( v, ( dict, NamedList )) and name in v:
                                obj = v[name]
                                if parameter_types is not None:
                                    obj = check_parameter( obj, parameter_types[k][name] )
                                from_params = True
                                break
                        else:
                            if obj is None and len( name + part ) == 1:
                                obj = LocalVar( name )
                            elif name == "parameters":
                                obj = parameters
                            elif name == "templates":
                                obj = templates
                            else:
                                raise KeyError( f"Error: Failed to find '{name}' (of: {expression})" )
                if evaluate_strings and type( obj ) == str and not from_params:
                    if obj.startswith( "function" ):
                        file_name, func = obj[8:].split( ":", 1 )
                        assert meta_hook is not None
                        file_name = f"kernel_lib.kernels.{meta_hook.name}.{meta_hook.name}_{file_name.strip( )}"
                        func_name, args = func.split( '(', 1 )
                        func_name = func_name.strip( )
                        func_mod = importlib.import_module( file_name )
                        assert hasattr( func_mod, func_name ), f"expected function {func_name} in file {file_name}.py"
                        func = getattr( func_mod, func_name )
                        args = evaluate( '(' + args, templates, parameters, variables, parameter_types, parse_only, evaluate_strings=True )
                        assert parse_only == False, "Next statement is currently executing, not just parsing"
                        obj = func( *args )
                    else:
                        print( f"call again with {obj}" )
                        obj = evaluate( obj, templates, parameters, variables, parameter_types, parse_only, evaluate_strings=True )
                if callable( obj ) or isinstance( obj, ( NamedList, dict, np.ndarray )):
                    n = f"obj{local_object_id}"
                    local_objects[n] = obj
                    resolved += n
                    local_object_id += 1
                else:
                    if isinstance( obj, LocalVar ) and obj.ref is not None:
                        local_objects[obj.name] = obj.ref
                    resolved += str( obj ) if type( obj ) != str else f'"{obj}"'
    except Exception as e:
        print( f"ERROR encountered while parsing expression {expression}" )
        raise repr( e )
    if parse_only:
        return ( resolved, local_objects )
    else:
        print( f"expression '{expression}' resolved to '{resolved}' ", end=None
 )
        try:
            if use_exec:
                ret = exec( resolved, None, local_objects )
            else:
                ret = eval( resolved, None, local_objects )
        except Exception as e:
            print( "failed with error:" )
            raise repr( e )
        print( f"yielding: {ret}" )
        return ret

