import struct
import os
from utils.utils_common import log
from utils.build_utils import is_qdq_fp16


def float32_to_bfloat16_bytes(f32):
    """Converts a Python float to its bfloat16 representation (bytes)."""
    # Pack the float into a 32-bit IEEE 754 float
    packed_float32 = struct.pack('>f', f32) # Use big-endian for consistent bit order

    # Unpack it as an unsigned 32-bit integer to access its bits
    int_float32 = struct.unpack('>I', packed_float32)[0]

    # Shift the 32-bit float to get the bfloat16 representation
    # bfloat16 uses the same exponent and sign bits as float32,
    # but truncates the mantissa.
    bfloat16_int = int_float32 >> 16

    # Pack the resulting 16-bit integer into bytes
    return struct.pack('<H', bfloat16_int) # Use 'H' for unsigned short (2 bytes)


def float32_to_float16_bits(f):
    # 1. Pack float into 32-bit hex, then unpack as an unsigned integer
    f32_bits = struct.unpack('>I', struct.pack('>f', f))[0]

    # 2. Extract components
    # Sign bit: bit 31
    sign = (f32_bits >> 16) & 0x8000

    # Exponent: bits 23-30
    exponent = (f32_bits >> 23) & 0xFF

    # Mantissa: bits 0-22
    mantissa = f32_bits & 0x7FFFFF

    # 3. Re-bias and pack
    if exponent == 0:
        # Input is Zero or Subnormal
        new_exp = 0
        new_mant = 0
    elif exponent == 0xFF:
        # Input is Infinity or NaN
        new_exp = 0x1F
        new_mant = (mantissa >> 13) if mantissa == 0 else 0x200 # Keep NaN state
    else:
        # Normal number: Adjust bias (127 -> 15)
        new_exp = exponent - 127 + 15

        if new_exp >= 0x1F:
            # Overflow to Infinity
            new_exp = 0x1F
            new_mant = 0
        elif new_exp <= 0:
            # Underflow to Zero (simplification)
            new_exp = 0
            new_mant = 0
        else:
            # Normal conversion: shift mantissa from 23 bits to 10 bits
            new_exp = new_exp & 0x1F
            new_mant = mantissa >> 13

    # Combine into 16-bit integer
    f16 = sign | (new_exp << 10) | new_mant
    print("f16", f16)
    return struct.pack('<H', f16)

def float32_to_2_bytes(value):
    if is_qdq_fp16():
        return float32_to_float16_bits(value)
    else:
        return float32_to_bfloat16_bytes(value)
    

def linear_approx_layer_params(
    CoreInputAddr: int,
    CoreLUTAB_Addr: int,
    CoreLUTCD_Addr: int,
    CoreSpillAddr: int,
    CoreOutputAddr: int,
    QdqParamAddr: int,
    dqBufferAddr: int,
    qBufferAddr: int,
    sign_A: int, 
    sign_O: int, 
    idx_bias: float,    # float32
    num_iters: int,     # U16
    idx_max: float,     # convert to bfloat16
    idx_min: float,       # convert to bfloat16
    idx_mul: float,       # convert to bfloat16
) -> bytes:

    CoreInputAddr += 0xE0000
    CoreLUTAB_Addr += 0xE0000
    CoreLUTCD_Addr += 0xE0000
    CoreOutputAddr += 0xE0000
    CoreSpillAddr += 0xE0000
    QdqParamAddr += 0xE0000
    dqBufferAddr += 0xE0000
    qBufferAddr += 0xE0000

    bytes = (
        CoreInputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreLUTAB_Addr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreLUTCD_Addr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreSpillAddr.to_bytes(length=4, byteorder="little", signed=False)
        + CoreOutputAddr.to_bytes(length=4, byteorder="little", signed=False)
        + QdqParamAddr.to_bytes(length=4, byteorder="little", signed=False)
        + dqBufferAddr.to_bytes(length=4, byteorder="little", signed=False)
        + qBufferAddr.to_bytes(length=4, byteorder="little", signed=False)
        + sign_A.to_bytes(length=4, byteorder="little", signed=False)
        + sign_O.to_bytes(length=4, byteorder="little", signed=False)
        + struct.pack('f', idx_bias)
        + num_iters.to_bytes(length=2, byteorder="little", signed=False)
        + float32_to_2_bytes(idx_max)
        + float32_to_2_bytes(idx_min)
        + float32_to_2_bytes(idx_mul)
    )
    os.environ["ENABLE_LOG"]="true"
    log("Number of bytes layernorm layer params:", len(bytes))
    return bytes