深度学习中的“不确定性基线”

共 5387字,需浏览 11分钟

 ·

2022-01-19 20:58


发布人:Google Research Brain 团队研究工程师 Zachary Nado 和研究员 Dustin Tran


机器学习 (ML) 越来越多地被用于实际应用,因此了解模型的不确定性和稳健性对于确保其在实践中的性能很有必要。例如,将模型部署到与训练数据不同的数据上时,其表现如何?模型在可能出错时如何发出信号?

  • 不确定性和稳健性

    https://slideslive.com/38935801/practical-uncertainty-estimation-outofdistribution-robustness-in-deep-learning


为掌握 ML 模型的行为,我们通常会根据目标任务的基线来衡量其性能。对于每个基线,研究人员必须尝试仅使用相应论文中的描述来重现结果,这为复现带来了严峻挑战。在实验代码得到完好记录和维护的前提下,查看这些代码可能更有用。但是,基线必须经过严格验证,因此仅仅查看代码还不够。

  • 带来了严峻挑战

    https://paperswithcode.com/rc2020


例如,在对一系列研究 [123] 进行回顾性分析时,作者会发现简单且经过优化的基线往往优于更复杂的方法。为真正了解模型之间的相对表现,且让研究人员能够衡量新理念是否切实取得有意义的进展,必须将目标模型与共同基线进行比较。

  • 1

    https://arxiv.org/abs/1707.05589

  • 2

    https://arxiv.org/abs/1807.04720

  • 3

    https://arxiv.org/abs/2102.06356


在“不确定性基线:深度学习中不确定性和稳健性的基准 (Uncertainty Baselines: Benchmarks for Uncertainty & Robustness in Deep Learning) ”一文中,我们介绍了“不确定性基线”,这是针对各种任务的标准化和先进深度学习方法的高质量实现合集,旨在促使不确定性和稳健性的相关研究更具可重复性。

  • 不确定性基线:深度学习中不确定性和稳健性的基准

    https://arxiv.org/abs/2106.04015

  • 不确定性基线

    https://github.com/google/uncertainty-baselines


该合集涵盖 9 大任务的 19 种方法,每种方法有至少五项指标。每个基线都是独立的实验流水线,具有易于重复使用且可扩展的组件,并且在其编写框架之外具有最小的依赖性。内含的流水线可在 TensorFlowPyTorchJax 中得到实现。此外,每个基线的超参数都已在多次迭代中经过广泛调整,可提供更有力的结果。

  • TensorFlow

    https://tensorflow.google.cn/

  • PyTorch

    https://pytorch.org/

  • Jax

    https://jax.readthedocs.io/en/latest/notebooks/quickstart.html


不确定性基线


至撰写本文时,不确定性基线共提供了 83 个基线,包括 19 种方法,涵盖九个数据集的标准和最新策略。示例方法包括 BatchEnsemble(批集成)、Deep Ensembles(深度集成)、Rank-1 Bayesian Neural Nets(1 阶贝叶斯神经网络)、Monte Carlo DropoutSpectral-normalized Neural Gaussian Processes(光谱归一化神经高斯过程)。

  • BatchEnsemble

    https://arxiv.org/abs/2002.06715

  • Deep Ensembles

    https://arxiv.org/abs/1612.01474

  • Rank-1 Bayesian Neural Nets

    https://arxiv.org/abs/2005.07186

  • Monte Carlo Dropout

    https://arxiv.org/abs/1506.02142

  • Spectral-normalized Neural Gaussian Processes

    https://arxiv.org/abs/2006.10108


不确定性基线可以作为继任者,合并社区中如下流行基准:您可以相信模型的不确定性吗?BDL 基准Edward2 基线

  • 您可以相信模型的不确定性吗?

    https://github.com/google-research/google-research/tree/master/uq_benchmark_2019

  • BDL 基准

    https://github.com/OATML/bdl-benchmarks

  • Edward2 基线

    https://github.com/google/edward2/tree/master/baselines


数据集

输入

输出

训练示例

测试

数据集

CIFAR

RGB 图像

10 类分布

50,000

3

ImageNet

RGB 图像

1000

类分布

1,281,167

6

CLINC

意图检测

对话框系统

查询文本

150 类分布

(10 个网域)

15,000

2

Kaggle 糖尿病性视网膜病变检测

RGB 图像

糖尿病性视网膜病变的概率

35,126

1

维基百科

毒性

维基百科评论文本

毒性概率

159,571

3

  • CIFAR

    https://www.cs.toronto.edu/~kriz/cifar.html

  • ImageNet

    https://image-net.org/

  • CLINC 意图检测

    https://github.com/clinc/oos-eval

  • Kaggle 糖尿病性视网膜病变检测

    https://www.kaggle.com/c/diabetic-retinopathy-detection

  • 维基百科毒性

    https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge


我们共为 9 个可用数据集提供基线,上表显示其中 5 个数据集的子集。数据集涵盖了表格、文本和图像模态。


不确定性基线根据选择的基础模型、训练数据集和一套评估指标设置各个基线。然后通过超参数对各个基线进行调整,以最大限度地提高这些指标的性能。上述三个轴线的可用基线各不相同:


