#
# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
from enum import Enum
from quark.shares.data_type import BaseDataType
from quark.shares.utils.log import ScreenLogger
from .data_type import (
BFP16,
MX4,
MX6,
MX9,
MXFP4E2M1,
MXFP6E2M3,
MXFP6E3M2,
MXFP8E4M3,
MXFP8E5M2,
BFloat16,
Int8,
Int16,
Int32,
MXInt8,
UInt8,
UInt16,
UInt32,
)
logger = ScreenLogger(__name__)
# TODO: Write a separate class for each calibration method.
[docs]
class CalibMethod(Enum):
"""
Enumeration of calibration methods used for determining quantization parameters.
"""
MinMax = 0
MinMSE = 1
Percentile = 2
Entropy = 3
LayerwisePercentile = 4
Distribution = 5
[docs]
class ScaleType(Enum):
"""
Enumeration of scale types used in quantization.
"""
Float32 = 0
PowerOf2 = 1
Int16 = 2
[docs]
class QuantGranularity(Enum):
"""
Enumeration of quantization granularity.
"""
Tensor = 0
Channel = 1
Group = 2
# TODO: Move QTensorConfig into the quark/shares
[docs]
class QTensorConfig:
"""
Configuration for a quantized tensor.
:param bool symmetric: Whether to use symmetric quantization.
:param ScaleType scale_type: Type of scaling to apply.
:param CalibMethod calibration_method: Method for calibration.
:param QuantGranularity quant_granularity: Level of quantization granularity.
:param BaseDataType data_type: Data type of quantization.
"""
def __init__(
self,
symmetric: bool,
scale_type: ScaleType,
calibration_method: CalibMethod,
quant_granularity: QuantGranularity,
data_type: BaseDataType,
) -> None:
self.symmetric = symmetric
self.scale_type = scale_type
self.calibration_method = calibration_method
self.quant_granularity = quant_granularity
self.data_type = data_type
[docs]
def set_symmetric(self, symmetric: bool) -> None:
"""Set whether symmetric quantization is used."""
self.symmetric = symmetric
[docs]
def set_scale_type(self, scale_type: ScaleType) -> None:
"""Set the scale type."""
self.scale_type = scale_type
[docs]
def set_calibration_method(self, calibration_method: CalibMethod) -> None:
"""Set the calibration method."""
self.calibration_method = calibration_method
[docs]
def set_quant_granularity(self, quant_granularity: QuantGranularity) -> None:
"""Set the quantization granularity."""
self.quant_granularity = quant_granularity
[docs]
def set_data_type(self, data_type: BaseDataType) -> None:
"""Set the data type."""
self.data_type = data_type
[docs]
class Int8Spec(QTensorConfig):
"""
Quantization specification for int8 tensors (default Float32 scaling and Percentile calibration).
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.Percentile,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = Int8,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class UInt8Spec(QTensorConfig):
"""
Quantization specification for uint8 tensors.
"""
def __init__(
self,
symmetric: bool = False,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.Percentile,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = UInt8,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class XInt8Spec(Int8Spec):
"""
Quantization specification for int8 tensors with power-of-2 scaling.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.PowerOf2,
calibration_method: CalibMethod = CalibMethod.MinMSE,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = Int8,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class Int16Spec(QTensorConfig):
"""
Quantization specification for int16 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.Percentile,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = Int16,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class UInt16Spec(QTensorConfig):
"""
Quantization specification for uint16 tensors.
"""
def __init__(
self,
symmetric: bool = False,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.Percentile,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = UInt16,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class Int32Spec(QTensorConfig):
"""
Quantization specification for int32 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.Percentile,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = Int32,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class UInt32Spec(QTensorConfig):
"""
Quantization specification for uint32 tensors.
"""
def __init__(
self,
symmetric: bool = False,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.Percentile,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = UInt32,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class BFloat16Spec(QTensorConfig):
"""
Specification for bfloat16 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = BFloat16,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class BFP16Spec(QTensorConfig):
"""
Specification for Block Floating Point (BFP16) tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = BFP16,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MX4Spec(QTensorConfig):
"""
Specification for MX4 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MX4,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MX6Spec(QTensorConfig):
"""
Specification for MX6 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MX6,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MX9Spec(QTensorConfig):
"""
Specification for MX9 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MX9,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MXFP4E2M1Spec(QTensorConfig):
"""
Specification for MXFP4E2M1 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MXFP4E2M1,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MXFP6E3M2Spec(QTensorConfig):
"""
Specification for MXFP6E3M2 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MXFP6E3M2,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MXFP6E2M3Spec(QTensorConfig):
"""
Specification for MXFP6E2M3 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MXFP6E2M3,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MXFP8E5M2Spec(QTensorConfig):
"""
Specification for MXFP8E5M2 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MXFP8E5M2,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MXFP8E4M3Spec(QTensorConfig):
"""
Specification for MXFP8E4M3 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MXFP8E4M3,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
[docs]
class MXInt8Spec(QTensorConfig):
"""
Specification for MXInt8 tensors.
"""
def __init__(
self,
symmetric: bool = True,
scale_type: ScaleType = ScaleType.Float32,
calibration_method: CalibMethod = CalibMethod.MinMax,
quant_granularity: QuantGranularity = QuantGranularity.Tensor,
data_type: type[BaseDataType] = MXInt8,
):
super().__init__(symmetric, scale_type, calibration_method, quant_granularity, data_type)
# TODO: Move QLayerConfig into the quark/shares
[docs]
class QLayerConfig:
"""
Layer-level quantization configuration.
:param QTensorConfig input_tensors: Quantization spec for input_tensors.
:param QTensorConfig activation: Quantization spec for activations.
:param QTensorConfig weight: Quantization spec for weights.
:param QTensorConfig bias: Quantization spec for bias.
:param QTensorConfig output_tensors: Quantization spec for output_tensors.
"""
def __init__(
self,
input_tensors: QTensorConfig | None = None,
activation: QTensorConfig | None = None,
weight: QTensorConfig | None = None,
bias: QTensorConfig | None = None,
output_tensors: QTensorConfig | None = None,
):
self.input_tensors = input_tensors
self.activation = activation
self.weight = weight
self.bias = bias
self.output_tensors = output_tensors