半精度(FP16)调试血泪总结
点击上方“视学算法”,选择加"星标"或“置顶”
重磅干货,第一时间送达
导读
本文分享了一般的FP16数值溢出情况下的两种处理方式:一个快速查找数值溢出算子的方法;一个替换多个算子,从原始模型解决FP16数值溢出的方法。
问题描述
MMOCR在MMDeploy中部署时,PANet模型在以TensorRT-fp16为后端的情况下会有精度损失。hmean-iou由原本的0.8-掉点到0.2-。此时需要相应的debug查找问题原因。
MMOCR:https://github.com/open-mmlab/mmocr
MMDeploy:https://github.com/open-mmlab/mmdeploy
PANet:https://github.com/open-mmlab/mmocr/blob/main/configs/textdet/panet/panet_r18_fpem_ffm_600e_icdar2015.py
排除法查找节点
首先请教了有相关经验的同事,被告知一般只能二分查找,没有更方便的工具。此外,如果ONNX中有reduce_sum和avg_pool节点,可以重点尝试排查。经过netron可视化ONNX文件,发现PANet中无以上两种节点。没办法,只能笨办法慢慢查。这时候,我首先想到的是,骨干网络应该不会出现数值溢出问题。而PANet中与其他常用的无此相关问题的模型的差异在于其neck部分是,FPEM-FFM结构。会不会是在这个结构中产生问题的呢?所以,我没有直接使用二分查找。
卷积层映射
具体深入查看FPEM-FFM结构,发现有其他常用网络不用的卷积,深度可分离卷积。这时,很直觉的判断就是这个,所以,这时候我对该结构内的层采用的应对FP16数值溢出常用的解决方法。计算前减小数值,计算后恢复原来大小。那么对于卷积而言,一次卷积运算可以认为是:输入x,输出conv(x) = wx+b,其中W是权重weight,b是偏移量bias。如果我想要保证卷积过程无数值溢出情况,只需要将输入减小,计算完成后再将结果映射成真实值。实际卷积的计算过程变成:conv(x/p) = w(x/p)+b。卷积完成后,先乘以p再减去(p-1)b即可。亦即:p*conv(x/p)-(p-1)b = wx+b。很自然地,对于一个卷积层conv和输入张量x,进行如下变换:
def warp_conv(x, conv, factor: int=32):
"""(W*(x/p)+b)*p-b*(p-1) == Wx+b
conv(x) == warp_conv(x, conv)
"""
x_tmp = conv(x / factor)
return factor * x_tmp - (factor - 1) * conv.bias.reshape(
1, -1, 1, 1).repeat(1, 1, x_tmp.size(2), x_tmp.size(3))
原始计算过程conv(x)被等效成 warp_conv(x, conv),并且,等效过程不会出现数值溢出。
归一化层映射
除了可能的卷积层中数值溢出外,归一化层出现问题的几率更高。归一化层的计算过程如下:
# input:x, output:out
w = weight / torch.sqrt(running_var + eps)
out = x * w + (bias - running_mean * w)
可以发现,归一化层bn的计算过程实际也是个线性运算:bn(x) = wx+b,其中w = weight / torch.sqrt(running_var + eps),而b = bias - running_mean * w。那么类似卷积层的一个映射函数可以是:
def warp_bn(x, bn, factor: int=32):
import torch
scale = bn.weight / torch.sqrt(bn.running_var + bn.eps)
bias = bn.bias - bn.running_mean * scale
bias_t = bias.reshape(1, -1, 1, 1).repeat(1, 1, x.size(2), x.size(3))
return bn(x / factor) * factor - (factor - 1) * bias_t
验证
有了以上的分析,直接将FPEM-FFM结构中的所有的卷积和归一化层都替换一下。这样导出的模型再转换成TensorRT模型,进行精度测试后发现结果有很大提升,由原来的0.2-提升到0.6+。看起来验证是成功的,必然是其中的某个层出现了数值溢出,替换后可以进行fp16推理。可是,为什么精度没有对齐到fp32呢?难道上述的替换并不完全等效?为此我进行了如下实验:
print(conv(x) - warp_conv(x, conv).sum())
print(bn(x) - warp_bn(x, bn).sum())
发现对于卷积层,并无数值误差,而对归一化层,出现了一定的误差。而且,bn层的误差是随着映射函数中的变量factor增大而增大。“一定是多个归一化层累计效应,同时factor不应该设置太大”,我想。
就这样,我减小了factor的大小,同时逐一验证是具体哪个归一化层有问题。最后却没有定位到一个具体的归一化层产生数值溢出。一时间,整个debug似乎卡住了,没法继续推进了,似乎只能二分法?
整理思路
数值溢出的可能形式
不考虑TensorRT做层融合等情况,只考虑PyTorch做推理,数值溢出的可能形式有哪些呢?
某个算子内部计算过程数值溢出,输入输出均可以用fp16表示 跨内部连续多个算子出现数值溢出 整个网络计算过程都有数值溢出
首先排除最后一种情况,一般的,输入是图像归一化后的结果,而输出一般需要对应到标签或者标签相关的值。所以,输入和输出都不可能出现数值溢出。而第一种情况种,PANet只有FPEM-FFM结构是迥异于其他模型的,替换其中的卷积层和归一化层都不能解决问题的话,数值溢出只剩下第二种情况了。
多算子数值溢出
考虑到连续多个算子计算过程均出现数值溢出的可能,似乎直接进行单个算子映射已经无法解决问题了。同时,具体从哪个算子出现问题,我们仍然不知道。难道还是要寻求二分法的帮助?
寻找fp16失效的算子
没辙,只好二分法吧。可是如何二分法其实也很有考究。可行的debug方式:
每次提前返回结果,二分地导出ONNX再导出TensorRT模型,未被导出的部分继续以PyTorch代码衔接到TensoRT的计算结果后。 直接运行PyTorch模型,设置断点,查看哪些计算过程有数值异常地大。
第一种方法最为精准,肯定是可以找到具体的节点的。但是过程非常繁琐,同时需要大量的测试代码。第二种方法最为直接,但是也同样繁琐,因为一个图的节点太多了。要断点查看的话,可能需要很久。而且不能保证结果一定找到,可能存在疏漏。
第二种方法需要指导fp16数值表示的变化范围。IEEE标准中,fp32的取值范围是1.4e-45至3.4e38,fp16的取值范围是5.96e-8 ~ 65504。也就是说,fp16的最大值不超过65504。
断点调试
权衡了下,选择断点查看是否有数值异常。最后找到了异常的起始位置为骨干网络输出的特征层的最后一层,里面有大量的点的数值大于65504。随后再逐步往前找到开始出现异常的节点,同时再逐步往后,找到异常结束的节点。结果发现起始位置在最后一层的最后一次归一化层,结束位置在FPEM-FFM结构的某个归一化完成后。
优化
上面的方法仍然略显繁琐,有没有不用手动调试,直接用程序查找的方式呢?
之前的断点调试本质是寻找图计算过程中,内部嵌套的nn.Module的推理过程中的张量是否超出65504的点罢了。我们完全可以用代码查找。直接对模型的每一层进行遍历,查看是否有异常点。
一个可能的实现方式是利用PyTorch的钩子(hook)技术。所谓的钩子技术其实并非PyTorch所独有,事实上,很多的软件架构都有提供。钩子技术是在某个事件执行完成后,自动执行的函数。也就是说,我们只要在网络的每个层设置一个钩子,在该层推理完成后再对输入和输出进行查找是否有异常点即可。
from pyclbr import Function
from typing import Sequence
import torch
def fp16_check(module: torch.nn.Module, input: torch.Tensor, output: torch.Tensor) -> None:
if isinstance(input, dict):
for _, value in input.items():
fp16_check(module, value, output)
return
if isinstance(input, Sequence):
for value in input:
fp16_check(module, value, output)
return
if isinstance(output, dict):
for _, value in output.items():
fp16_check(module, input, value)
return
if isinstance(output, Sequence):
for value in output:
fp16_check(module, input, value)
return
if torch.abs(input).max()<65504 and torch.abs(output).max()>65504:
print('from: ', module.finspect_name)
if torch.abs(input).max()>65504 and torch.abs(output).max()<65504:
print('to: ', module.finspect_name)
return
from contextlib import contextmanager
class FInspect:
module_names = ['model']
handlers = []
def hook_all_impl(cls, module: torch.nn.Module, hook_func: Function)-> None:
for name, child in module.named_children():
cls.module_names.append(name)
cls.hook_all_impl(cls, module=child, hook_func=hook_func)
linked_name='->'.join(cls.module_names)
setattr(module, 'finspect_name', linked_name)
cls.module_names.pop()
handler = module.register_forward_hook(hook=hook_func)
cls.handlers.append(handler)
@classmethod
@contextmanager
def hook_all(cls, module: torch.nn.Module, hook_func: Function)-> None:
cls.hook_all_impl(cls, module, hook_func)
yield
[i.remove() for i in cls.handlers]
with FInspect.hook_all(patched_model, fp16_check):
patched_model(inputs)
尝试映射
整个异常过程可以表达成:bn(conv(relu(bn(x) + residual)))。对于这样一个过程,能否通过类似上述的对卷积层和归一化层的处理方式解决数值溢出问题呢?
relu一生之敌
我们知道,一个初始值,经过多个线性运算后,结果可以用一次线性运算还原。比如:w2(w1x + b1) + b2 = w1w2x + w2b1 + b2,结果还是个线性运算。这也是为什么神经网络需要激活函数————不然多个线性层的结果等效于一个线性层。
那么,上面的异常过程可以简化成w2(relu(w1x + b1)) + b2。其中w1, w2, b1, b2都可以计算出来。可是,relu激活函数可以在输入缩放减小后,对输出进行还原得到吗?我们都知道,relu函数有个很好的性质,那就是relu(px) = p*relu(x)。除此之外,再难有其他性质可以被利用,以对抗数值溢出。可是,如果计算过程是relu(wx + b),缩放输入x以后得到的结果relu(wx/p + b)不能再简单地恢复到relu(wx + b)。
不过,我们可以通过整体缩放relu的输入,将计算过程变成relu((wx + b)/p),这样再乘以p就可以恢复成relu(wx + b)了。如此,relu函数也可以绕过,对抗数值溢出。
公式
那么对于原公式w2(relu(w1x + b1)) + b2,就可以映射成 p * w2(relu((w1x + b1)/p)) + b2 - (p-1)*b2。
实施
在 MMDeploy框架中,只要利用函数重写劫持掉原始的PANet中的两段函数
mmocr.models.textdet.necks.FPEM_FFM.forward
和
mmdet.models.backbones.resnet.BasicBlock.forward
即可。在劫持的函数里,替换原有的计算过程,将上述的映射实现即可。
import torch
import torch.nn.functional as F
from mmdeploy.core import FUNCTION_REWRITER
from mmdeploy.utils.constants import Backend
FACTOR = 32
ENABLE = False
CHANNEL_THRESH = 400
@FUNCTION_REWRITER.register_rewriter(
func_name='mmocr.models.textdet.necks.FPEM_FFM.forward',
backend=Backend.TENSORRT.value)
def fpem_ffm__forward__trt(ctx, self, x, *args, **kwargs):
c2, c3, c4, c5 = x
# reduce channel
c2 = self.reduce_conv_c2(c2)
c3 = self.reduce_conv_c3(c3)
c4 = self.reduce_conv_c4(c4)
if ENABLE:
bn_w = self.reduce_conv_c5[1].weight / torch.sqrt(
self.reduce_conv_c5[1].running_var + self.reduce_conv_c5[1].eps)
bn_b = self.reduce_conv_c5[
1].bias - self.reduce_conv_c5[1].running_mean * bn_w
bn_w = bn_w.reshape(1, -1, 1, 1).repeat(1, 1, c5.size(2), c5.size(3))
bn_b = bn_b.reshape(1, -1, 1, 1).repeat(1, 1, c5.size(2), c5.size(3))
conv_b = self.reduce_conv_c5[0].bias.reshape(1, -1, 1, 1).repeat(
1, 1, c5.size(2), c5.size(3))
c5 = FACTOR * (self.reduce_conv_c5[:-1](c5)) - (FACTOR - 1) * (
bn_w * conv_b + bn_b)
c5 = self.reduce_conv_c5[-1](c5)
else:
c5 = self.reduce_conv_c5(c5)
# FPEM
for i, fpem in enumerate(self.fpems):
c2, c3, c4, c5 = fpem(c2, c3, c4, c5)
if i == 0:
c2_ffm = c2
c3_ffm = c3
c4_ffm = c4
c5_ffm = c5
else:
c2_ffm += c2
c3_ffm += c3
c4_ffm += c4
c5_ffm += c5
# FFM
c5 = F.interpolate(
c5_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
c4 = F.interpolate(
c4_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
c3 = F.interpolate(
c3_ffm,
c2_ffm.size()[-2:],
mode='bilinear',
align_corners=self.align_corners)
outs = [c2_ffm, c3, c4, c5]
return tuple(outs)
@FUNCTION_REWRITER.register_rewriter(
func_name='mmdet.models.backbones.resnet.BasicBlock.forward',
backend=Backend.TENSORRT.value)
def basic_block__forward__trt(ctx, self, x):
if self.conv1.in_channels < CHANNEL_THRESH:
return ctx.origin_func(self, x)
identity = x
out = self.conv1(x)
out = self.norm1(out)
out = self.relu(out)
out = self.conv2(out)
if torch.abs(self.norm2(out)).max() < 65504:
out = self.norm2(out)
out += identity
out = self.relu(out)
return out
else:
global ENABLE
ENABLE = True
# the output of the last bn layer exceeds the range of fp16
w1 = self.norm2.weight / torch.sqrt(self.norm2.running_var +
self.norm2.eps)
bias = self.norm2.bias - self.norm2.running_mean * w1
w1 = w1.reshape(1, -1, 1, 1).repeat(1, 1, out.size(2), out.size(3))
bias = bias.reshape(1, -1, 1, 1).repeat(1, 1, out.size(2),
out.size(3)) + identity
out = self.relu(w1 * (out / FACTOR) + bias / FACTOR)
return out
通过上述的重写函数,最后导出PANet模型可以媲美原始PyTorch模型,甚至略有超过(数值误差)。
总结
总结一下这篇博客,分享了一般的FP16数值溢出情况下的处理方式。
一个快速查找数值溢出算子的方法。 一个替换多个算子,从原始模型解决FP16数值溢出的方法。
点个在看 paper不断!