Source code for quark.onnx.quantization.config.data_type
#
# Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT
#
"""Quark ONNX Quantization Data Type Classes"""
from onnx import TensorProto
from onnxruntime.quantization.quant_utils import QuantType
from quark.onnx.quantization.quant_utils import ExtendedQuantType
from quark.shares.data_type import (
BaseBFloat16,
BaseBFP16,
BaseDataType,
BaseFloat16,
BaseInt4,
BaseInt8,
BaseInt16,
BaseInt32,
BaseMX4,
BaseMX6,
BaseMX9,
BaseMXFP4_E2M1,
BaseMXFP6_E2M3,
BaseMXFP6_E3M2,
BaseMXFP8_E4M3,
BaseMXFP8_E5M2,
BaseMXInt8,
BaseUInt4,
BaseUInt8,
BaseUInt16,
BaseUInt32,
)
[docs]
class DataType(BaseDataType):
"""
Base class for representing a quantization data type.
"""
# Corresponding ONNX TensorProto data type.
onnx_proto_dtype: TensorProto
# Mapping to ONNX Runtime quantization type.
map_onnx_format: ExtendedQuantType | QuantType
[docs]
class Int4(BaseInt4):
"""Signed 4-bit integer quark onnx quantization data type."""
onnx_proto_dtype: TensorProto.INT4 # type: ignore
map_onnx_format = ExtendedQuantType.QInt4
[docs]
class UInt4(BaseUInt4):
"""Unsigned 4-bit integer quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UINT4
map_onnx_format = ExtendedQuantType.QUInt4
[docs]
class Int8(BaseInt8):
"""Signed 8-bit integer quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.INT8
map_onnx_format = QuantType.QInt8
[docs]
class UInt8(BaseUInt8):
"""Unsigned 8-bit integer quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UINT8
map_onnx_format = QuantType.QUInt8
[docs]
class Int16(BaseInt16):
"""Signed 16-bit integer quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.INT16
map_onnx_format = ExtendedQuantType.QInt16
[docs]
class UInt16(BaseUInt16):
"""Unsigned 16-bit integer quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UINT16
map_onnx_format = ExtendedQuantType.QUInt16
[docs]
class Int32(BaseInt32):
"""Signed 32-bit integer quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.INT32
map_onnx_format = ExtendedQuantType.QInt32
[docs]
class UInt32(BaseUInt32):
"""Unsigned 32-bit integer quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UINT32
map_onnx_format = ExtendedQuantType.QUInt32
[docs]
class Float16(BaseFloat16):
"""16-bit floating point quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.FLOAT16
map_onnx_format = ExtendedQuantType.QFloat16
[docs]
class BFloat16(BaseBFloat16):
"""16-bit Brain Floating Point quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.BFLOAT16
map_onnx_format = ExtendedQuantType.QBFloat16
[docs]
class BFP16(BaseBFP16):
"""Block Floating Point quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QBFP
[docs]
class MX4(BaseMX4):
"""MX4 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QBFP
[docs]
class MX6(BaseMX6):
"""MX6 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QBFP
[docs]
class MX9(BaseMX9):
"""MX9 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QBFP
[docs]
class MXFP4E2M1(BaseMXFP4_E2M1):
"""MXFP4E2M1 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QMX
[docs]
class MXFP6E3M2(BaseMXFP6_E3M2):
"""MXFP6E3M2 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QMX
[docs]
class MXFP6E2M3(BaseMXFP6_E2M3):
"""MXFP6E2M3 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QMX
[docs]
class MXFP8E5M2(BaseMXFP8_E5M2):
"""MXFP8E5M2 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QMX
[docs]
class MXFP8E4M3(BaseMXFP8_E4M3):
"""MXFP8E4M3 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QMX
[docs]
class MXInt8(BaseMXInt8):
"""MXInt8 quark onnx quantization data type."""
onnx_proto_dtype = TensorProto.UNDEFINED
map_onnx_format = ExtendedQuantType.QMX