import numpy as np
import ml_dtypes

def dequantize(input_int, zero_point, scale):
    """
    Dequantizes a low precision integer input to bfloat16 output.

    Args:
        input_int (np.ndarray or int): Low precision integer input.
        zero_point (int): Zero point (low precision integer).
        scale (ml_dtypes.bfloat16 or float): Scale (bfloat16).

    Returns:
        np.ndarray or ml_dtypes.bfloat16: Dequantized output in bfloat16.
    """
    # Convert input to numpy array for vectorized operations
    input_arr = np.array(input_int, dtype=np.int32)
    # Compute dequantized value
    output = scale * (input_arr - zero_point)
    # Cast to bfloat16
    return output.astype(ml_dtypes.bfloat16)


def saturate(value, dtype):
    """
    Saturate value to the limits of the given integer dtype.
    """
    info = np.iinfo(dtype)
    return np.clip(np.round(value), info.min, info.max).astype(dtype)

def quantize(input_bf16, scale, zero_point, dtype=np.uint8):
    """
    Quantizes bfloat16 input to low precision integer output.

    Args:
        input_bf16 (np.ndarray or float): Input in bfloat16.
        scale (ml_dtypes.bfloat16 or float): Scale in bfloat16.
        zero_point (int): Zero point (low precision int).
        dtype (np.dtype): Output integer type (e.g., np.uint8, np.int8, np.uint16, np.int16).

    Returns:
        np.ndarray or int: Quantized output in specified low precision integer type.
    """
    # Convert input to numpy array for vectorized operations
    input_arr = np.array(input_bf16, dtype=np.float32)
    scale_val = np.float32(scale)
    # Quantize
    quantized = input_arr / scale_val + zero_point
    # Saturate and cast to desired dtype
    return saturate(quantized, dtype)

def add_qdq(input1, input2, scale1, zero_point1, scale2, zero_point2, output_scale, output_zero_point, dtype=np.uint8):
    """
    Adds two low precision integer inputs after dequantizing them to bfloat16,
    then quantizes the result back to low precision integer.

    Args:
        input1 (np.ndarray or int): First low precision integer input.
        input2 (np.ndarray or int): Second low precision integer input.
        scale1 (ml_dtypes.bfloat16 or float): Scale for first input.
        zero_point1 (int): Zero point for first input.
        scale2 (ml_dtypes.bfloat16 or float): Scale for second input.
        zero_point2 (int): Zero point for second input.
        output_scale (ml_dtypes.bfloat16 or float): Scale for output.
        output_zero_point (int): Zero point for output.
        dtype (np.dtype): Output integer type (e.g., np.uint8, np.int8, np.uint16, np.int16).

    Returns:
        np.ndarray or int: Resulting quantized output in specified low precision integer type.
    """
    # Dequantize inputs
    dequantized1 = dequantize(input1, zero_point1, scale1)
    dequantized2 = dequantize(input2, zero_point2, scale2)
    
    # Add dequantized values
    added = dequantized1 + dequantized2
    
    # Quantize the result
    quantized_output = quantize(added, output_scale, output_zero_point, dtype)
    
    return quantized_output

def fused_add_qdq(input1, input2, zero_point1, zero_point2, zero_point_out,
                  scale_in1, scale_in2, scale_out, dtype=np.int8):
    """
    Fused add QDQ: output = saturate(alpha*input1 + beta*input2 + gamma)

    Args:
        input1, input2: Low precision integer inputs (np.ndarray or int).
        zero_point1, zero_point2, zero_point_out: Zero points (low precision int).
        scale_in1, scale_in2, scale_out: Scales (bfloat16 or float).
        dtype: Output integer type (e.g., np.int8, np.uint8, np.int16, np.uint16).

    Returns:
        np.ndarray or int: Output in specified low precision integer type.
    """
    # Convert inputs to numpy arrays for vectorized operations
    input1 = np.array(input1, dtype=np.int32)
    input2 = np.array(input2, dtype=np.int32)
    # Calculate alpha, beta, gamma
    alpha = np.float32(scale_in1) / np.float32(scale_out)
    beta = np.float32(scale_in2) / np.float32(scale_out)
    gamma = zero_point_out - alpha * zero_point1 - beta * zero_point2
    # Fused add QDQ
    output = alpha * input1 + beta * input2 + gamma
    return saturate(output, dtype)


def test_add_qdq_vs_fused_add(scale1, zero_point1, scale2, zero_point2, 
                             output_scale, output_zero_point, n_samples=1000):
    """
    Test function to compare add_qdq and fused_add_qdq by generating random samples
    and returning the absolute difference of the arrays.
    
    Args:
        scale1 (float): Scale for first input
        zero_point1 (int): Zero point for first input
        scale2 (float): Scale for second input
        zero_point2 (int): Zero point for second input
        output_scale (float): Scale for output
        output_zero_point (int): Zero point for output
        n_samples (int): Number of random samples to generate
        
    Returns:
        np.ndarray: Absolute differences between add_qdq and fused_add_qdq outputs
    """
    # Set random seed for reproducibility
    np.random.seed(42)
    
    # Generate random low precision integer inputs
    input1 = np.random.randint(0, 16*1024, size=n_samples, dtype=type(zero_point1))
    input2 = np.random.randint(0, 16*1024, size=n_samples, dtype=type(zero_point2))

    # Output dtype
    dtype = type(output_zero_point)
    
    # Compute outputs using both methods
    add_qdq_result = add_qdq(
        input1, input2, 
        scale1, zero_point1, 
        scale2, zero_point2, 
        output_scale, output_zero_point, 
        dtype
    )
    
    fused_add_result = fused_add_qdq(
        input1, input2,
        zero_point1, zero_point2, output_zero_point,
        scale1, scale2, output_scale,
        dtype
    )
    
    # Compute absolute differences
    abs_differences = np.abs(add_qdq_result - fused_add_result)
    
    return abs_differences