记MoCo中一个关于CrossEntropyLoss的计算问题

机器学习与生成对抗网络

共 1780字,需浏览 4分钟

 ·

2022-05-10 18:10

来源:知乎—葱鸭Fighting
地址:https://zhuanlan.zhihu.com/p/490742250

编辑:人工智能前沿讲习

昨晚在看对比学习算法 MoCo[1] 的源代码时,中间有一个涉及Pytorch中CrossEntropyLoss的计算问题困扰了我较长时间,因此记录下来加深一下印象:
问题描述:
MoCo 中 contrastive loss 的组成是由query正样本对相似度(代码图中的 l_pos),以及query与一系列queue中的负样本相似度(代码图中的 l_neg)共同构成的:
MoCo contrastive loss
在经过拼接后,logits 为一个N*(1+K) 的矩阵,矩阵的第一列为正样本对间的相似度,而其他剩余K列为正负样本对之间的相似度,因此我会直观地认为,在对应到标签计算CrossEntropyLoss时,第一列的标签应该为1,而其余K列的标签都为0。但在算法实现的时候,可以明显地看到此处的labels为一个值全为0的张量,这是为什么?这个labels张量不应该是第一个元素为1,其他元素都为0吗?
MoCo 部分源码
我在反复品读GitHub issue中其他人关于这个问题的解答(链接如下),以及pytorch文档中CrossEntropyLoss的计算方法后,总算意识到自己之前理解的错误所在:labels中的0元素并不是指代正负样本对,而是告诉CrossEntropyLoss输入第一维的标签为1(ground truth),也就是第0维指代的是正样本对。
https://github.com/facebookresearch/moco/issues/24#issuecomment-625508418
举例说明
上面这句话理解起来可能仍然有点抽象,因此举个简单例子说明一下:
  • logits矩阵:
logits = [[0.5, 0.2, 0.2, 0.1]
[0.6, 0.1, 0.1, 0.2]]
矩阵的行表示不同的数据样本;第一列是正样本对间的相似度,其他列表示正样本与负样本之间的相似度。
  • labels 张量:
labels = [0, 0]
注意这里labels的长度,是与logits的第一维也就是样本数量是一致的。labels中的元素实际上意味着在进行CrossEntropyLoss计算时,标签为1的ground truth的索引是多少,以logits中第一个样本为例的话,此时0号元素为ground truth,即数值0.5对应的标签值为1,其他数值对应的标签值为0,在进行CrossEntropyLoss计算时,会由 logits [0.5, 0.2, 0.2, 0.1] 与 label [1, 0, 0, 0] 来计算loss的数值。
之前会有错误理解的原因在于对Pytorch中CrossEntropyLoss的计算方法理解还不够深,在弄明白它的计算方法后自然就不会产生这样的疑问啦。
[1] He, K., Fan, H., Wu, Y., Xie, S., & Girshick, R. (2020). Momentum contrast for unsupervised visual representation learning. InProceedings of the IEEE/CVF conference on computer vision and pattern recognition(pp. 9729-9738).


猜您喜欢:

 戳我,查看GAN的系列专辑~!
一顿午饭外卖,成为CV视觉前沿弄潮儿!
CVPR 2022 | 25+方向、最新50篇GAN论文
 ICCV 2021 | 35个主题GAN论文汇总
超110篇!CVPR 2021最全GAN论文梳理
超100篇!CVPR 2020最全GAN论文梳理


拆解组新的GAN:解耦表征MixNMatch

StarGAN第2版:多域多样性图像生成


附下载 | 《可解释的机器学习》中文版

附下载 |《TensorFlow 2.0 深度学习算法实战》

附下载 |《计算机视觉中的数学方法》分享


《基于深度学习的表面缺陷检测方法综述》

《零样本图像分类综述: 十年进展》

《基于深度神经网络的少样本学习综述》



浏览 51
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报