Torcheck —— 测试你的 PyTorch 模型

AI算法与图像处理

共 6113字,需浏览 13分钟

 ·

2021-06-19 17:44

引言


你是否有过这样的经历:长时间训练 PyTorch 模型,结果发现在模型的 forward 方法中输入了一行错误?你是否曾经遇到过这样的情况:你从模型中获得了一些合理的输出,但是不确定这是否表明你构建的模型是正确的,或者这只是因为深度学习是如此强大,即使是错误的模型架构也会产生下降的结果。


就我个人而言,测试深度学习模型有时会让我抓狂。最突出的痛点是:


  • 它的黑盒特性使它很难测试。即使不是不可能,也需要很多专业知识来理解中间结果

  • 长的训练时间大大减少了迭代次数

  • 没有专门的工具。通常,您希望在一个小样本数据集上测试模型,这需要重复编写样板代码,以便进行设置优化、计算损失和反向传播

为了减少这种开销,我之前做了一些研究。这里有一篇详细文章:https://thenerdstation.medium.com/how-to-unit-test-machine-learning-code-57cf6fd81765其核心思想是,我们永远不能百分之百确定我们的模型是正确的,但至少它应该能够通过一些合理性检验。换句话说,这些合理性检验是必要的,但可能还不够。


为了节省您的时间,以下是他提出的所有合理性检验的摘要:


  • 如果一个模型参数在训练过程中没有被故意冻结,那么它应该在训练过程中不断变化。这可以是 PyTorch 线性层的张量

  • 模型参数在训练过程中不应该改变,如果它被冻结。这可能是一个你不想更新的预训练好的层

  • 根据您的模型属性,模型输出的范围应该服从某些条件。例如,如果它是一个分类模型,它的输出不应该都在范围(0,1)内。否则,很有可能您在计算损失之前错误地将 softmax 最大激活函数应用于输出

  • (这实际上不是来自那篇文章,而是一个常见的问题)在大多数情况下,模型参数不应该包含 NaN 或 Inf(infinite number),这同样适用于模型输出


除了提出这些检查,他还构建了一个 Python 包来实现这些检查。这是一个很好的包,但仍然有未解决的痛点。这个包是几年前创建的,现在已经不再维护了。


因此,受这种合理性检验思想的启发,目标是创建一个易于使用的 Python 包,创建 torcheck!其主要创新包括:


  • 不再需要额外的测试代码。只需添加几行代码指定训练前的检查,torcheck 将在训练发生时执行检查,并在检查失败时提出信息性错误消息

  • 可以在不同的级别检查模型。你可以指定子模块、线性层甚至权重张量的检查,而不是检验整个模型!这样就可以对复杂体系结构的检查进行更多的自定义


接下来,我们会给你一个关于 torcheck 的快速教程。


假设我们已经编写了一个 ConvNet 模型来对 MNIST 数据集进行分类:

# modelclass CNN(nn.Module):        def __init__(self):        super().__init__()        self.conv1 = nn.Conv2d(1, 1, kernel_size=1, stride=1)        self.conv2 = nn.Conv2d(1, 8, kernel_size=3, stride=1, padding=1)        self.conv3 = nn.Conv2d(8, 16, kernel_size=3, stride=1, padding=1)        self.relu = nn.ReLU()        self.maxpool = nn.MaxPool2d(2, 2)        self.fc1 = nn.Linear(16 * 7 * 7, 128)        self.fc2 = nn.Linear(128, 10)            def forward(self, x):        output = self.relu(self.conv1(x))        output = self.relu(self.conv2(x))        output = self.maxpool(output)        output = self.relu(self.conv3(output))        output = self.maxpool(output)        output = output.view(output.size()[0], -1)        output = self.relu(self.fc1(output))        output = self.fc2(output)        return output

# training routinemodel = CNN()optimizer = optim.Adam(model.parameters(), lr=0.001)for epoch in range(num_epochs): for x, y in dataloader: y_pred = model(x) loss = F.cross_entropy(y_pred, y) loss.backward() optimizer.step() optimizer.zero_grad()

在模型代码中实际上有一个细微的错误。你们中的一些人可能已经注意到了:在第16行中,我们不小心把 x 放在了右边,而应该是 output。


现在让我们看看 torcheck 如何帮助我们检验这个隐藏的错误!


步骤0:安装


在我们开始之前,首先用一行代码安装软件包。

$ pip install torcheck


步骤1: 添加 torcheck 代码


接下来我们将添加代码。Torcheck 代码总是驻留在循环训练之前,在模型和优化器实例化之后,如下所示:

