深入探讨:简单的label smooth为什么能够涨点?

极市平台

共 3004字,需浏览 7分钟

 ·

2021-01-27 02:49

↑ 点击蓝字 关注极市平台

作者丨 史开杰
来源丨PandaCV
编辑丨极市平台

极市导读

 

label smooth(标签平滑)作为一种简单的训练trick,能通过很少的代价(只需要修改target的编码方式),即可获得准确率的提升。本文想要通过一些简单的公式推导,理解target使用label smooth表示会比单纯的使用one-hot好在哪里。 >>加入极市CV技术交流群,走在计算机视觉的最前沿

感谢深度眸在本文中对我的帮助。

前言

本文的开头先引用下《深度学习》中文版,第四章开头的一段话。

机器学习的算法通常需要大量的数值计算。这通常是指通过迭代过程更新解的估计值来解决数学问题的算法,而不是通过解析过程推导出公式来提供正确解的方法。常见的操作包括优化(找到最小化或者最大化函数值的参数)和线性方程组的求解。对数字计算机来说实数无法在有限内存下精确表示,因此仅仅是计算涉及实数的函数也是困难的。

这里就涉及到两个求解数学问题的方法:

  1. 迭代更新解的估计值,如通过二分法牛顿迭代法求开方的过程。
  2. 解析过程推导公式,如我们考试中的那些需要大量计算的数学题,一般最后会得出一个解析解。

在整日学习深度学习之后,我们有的时候也需要用解析解,即公式推导求解一些深度学习的问题。

one-hot解析解推导

神经网络的输出称为logits,简记为,经过softmax之后转化为和为1的概率形式,记为,真值target记为为分类类别的数量。本文所有讨论的内容是在导数等于0的情况下(解析解的情况下),为多少(神经网络的输出是多少)。当损失函数为交叉熵且target的编码和为1时, 导数则为(求导过程文章:https://zhuanlan.zhihu.com/p/343988823 ), 假设总共有个类. 可以有如下的公式.

令公式(1)的导数等于0, 可以得到公式(2), 记真值下标为.

是通过推导出来的, 则

通过公式(3.1)可得

通过公式(3.2)也可以得上面的结果。所以targetone-hot编码,损失函数为交叉熵的情况下。解析解是

所以通过上述推导可以得到:最优的情况下,在one-hot编码和交叉熵的损失函数下,错误类的logit值要是负无穷,正确类要是一个常数。这种最优的情况一般是不能达到的,且会远大于. 在文章《Rethinking the inception architecture for computer vision》里面认为如果远大于,会出现两个不好的性质

  1. 导致过拟合,将所有的概率都赋给了真值,会导致泛化能力下降
  2. 鼓励真值对应的logit远大于其他值的logit,但是导数是有界的,也就是数值不会很大,想要达成远大于的效果,要更新很多很多次。

个人认为:logit要是负无穷,损失才会变为0,神经网络很难会有输出负无穷的情况(权重衰减还会约束着神经网络的参数)

label smooth解析解推导

label smooth是在《Rethinking the inception architecture for computer vision》里面提出来的。我觉的作者的想法应该是这样的:蒸馏改变了学习的真值,能获得更好的结果,但是它需要准确率更高的教师网络;如果我现在想要训练出一个准确率最高的模型,那么是没有网络能给我知识的,所以就通过label smooth学习一种简单的知识。

label smooth 学习的编码形式如公式(4)所示,其中是预定义好的一个超参数,一般取值0.1,是该分类问题的类别个数

令公式(4)导数等于0,可得到公式(5.1)和(5.2)。类似于公式(1)的求导,但是要注意target编码的和要为1( https://zhuanlan.zhihu.com/p/343988823 里面有解释).

因为正确的类只有1个;错误的类有K-1个,且解析解的情况下,错误类的概率是相等的。所以公式(5.1)可以推导为公式(6):

把公式(6)的放到右边,两边再取下对数可得公式(7)

我们通过公式(5.2)也能推出相同的解。右边的公式分子分母颠倒一下可得公式(8)

因为错误类的值是相等的,所以,则可得公式(9)

记为, 则可得公式(10),即导数等于0的情况下,logit的取值。

和论文《bag of tricks for image classification with convolutional neural networks》中,给出的结果是一样的(文章里面交叉熵的好像写反了) 带入label smooth定义的公式验算一下则是

所以,在损失函数为交叉熵的情况下,如果我们使用label-smooth编码,错误类的logit不会要求是负无穷。且错误类和正确类的logit值有一定大小误差的情况下,loss就会很小很小。

label smooth中的gap

论文《bag of tricks for image classification with convolutional neural networks》还画出了gap图,此处的gap就是导数等于0的情况下,之间的数值误差

gap就是,其中K是分类的类别数,(eps)是label smooth的超参数。假设取0.5且是1000分类,那么

意思是,正确类和错误类的误差等于7就够了,损失不想要继续更新参数让他们的误差越来越大。实际代码的过程中,一般取即可。

总结

one-hot的编码方式需要错误类的logit趋向于负无穷,这样会导致正确类和错误类的logit输出误差很大,网络的泛化能力不强。并且因为网络训练时会有一些正则化的存在,logit的输出很难是负无穷。label-smooth的编码方式只要正确类和错误类有一定的数值误差即可,这个取决于分类的类别数量和。网络极使在正则化的情况下也比one-hot容易学习到最优情况。

代码

这里推荐https://github.com/CoinCheung/pytorch-loss/blob/master/label_smooth.py,大家需要注意的是这个代码的编码表示值和好像不为1.


推荐阅读




添加极市小助手微信(ID : cvmart2),备注:姓名-学校/公司-研究方向-城市(如:小极-北大-目标检测-深圳),即可申请加入极市目标检测/图像分割/工业检测/人脸/医学影像/3D/SLAM/自动驾驶/超分辨率/姿态估计/ReID/GAN/图像增强/OCR/视频理解等技术交流群:月大咖直播分享、真实项目需求对接、求职内推、算法竞赛、干货资讯汇总、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企视觉开发者互动交流~
△长按添加极市小助手

△长按关注极市平台,获取最新CV干货

觉得有用麻烦给个在看啦~  
浏览 80
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报