深度学习中的“不确定性基线”
发布人:Google Research Brain 团队研究工程师 Zachary Nado 和研究员 Dustin Tran
机器学习 (ML) 越来越多地被用于实际应用,因此了解模型的不确定性和稳健性对于确保其在实践中的性能很有必要。例如,将模型部署到与训练数据不同的数据上时,其表现如何?模型在可能出错时如何发出信号?
不确定性和稳健性
https://slideslive.com/38935801/practical-uncertainty-estimation-outofdistribution-robustness-in-deep-learning
为掌握 ML 模型的行为,我们通常会根据目标任务的基线来衡量其性能。对于每个基线,研究人员必须尝试仅使用相应论文中的描述来重现结果,这为复现带来了严峻挑战。在实验代码得到完好记录和维护的前提下,查看这些代码可能更有用。但是,基线必须经过严格验证,因此仅仅查看代码还不够。
带来了严峻挑战
https://paperswithcode.com/rc2020
例如,在对一系列研究 [1、2、3] 进行回顾性分析时,作者会发现简单且经过优化的基线往往优于更复杂的方法。为真正了解模型之间的相对表现,且让研究人员能够衡量新理念是否切实取得有意义的进展,必须将目标模型与共同基线进行比较。
1
https://arxiv.org/abs/1707.05589
2
https://arxiv.org/abs/1807.04720
3
https://arxiv.org/abs/2102.06356
在“不确定性基线:深度学习中不确定性和稳健性的基准 (Uncertainty Baselines: Benchmarks for Uncertainty & Robustness in Deep Learning) ”一文中,我们介绍了“不确定性基线”,这是针对各种任务的标准化和先进深度学习方法的高质量实现合集,旨在促使不确定性和稳健性的相关研究更具可重复性。
不确定性基线:深度学习中不确定性和稳健性的基准
https://arxiv.org/abs/2106.04015
不确定性基线
https://github.com/google/uncertainty-baselines
该合集涵盖 9 大任务的 19 种方法,每种方法有至少五项指标。每个基线都是独立的实验流水线,具有易于重复使用且可扩展的组件,并且在其编写框架之外具有最小的依赖性。内含的流水线可在 TensorFlow、PyTorch 和 Jax 中得到实现。此外,每个基线的超参数都已在多次迭代中经过广泛调整,可提供更有力的结果。
TensorFlow
https://tensorflow.google.cn/
PyTorch
https://pytorch.org/
Jax
https://jax.readthedocs.io/en/latest/notebooks/quickstart.html
至撰写本文时,不确定性基线共提供了 83 个基线,包括 19 种方法,涵盖九个数据集的标准和最新策略。示例方法包括 BatchEnsemble(批集成)、Deep Ensembles(深度集成)、Rank-1 Bayesian Neural Nets(1 阶贝叶斯神经网络)、Monte Carlo Dropout 和 Spectral-normalized Neural Gaussian Processes(光谱归一化神经高斯过程)。
BatchEnsemble
https://arxiv.org/abs/2002.06715
Deep Ensembles
https://arxiv.org/abs/1612.01474
Rank-1 Bayesian Neural Nets
https://arxiv.org/abs/2005.07186
Monte Carlo Dropout
https://arxiv.org/abs/1506.02142
Spectral-normalized Neural Gaussian Processes
https://arxiv.org/abs/2006.10108
不确定性基线可以作为继任者,合并社区中如下流行基准:您可以相信模型的不确定性吗?、BDL 基准和 Edward2 基线。
您可以相信模型的不确定性吗?
https://github.com/google-research/google-research/tree/master/uq_benchmark_2019
BDL 基准
https://github.com/OATML/bdl-benchmarks
Edward2 基线
https://github.com/google/edward2/tree/master/baselines
数据集 | 输入 | 输出 | 训练示例 | 测试 数据集 |
CIFAR | RGB 图像 | 10 类分布 | 50,000 | 3 |
ImageNet | RGB 图像 | 1000 类分布 | 1,281,167 | 6 |
CLINC 意图检测 | 对话框系统 查询文本 | 150 类分布 (10 个网域) | 15,000 | 2 |
Kaggle 糖尿病性视网膜病变检测 | RGB 图像 | 糖尿病性视网膜病变的概率 | 35,126 | 1 |
维基百科 毒性 | 维基百科评论文本 | 毒性概率 | 159,571 | 3 |
CIFAR
https://www.cs.toronto.edu/~kriz/cifar.html
ImageNet
https://image-net.org/
CLINC 意图检测
https://github.com/clinc/oos-eval
Kaggle 糖尿病性视网膜病变检测
https://www.kaggle.com/c/diabetic-retinopathy-detection
维基百科毒性
https://www.kaggle.com/c/jigsaw-toxic-comment-classification-challenge
我们共为 9 个可用数据集提供基线,上表显示其中 5 个数据集的子集。数据集涵盖了表格、文本和图像模态。
不确定性基线根据选择的基础模型、训练数据集和一套评估指标设置各个基线。然后通过超参数对各个基线进行调整,以最大限度地提高这些指标的性能。上述三个轴线的可用基线各不相同:
基础模型(架构)包括 Wide ResNet 28-10、ResNet-50、BERT 和简单的全连接网络。
Wide ResNet 28-10
https://arxiv.org/abs/1605.07146
ResNet-50
https://arxiv.org/abs/1512.03385
BERT
https://arxiv.org/abs/1810.04805
训练数据集包括标准机器学习数据集(CIFAR、ImageNet 和 UCI)以及更多现实问题(Clinc 意图检测、Kaggle 糖尿病性视网膜病变检测和维基百科毒性)。
UCI
https://archive.ics.uci.edu/ml/datasets.php
Clinc 意图检测
https://tensorflow.google.cn/datasets/catalog/clinc_oos
维基百科毒性
https://tensorflow.google.cn/datasets/catalog/wikipedia_toxicity_subtypes
评估包括预测性指标(如准确率)、不确定性指标(如选择性预测和校准误差)、计算指标(推断延迟)以及分布内外数据集的性能。
为便于研究人员使用基线并在其基础上进行构建,我们特意对其进行优化,尽可能采用模块化设计,并实现最小化。如下方工作流图所示,不确定性基线没有引入新的类抽象,而是重复使用生态系统中预先存在的类(例如 TensorFlow 的 tf.data.Dataset
)。各个基线的训练/评估流水线均包含在实验的独立 Python 文件(可以在 CPU、GPU 或 Google Cloud TPU 上运行)中。由于基线之间的这种独立性,我们得以在 TensorFlow、PyTorch 或 JAX 任意一者中开发基线。
tf.data.Dataset
https://tensorflow.google.cn/api_docs/python/tf/data/Dataset
PyTorch
https://github.com/google/uncertainty-baselines/blob/master/baselines/diabetic_retinopathy_detection/torch_dropout.py
JAX
https://github.com/google/uncertainty-baselines/blob/master/baselines/jft/deterministic.py
工作流示意图:不确定性基线不同组成部分的构造方式。所有数据集都是 BaseDataset 类的子类,BaseDataset 类提供的简单 API 可用于使用任何受支持框架编写的基线。然后,任何基线的输出均可使用稳健性指标库进行分析
稳健性指标
https://github.com/google-research/robustness_metrics/
研究工程师对如何管理超参数和其他实验配置值(很轻松就能达到几十个)存在争议。我们没有使用针对该问题构建的任意一个框架,不想冒用户必须学习另一个库的风险,因此选择仅使用 Python 标志,这些标志通过遵循 Python 约定的 Abseil 定义。大多数研究人员应该非常熟悉该技术,其很容易扩展并插入其他流水线。
Abseil
https://abseil.io/docs/python/guides/flags
除了能够使用记录的命令运行我们的所有基线并获得相同的报告结果之外,我们还力求发布超参数调整结果和最终模型检查点,以进一步实现可重复性。目前,我们只针对糖尿病性视网膜病变基线完全开源上述内容,但我们会在运行基线的过程中继续上传更多结果。此外,我们提供的基线示例在硬件确定性方面完全可重复。
糖尿病性视网膜病变基线
https://github.com/google/uncertainty-baselines/tree/master/baselines/diabetic_retinopathy_detection
基线示例
https://github.com/google/uncertainty-baselines/blob/df320d4987deddf2e23a8a7cb45eda87d3c5f210/baselines/cifar/deterministic.py#L132
代码库中包含的所有基线都经过了广泛的超参数调整,我们希望研究人员可以轻松地重复使用这些基线,而无需进行昂贵的重新训练或重新调整。此外,我们希望避免流水线实现中影响基线比较的细微差异。
不确定性基线已被用于众多研究项目。如果您是一名研究人员,想要贡献其他方法或数据集,您可以在 GitHub 上创建一个议题,开启讨论!
众多研究项目
https://github.com/google/uncertainty-baselines#papers-using-uncertainty-baselines
衷心感谢各位共同开发的人员,以及提供指导和/或帮助审核本文的人员:Neil Band、Mark Collier、Josip Djolonga、Michael W. Dusenberry、Sebastian Farquhar、Angelos Filos、Marton Havasi、Rodolphe Jenatton、Ghassen Jerfel、Jeremiah Liu、Zelda Mariet、Jeremy Nixon、Shreyas Padhy、Jie Ren、Tim G. J. Rudner、Yeming Wen、Florian Wenzel、Kevin Murphy、D. Sculley、Balaji Lakshminarayanan、Jasper Snoek、Yarin Gal。
推荐阅读
辅助模块加速收敛,精度大幅提升!移动端实时的NanoDet-Plus来了!
机器学习算法工程师
一个用心的公众号
不要忘记“一键三连”哦~
分享
点赞
在看