# model and optimizer instantiationmodel = CNN()optimizer = optim.Adam(model.parameters(), lr=0.001)
############################ torcheck code goes here ############################
# training routinefor epoch in range(num_epochs): for x, y in dataloader: y_pred = model(x) loss = F.cross_entropy(y_pred, y) loss.backward() optimizer.step() optimizer.zero_grad()


步骤1.1:注册优化器


首先,用 torcheck 注册你的优化器:

torcheck.register(optimizer)


步骤1.2:增加合理性检验


接下来,添加要在四个类别中执行的所有检验。


1. 参数改变/不改变


对于我们的例子,我们希望在训练过程中更改所有的模型参数:

# check all the model parameters will change# module_name is optional, but it makes error messages more informative when checks failtorcheck.add_module_changing_check(model, module_name="my_model")

附注


为了演示 torcheck 的全部功能,让我们假设稍后你已经冻结了卷积层,只想微调线性层。在这种情况下将会像这样:

# check the first convolutional layer's parameters won't changetorcheck.add_module_unchanging_check(model.conv1, module_name="conv_layer_1")# check the second convolutional layer's parameters won't changetorcheck.add_module_unchanging_check(model.conv2, module_name="conv_layer_2")# check the third convolutional layer's parameters won't changetorcheck.add_module_unchanging_check(model.conv3, module_name="conv_layer_3")# check the first linear layer's parameters will changetorcheck.add_module_changing_check(model.fc1, module_name="linear_layer_1")# check the second linear layer's parameters will changetorcheck.add_module_changing_check(model.fc2, module_name="linear_layer_2")

2. 输出范围检查


因为我们的模型是一个分类模型,所以我们想要添加前面提到的检验:模型输出不应该都在范围(0,1)内。

# check model outputs are not all within (0, 1)# aka softmax hasn't been applied before loss calculationtorcheck.add_module_output_range_check(    model,    output_range=(0, 1),    negate_range=True,)

negate_range = True 参数带有“ not all”的含义。如果您只是想检查模型输出都在某个范围内,只需删除该参数。


尽管 torcheck 不适用于我们的示例,但是它还允许你检查子模块的中间输出。


3. NaN 检查


我们当然希望确保模型参数在训练期间不会变成 NaN,并且模型输出不包含 NaN。添加 NaN 检查很简单:

# check whether model parameters become NaN or outputs contain NaNtorcheck.add_module_nan_check(model)

4. Inf 检查


类似地,添加 Inf 检查:

# check whether model parameters become infinite or outputs contain infinite valuetorcheck.add_module_inf_check(model)

在添加了所有感兴趣的检验之后,最终的训练代码如下:

# model and optimizer instantiationmodel = CNN()optimizer = optim.Adam(model.parameters(), lr=0.001)
# torcheck codetorcheck.register(optimizer)torcheck.add_module_changing_check(model, module_name="my_model")torcheck.add_module_output_range_check(model, output_range=(0, 1), negate_range=True)torcheck.add_module_nan_check(model)torcheck.add_module_inf_check(model)
# training routinefor epoch in range(num_epochs): for x, y in dataloader: y_pred = model(x) loss = F.cross_entropy(y_pred, y) loss.backward() optimizer.step() optimizer.zero_grad()

步骤2:训练和修复


现在让我们像往常一样进行训练,看看会发生什么:

$ python run.pyTraceback (most recent call last):  (stack trace information here)RuntimeError: The following errors are detected while training:Module my_model's conv1.weight should change.Module my_model's conv1.bias should change.

砰!我们立即得到一个错误消息,说我们的模型的 conv1.weight 和 conv1.bias 不会改变。一定是 model.conv1出了什么问题。


正如预期的那样,我们转向模型代码,注意错误,修复它,并重新运行训练。


(可选)步骤3:关闭检验


耶! 我们的模型通过了所有的检验。最后,我们可以对其进行关闭:

torcheck.disable()

当你想要在验证集上运行模型,或者只想从你的模型训练中消除检查耗时,这是非常有用的。


如果你还想继续使用,只需要:

torcheck.enable()


·  END  ·

个人微信(如果没有备注不拉群!
请注明:地区+学校/企业+研究方向+昵称



下载1:何恺明顶会分享


AI算法与图像处理」公众号后台回复:何恺明,即可下载。总共有6份PDF,涉及 ResNet、Mask RCNN等经典工作的总结分析


下载2:终身受益的编程指南:Google编程风格指南


AI算法与图像处理」公众号后台回复:c++,即可下载。历经十年考验,最权威的编程规范!



下载3 CVPR2021

AI算法与图像处公众号后台回复:CVPR即可下载1467篇CVPR 2020论文 和 CVPR 2021 最新论文

点亮 ,告诉大家你也在看



浏览 36
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报