FP8 量化:原理、实现与误差分析
2.1 浮点数表示法:
在计算机中我们使用符号位、指数、底数三部分表示一个浮点数。符号位只占用1bit,用来表达数的正负,0-表示整数,1-表示负数。这也意味着在浮点数中0有两种表达方式:
接下来我们讨论浮点表示中的指数部分:
最后我们讨论浮点数中的底数部分:
2.2 数的浮点量化
2.3 FP32 到 FP8 的数据格式转换
Unscaled FP32 = FP32 / scale
FP8 = Convert(Unscaled FP32)
当 Unscaled FP32 数据已经超出 FP8 的表示范围,即 Unscaled FP32 的幅值大于 448,那么直接进行截断,此时为浮点上溢出。
当 Unscaled FP32 数据范围在 FP8 的表示范围内,且幅值大于 FP8 能够表达的最小值,此时需要移去多余的底数位,并对底数进行四舍五入。
当 Unscaled FP32 数据小于 FP8 能够表达的最小值,此时浮点下溢出,只需判断能否四舍五入为 (0 0000 001),若不能则直接为0。
union FPConvertHelper {
float value;
uint32_t data;
};
template<typename Dtype, typename Stype, typename Otype>
__device__ __inline__
float QuantizeScalarFloating(
const Dtype value, const Stype scale, const Otype offset,
const int exponent, const int mantissa,
const float clip_min, const float clip_max,
const Rounding rounding){
/**
* PPQ Quantization Function implementation.
* This function convert an float value to low-precision float
*/
FPConvertHelper helper; FPConvertHelper rounding_helper;
helper.value = static_cast<float>(value) / scale;
// Following code will Split float32 into sign, exp, mantissa
/* IEEE 754 Standard: 1 bit sign, 8 bit exponent, 23 bit mantissa */
/* In binary 10000000 00000000 00000000 00000000 = 0x80000000 in Hex */
/* In binary 01111111 10000000 00000000 00000000 = 0x7F800000 in Hex */
/* In binary 00000000 01111111 11111111 11111111 = 0x007FFFFF in Hex */
/* Tool: https://www.h-schmidt.net/FloatConverter/IEEE754.html */
uint32_t fp32_sign = helper.data & 0x80000000;
int32_t fp32_exp = helper.data & 0x7F800000;
int32_t fp32_mantissa = helper.data & 0x007FFFFF;
int32_t exponent_min = -(1 << (exponent - 1)) + mantissa;
int32_t exponent_max = (1 << (exponent - 1));
// Float Overflow.
if (value > clip_max) return clip_max;
if (value < clip_min) return clip_min;
// Following code will process Float underflow
/* Float underflow means fp32_exp is smaller than exponent_min */
/* Where exponent_min is the minimum exponent value of quantized float. */
/* For FP8 E4M3, the minimum exponent value should be -9. */
if (((fp32_exp >> 23) - 127) < exponent_min){
if (((fp32_exp >> 23) - 127) == (exponent_min - 1)){
// there is a chance to round
rounding_helper.data = (fp32_mantissa & 0x007FFFFF) + 0x3F800000;
if (_round2int(rounding_helper.value - 1, rounding)) {
helper.data = fp32_sign + ((exponent_min + 127) << 23) + (1 << (23 - mantissa));
return helper.value;
}
}
return 0.0f;
}
if ((fp32_exp >> 23) - 127 > exponent_max){
if (fp32_sign) return clip_min;
else return clip_max;
}
/* high precision mantissa convert to low precision mantissa requires rounding */
/* Here we apply a tricky method to round mantissa: */
/* We create another float, which sign = 0, exponent = 127, mantissa = fp32_mantissa << (23 - mantissa) */
/* Then we directly round this float to int, result here is what we want, you can prove it by yourself */
rounding_helper.data = ((fp32_mantissa << (mantissa)) & 0x007FFFFF) + 0x3F800000;
uint32_t round_bit = _round2int(rounding_helper.value - 1, rounding);
// process mantissa
fp32_mantissa = ((fp32_mantissa >> (23 - mantissa)) + round_bit) << (23 - mantissa);
helper.data = fp32_sign + fp32_mantissa + fp32_exp;
return CLIP<float>(helper.value + offset, clip_min, clip_max);
}
FP8 E4M3 | FP8 E5M2 | INT8 |
---|---|---|
0.06% | 0.2% | 0.008% |
INT8 | FP8 | |
---|---|---|
Inceptionv3 | 69.4 | 68.2 |
mnasnet | 63.9 | 22.3 |
mnasnet | 72.8 | 71.3 |
squeezenet | 57.8 | 57.1 |
shufflenet | 68.8 | 66.0 |
resnet18 | 69.6 | 69.4 |
mobilenetv2 | 70.9 | 67.2 |
mobilenetv3 | 73.3 | 70.3 |
efficientnet-b0 | 52.8 | 74.9 |
FP8的量化并不精确
FP8的量化具有良好的宽容度,我们期待他在QAT中取得更好的表现
FP8良好的宽容度可以量化一些奇怪的网络,例如Effcientnet。
· https://github.com/openppl-public/ppq/tree/master/ppq
· https://github.com/openppl-public/ppq/pull/274
· https://www.graphcore.ai/posts/graphcore-and-amd-propose-8-bit-fp-ai-standard-with-qualcomm-support
· 本文部分内容翻译自:https://arxiv.org/pdf/2209.05433.pdf
推荐阅读
辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!
机器学习算法工程师
一个用心的公众号
评论
全部评论
QS5894650b997c2401e2023-09-21 09:25