基础模型(架构)包括 Wide ResNet 28-10ResNet-50BERT 和简单的全连接网络。

  • Wide ResNet 28-10

    https://arxiv.org/abs/1605.07146

  • ResNet-50

    https://arxiv.org/abs/1512.03385

  • BERT

    https://arxiv.org/abs/1810.04805


训练数据集包括标准机器学习数据集(CIFAR、ImageNetUCI)以及更多现实问题(Clinc 意图检测、Kaggle 糖尿病性视网膜病变检测和维基百科毒性)。

  • UCI

    https://archive.ics.uci.edu/ml/datasets.php

  • Clinc 意图检测

    https://tensorflow.google.cn/datasets/catalog/clinc_oos

  • 维基百科毒性

    https://tensorflow.google.cn/datasets/catalog/wikipedia_toxicity_subtypes


评估包括预测性指标(如准确率)、不确定性指标(如选择性预测和校准误差)、计算指标(推断延迟)以及分布内外数据集的性能。


模块化和可复用性



为便于研究人员使用基线并在其基础上进行构建,我们特意对其进行优化,尽可能采用模块化设计,并实现最小化。如下方工作流图所示,不确定性基线没有引入新的类抽象,而是重复使用生态系统中预先存在的类(例如 TensorFlow 的 tf.data.Dataset)。各个基线的训练/评估流水线均包含在实验的独立 Python 文件(可以在 CPU、GPU 或 Google Cloud TPU 上运行)中。由于基线之间的这种独立性,我们得以在 TensorFlow、PyTorchJAX 任意一者中开发基线。


  • tf.data.Dataset

    https://tensorflow.google.cn/api_docs/python/tf/data/Dataset

  • PyTorch

    https://github.com/google/uncertainty-baselines/blob/master/baselines/diabetic_retinopathy_detection/torch_dropout.py

  • JAX

    https://github.com/google/uncertainty-baselines/blob/master/baselines/jft/deterministic.py


工作流示意图:不确定性基线不同组成部分的构造方式。所有数据集都是 BaseDataset 类的子类,BaseDataset 类提供的简单 API 可用于使用任何受支持框架编写的基线。然后,任何基线的输出均可使用稳健性指标库进行分析

  • 稳健性指标

    https://github.com/google-research/robustness_metrics/


研究工程师对如何管理超参数和其他实验配置值(很轻松就能达到几十个)存在争议。我们没有使用针对该问题构建的任意一个框架,不想冒用户必须学习另一个库的风险,因此选择仅使用 Python 标志,这些标志通过遵循 Python 约定的 Abseil 定义。大多数研究人员应该非常熟悉该技术,其很容易扩展并插入其他流水线。

  • Abseil

    https://abseil.io/docs/python/guides/flags


可重复性


除了能够使用记录的命令运行我们的所有基线并获得相同的报告结果之外,我们还力求发布超参数调整结果和最终模型检查点,以进一步实现可重复性。目前,我们只针对糖尿病性视网膜病变基线完全开源上述内容,但我们会在运行基线的过程中继续上传更多结果。此外,我们提供的基线示例在硬件确定性方面完全可重复。

  • 糖尿病性视网膜病变基线

    https://github.com/google/uncertainty-baselines/tree/master/baselines/diabetic_retinopathy_detection

  • 基线示例

    https://github.com/google/uncertainty-baselines/blob/df320d4987deddf2e23a8a7cb45eda87d3c5f210/baselines/cifar/deterministic.py#L132


实际影响


代码库中包含的所有基线都经过了广泛的超参数调整,我们希望研究人员可以轻松地重复使用这些基线,而无需进行昂贵的重新训练或重新调整。此外,我们希望避免流水线实现中影响基线比较的细微差异。


不确定性基线已被用于众多研究项目。如果您是一名研究人员,想要贡献其他方法或数据集,您可以在 GitHub 上创建一个议题,开启讨论!

  • 众多研究项目

    https://github.com/google/uncertainty-baselines#papers-using-uncertainty-baselines


致谢


衷心感谢各位共同开发的人员,以及提供指导和/或帮助审核本文的人员:Neil Band、Mark Collier、Josip Djolonga、Michael W. Dusenberry、Sebastian Farquhar、Angelos Filos、Marton Havasi、Rodolphe Jenatton、Ghassen Jerfel、Jeremiah Liu、Zelda Mariet、Jeremy Nixon、Shreyas Padhy、Jie Ren、Tim G. J. Rudner、Yeming Wen、Florian Wenzel、Kevin Murphy、D. Sculley、Balaji Lakshminarayanan、Jasper Snoek、Yarin Gal。



推荐阅读

深入理解生成模型VAE

SOTA模型Swin Transformer是如何炼成的!

快来解锁PyTorch新技能:torch.fix

集成YYDS!让你的模型更快更准!

辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!

SimMIM:一种更简单的MIM方法

SSD的torchvision版本实现详解


机器学习算法工程师


                                    一个用心的公众号




不要忘记“一键三连”哦~

分享

点赞

在看

浏览 119
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报