import numpy as np
import math
import fnmatch
from typing import Tuple, Dict, List
from dataflow_common import ceildiv
import re
import unittest

datatype_bytes = {
    'int8' : 1,
    'uint8': 1,
    'int16': 2,
    'uint16': 2,
    'int32': 4,
    'uint32': 4,
    'int64': 8,
    'uint64': 8,
    'fp32':  4,
    'bfloat16': 2,
    'bfp16': 1.125,
    }

def sigmoid(x):
    return np.maximum(1/16, 2**np.minimum(np.floor(np.log2(x/512)),0))

def factors(num, gran):
    num = int(num)
    result = []
    for idx in range(1,math.isqrt(num)+1): #Since it is till sqrt, check both the pair
        if num % idx == 0:
            if idx % gran == 0:
                result.append(idx)
            if (num//idx) % gran == 0 and idx != num//idx:
                result.append(num//idx)
    return sorted(result)
    #return [i for i in range(1,num+1) if (num%i==0) & (i%gran==0)]

def get_bank_name(subbank_name: str) -> str:
    """
    Get main core bank name from bank name that potentially includes a subbank
    number or a begin/end specification.
    subbank_name -- "BANK<id>[.<subbank id>][.{BEGIN|END}]"
                    as read from placement metadata
    return -- "BANK<id>"

    Note: buffer_allocator.py BankNameMapper also parses the same format
    """
    return subbank_name.split(".")[0]

def compute_inverted_placement(placement_constraints: dict[str, dict[str, int | str]]) -> dict[str, str]:
    """
    placement_constraints -- placement constraints metadata, placement[buffer_name]["BANK<id>..."] = number | formula
    return -- inverted placement formulas per real bank, inverted_placement["BANK<id>"] = formula
    """
    inverted_placement: dict[str, str] = {}
    for k, v in placement_constraints.items():
        for sk, sv in v.items():
            skb = get_bank_name(sk)
            inverted_placement.setdefault(skb, '0 ')
            inverted_placement[skb] += ' + ' + str(sv)
    return inverted_placement

def count_leading_bits(arr):
    """
    Count the number of leading zeros for positive numbers and leading ones for negative numbers
    in a numpy array, assuming elements are int32.

    Parameters:
    arr (numpy.ndarray): Input numpy array of numbers.

    Returns:
    numpy.ndarray: Array containing the count of leading bits (zeros or ones) for each number.
    """

    is_scalar = np.isscalar(arr)
    arr = np.asarray(arr, dtype=np.int32)

    # Handle zero case
    zero_mask = (arr == 0) | (arr == -1)

    # For positive numbers: count leading zeros
    positive_mask = arr > 0
    # For negative numbers: count leading ones
    # In two's complement, leading ones = leading zeros of bitwise NOT
    negative_mask = arr < 0

    with np.errstate(divide='ignore', invalid='ignore'):
        positive_leading_zeros = np.where(positive_mask, (31 - np.floor(np.log2(arr)).astype(np.int32)), 0)
        negative_leading_ones = np.where(negative_mask, (31 - np.floor(np.log2(~arr)).astype(np.int32)), 0)

    # Combine results
    result = np.where(zero_mask, 32, positive_leading_zeros + negative_leading_ones)

    return result.item() if is_scalar else result

def process_overheads(args):
    def _calc_cycles_np(overhead: dict[str, np.ndarray]) -> int:
        cycles = overhead['call_count'] * (
            (overhead['outer_loop'] * ((overhead['inner_loop'] * overhead['cycles_per_inner_loop']) + overhead['outer_loop_OH']))
            + overhead['kernel_body_OH']
            )
        return cycles

    overhead_cycles = {}
    overhead_details = {}
    var_dict = {}
    keys_to_process = [
            'cycles',
            'call_count',
            'kernel_body_OH',
            'outer_loop',
            'outer_loop_OH',
            'inner_loop',
            'cycles_per_inner_loop',
            ]
    # Provide list of function additionaly to be used in scope of exec call
    add_exec_scope = {
            'clb': count_leading_bits,
            'ceildiv': ceildiv,
            'np': np,
            're': re,
            }
    for k1, v1 in args["overheads"].items():
        overhead_cycles_list = []
        overhead_details[k1] = {}
        # first execute 'vars' before looping over loop info
        if 'vars' in k1: # check if dict contains vars; then execute it
            for var in v1:
                exec(var, {**add_exec_scope, **locals()}, var_dict)
                locals().update(var_dict)
        else:
            for k2, v2 in v1.items():
                if 'vars' in k2: # check if dict contains vars; then execute it
                    for var in v2:
                        exec(var, {**add_exec_scope, **locals()}, var_dict)
                        locals().update(var_dict)
                else:
                    overheads = {}
                    for key in keys_to_process:
                        expr =v2.get(key, '0')
                        overheads[key] = eval(expr)
                        """NOTE: comments above and uncomment below to make it work in VS code windows debug"""
                        # overheads[key] = eval(expr, {**add_exec_scope, **var_dict, "args": args})
                    overheads['cycles'] = _calc_cycles_np(overheads)
                    if 'loop_skip' not in k2:
                        overhead_cycles_list.append(overheads['cycles'])
                    overhead_details[k1][k2] = overheads
        overhead_cycles[k1] = np.sum(overhead_cycles_list, axis=0)
        overhead_details[k1]['total_cycles'] = overhead_cycles[k1]

    return overhead_cycles, var_dict, overhead_details

def create_grid(start, end, steps, splits, divisibility):
    grid = list(range(start, math.ceil(end/splits)+1, steps))
    if divisibility:
        grid = list(filter(lambda x: end % x == 0, grid))

    return grid

def filter_grid(grid: List[int], dim: int, split: int, require_divisible: bool = False) -> List[int]:
    greater_points = [g for g in grid if g * split >= dim]
    filtered_grid = (
        grid if len(greater_points) == 0 else
        [g for g in grid if g <= min(greater_points)]
    )
    if require_divisible:
        filtered_grid = [g for g in filtered_grid
                         if ((dim % (g * split)) == 0) or (dim <= (g * split))]
    return filtered_grid

def check_reuse_chain_validity(reuse_ratio, num_consumers, max_chain_length, max_lock_value):
    valid_reuse_chain=[]
    for i in range(1, max_chain_length+1):
        valid_reuse_chain.append(
            ((reuse_ratio%i)==0) & (((reuse_ratio//i)*num_consumers<=max_lock_value))
        )
    return np.any(np.array(valid_reuse_chain), axis=0)

def check_iteration_chain_length(wrap: int, max_chain_length=4, MAX_ITER_WRAP=64):
    valid_iter_wrap=[]
    for i in range(1, max_chain_length+1):
        valid_iter_wrap.append(
            ((wrap%i) == 0) & ((wrap//i)<=MAX_ITER_WRAP)
        )
    return np.any(np.array(valid_iter_wrap), axis=0)

def check_special_handle_shapes(in_shape: int, in_pad: list, kernel: object) -> int:
    s_str = f'{in_shape[0]},{in_shape[1]},{in_shape[2]}'
    if hasattr(kernel, 'special_handle_shape'):
        for key, val in kernel.special_handle_shape.items():
            if fnmatch.fnmatch(s_str, key.replace(" ", "")):
                return [in_pad[idx] if x=='*' else int(x) for idx, x in enumerate(val.split(','))]
    return in_pad

class TestUtils(unittest.TestCase):
    def test_count_leading_bits(self):
        np.testing.assert_array_equal(count_leading_bits(np.array([-5, -4, -3, -2, -1, 0, 1, 2, 3, 4])), np.array([29, 30, 30, 31, 32, 32, 31, 30, 30, 29]))

if __name__ == "__main__":
    unittest.main()
