更简单的掩码图像建模框架SimMIM介绍和PyTorch代码实现
数据派THU
共 18403字,需浏览 37分钟
· 2022-08-25
来源:DeepHub IMBA 本文约4000字,建议阅读10+分钟 本文中我们介绍了 SimMIM,这是一种受掩码建模启发的强大 SSL 算法,其中一部分输入数据被掩码,模型的目标是最小化重建损失。
图像中的掩码技术
Pytorch实现
from torch import (
randn,
)
# tokens is currently a dummy tensor.
# Later, it will be replaced by the actual tokens
tokens = randn(batch_size, n_tokens, token_dim)
from torch import (
randn,
)
tokens = randn(batch_size, n_tokens, token_dim)
indices_to_mask = randn(batch_size, n_tokens)
# Number of tokens to mask
# 50% of the total number of tokens performs well on average.
# However, for smaller patch sizes, a higher masking ratio is generally better.
# For example, for a patch size of 32, 0.5 performs well but for
# a patch size of 16, it would be worthwhile to increase it to 0.8.
n_masked_tokens = int(0.5*n_tokens)
# topk returns the k largest elements as well as their indices
# dim=1 tells it to find the maximum values and their indices
# on a per-row basis
# The indices of the tokens that are to be masked is going
# to be the indices of the n_masked_tokens largest values
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
# The largest values can be accesses via indices_to_mask.values,
# and their indices can be accessed via indices_to_mask.indices
indices_to_mask = indices_to_mask.indices
from torch import (
randn,
zeros,
)
tokens = randn(batch_size, n_tokens, token_dim)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
# Initially, bitmask is simply full of zeros (i.e., False)
bitmask = zeros(batch_size, n_tokens)
# What this line does is as follows:
# For every row i, bitmask[i][j] is replaced
# by the value argument (in this case 1), where j takes every value
# in indices_to_mask[i].
# For example, if indices_to_mask[3] is
# [2, 4, 7], then bitmask[3][2], bitmask[3][4], and bitmask[3][7]
# are all set to 1.
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
from torch import (
randn,
zeros,
)
# vit is assumed to be a vision transformer from timm
# To get tokens from a timm ViT, one must call its patch_embed method
# tokens is now of shape batch_size X n_tokens X token_dim
# Keep in mind that input is image data and of size
# batch_size X n_channels X height X width
tokens = vit.patch_embed(input)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
from torch import (
randn,
zeros,
)
from torch.nn import (
Parameter,
)
tokens = vit.patch_embed(input)
# The mask token itself is simply a vector of dimension token_dim
mask_token = Parameter(randn(token_dim))
# mask_token is repeated to make it the same shape as tokens
# mask_tokens is now of size batch_size X n_tokens X token_dim
mask_tokens = mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
from torch import (
randn,
zeros,
)
from torch.nn import (
Parameter,
)
tokens = vit.patch_embed(input)
mask_token = Parameter(randn(token_dim))
mask_tokens = mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
# bitmask must have the same number of axes as tokens and mask_tokens
# Therefore, unsqueeze(2) adds an axis to it and it is now of shape batch_size X n_tokens X 1
bitmask = bitmask.unsqueeze(2)
# ~bitmask turns True to False and False to True
# Here, all that is taking place is (~bitmask) is multiplied by tokens
# to zero out every token that is supposed to be masked, and the result is added
# to bitmask*mask_tokens, in which everything is 0 except the tokens that are
# supposed to mask.
tokens = (~bitmask)*tokens + bitmask*mask_tokens
然后就是位置嵌入
from torch import (
randn,
zeros,
)
from torch.nn import (
Parameter,
)
tokens = vit.patch_embed(input)
mask_token = Parameter(randn(token_dim))
mask_tokens = mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
bitmask = bitmask.unsqueeze(2)
tokens = (~bitmask)*tokens + bitmask*mask_tokens
# In timm, a ViT's position embedding is accessible via vit.pos_embed
# The reason for vit.pos_embed[:, 1:] in place of simply vit.pos_embed
# is that the first position embedding vector is for the class token,
# which is not used for self-supervised learning.
tokens = tokens+vit.pos_embed[:, 1:]
令牌可以被输入到 ViT获得它的编码表示。
from torch import (
randn,
zeros,
)
from torch.nn import (
Parameter,
)
tokens = vit.patch_embed(input)
mask_token = Parameter(randn(token_dim))
mask_tokens = mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
bitmask = bitmask.unsqueeze(2)
tokens = (~bitmask)*tokens + bitmask*mask_tokens
tokens = tokens+vit.pos_embed[:, 1:]
# The encoded representation of tokens
encoded = vit.blocks(tokens)
被屏蔽的令牌将从编码中获取,然后它们通过线性层来重建像素值。
from torch import (
randn,
zeros,
)
from torch.nn import (
Linear,
Parameter,
)
tokens = vit.patch_embed(input)
mask_token = Parameter(randn(token_dim))
mask_tokens = mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
bitmask = bitmask.unsqueeze(2)
tokens = (~bitmask)*tokens + bitmask*mask_tokens
tokens = tokens+vit.pos_embed[:, 1:]
encoded = vit.blocks(tokens)
# To index input and encoded with bitmask,
# the axis that was added must be removed.
# This reverts bit_mask to a size of batch_size X n_tokens
bitmask = bitmask.squeeze(2)
# The encoded mask tokens, of shape batch_size X n_masked_tokens X token_dim
masked_tokens_encoded = encoded[bitmask]
# In timm, A ViT's patch height and width are vit.patch_embed.patch_size
patch_height = patch_width = vit.patch_embed.patch_size
# The input is the tokens,
# the output is the reconstructed raw pixel values.
# Therefore, the output shape is 3 (for 3 channels)
# multiplied by patch_height*patch_width, which is the original shape
# of the patches before they were tokenized
decoder_out_dim = 3*patch_height*patch_width
decoder = Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)
# The reconstructed pixels, of shape batch_size X n_masked_tokens X 3*patch_height*patch_width
masked_patches_reconstructed = decoder(masked_tokens_encoded)
from einops import (
rearrange,
)
from torch import (
randn,
zeros,
)
from torch.nn import (
Linear,
Parameter,
)
tokens = vit.patch_embed(input)
mask_token = torch.nn.Parameter(torch.randn(token_dim))
mask_tokens = self.mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
bitmask = bitmask.unsqueeze(2)
tokens = (~bitmask)*tokens + bitmask*mask_tokens
tokens = tokens+vit.pos_embed[:, 1:]
encoded = vit.blocks(tokens)
bitmask = bitmask.squeeze(2)
masked_tokens_encoded = encoded[bitmask]
patch_height = patch_width = vit.patch_embed.patch_size
decoder_out_dim = 3*patch_height*patch_width
decoder = Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)
masked_patches_reconstructed = decoder(masked_tokens_encoded)
# patterns tells einops how to rearrange the tensor
# Its layout is as follows: 'shape_before -> shape_after'
# In this case, the shape before would be batch_size X n_channels X height X width,
# and the shape after would be batch_size X n_tokens X n_channels*patch_height*patch_width
# However, in einops, variables that are in shape_before must be in shape_after as well and vice versa
# For example, in this case, height is in shape_before but not shape_after.
# Therefore, shape_before and shape_after must be restructured.
# Particularly, two new variables can be introduced, n_patches_height and n_patches_width,
# that say how many patches are along the height and width axes respectively.
# Thus, height = n_patches_height * patch_height,
# width = n_patches_width * patch_width, and
# n_tokens = n_patches_height * n_patches width
# Multiplying two variables in einops is denoted by (x y).
pattern = (
'batch_size n_channels (n_patches_height patch_height) (n_patches_width patch_width) -> '
'batch_size (n_patches_height n_patches_width) (n_channels patch_height patch_width)'
)
# einops.rearrange is like torch.reshape
# einops cannot infer patch_height and patch_width,
# so they must be passed manually
# patches is now of shape batch_size X n_tokens X 3*patch_height*patch_width
patches = rearrange(
tensor=input,
pattern=pattern,
patch_height=patch_height,
patch_width=patch_width,
)
得对应于 masked_patches_reconstructed 的patche部分,
from einops import (
rearrange,
)
from torch import (
randn,
zeros,
)
from torch.nn import (
Linear,
Parameter,
)
tokens = vit.patch_embed(input)
mask_token = torch.nn.Parameter(torch.randn(token_dim))
mask_tokens = self.mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
bitmask = bitmask.unsqueeze(2)
tokens = (~bitmask)*tokens + bitmask*mask_tokens
tokens = tokens+vit.pos_embed[:, 1:]
encoded = vit.blocks(tokens)
bitmask = bitmask.squeeze(2)
masked_tokens_encoded = encoded[bitmask]
patch_height = patch_width = vit.patch_embed.patch_size
decoder_out_dim = 3*patch_height*patch_width
decoder = Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)
masked_patches_reconstructed = decoder(masked_tokens_encoded)
pattern = (
'batch_size n_channels (n_patches_height patch_height) (n_patches_width patch_width) -> '
'batch_size (n_patches_height n_patches_width) (n_channels patch_height patch_width)'
)
patches = einops.rearrange(
tensor=input,
pattern=pattern,
patch_height=patch_height,
patch_width=patch_width,
)
# Similar to how masked_tokens_encoded was computed
maskes_patches_original = patches[bitmask]
评估损失。
from einops import (
rearrange,
)
from torch import (
randn,
zeros,
)
from torch.nn import (
Linear,
Parameter,
)
from torch.nn.functional import (
l1_loss,
)
tokens = vit.patch_embed(input)
mask_token = torch.nn.Parameter(torch.randn(token_dim))
mask_tokens = self.mask_token.repeat(batch_size, n_tokens, 1)
indices_to_mask = randn(batch_size, n_tokens)
n_masked_tokens = int(0.5*n_tokens)
indices_to_mask = indices_to_mask.topk(
k=n_masked_tokens,
dim=1,
)
indices_to_mask = indices_to_mask.indices
bitmask = zeros(batch_size, n_tokens)
bitmask = bitmask.scatter(
dim=1,
index=indices_to_mask,
value=1,
)
bitmask = bitmask.bool()
bitmask = bitmask.unsqueeze(2)
tokens = (~bitmask)*tokens + bitmask*mask_tokens
tokens = tokens+vit.pos_embed[:, 1:]
encoded = vit.blocks(tokens)
bitmask = bitmask.squeeze(2)
masked_tokens_encoded = encoded[bitmask]
patch_height = patch_width = vit.patch_embed.patch_size
decoder_out_dim = 3*patch_height*patch_width
decoder = Linear(
in_features=token_dim,
out_features=decoder_out_dim,
)
masked_patches_reconstructed = decoder(masked_tokens_encoded)
pattern = (
'batch_size n_channels (n_patches_height patch_height) (n_patches_width patch_width) -> '
'batch_size (n_patches_height n_patches_width) (n_channels patch_height patch_width)'
)
patches = einops.rearrange(
tensor=input,
pattern=pattern,
patch_height=patch_height,
patch_width=patch_width,
)
maskes_patches_original = patches[bitmask]
# The loss is the L1 difference between
# the predicted pixel values and the ground truth,
# divided by the number of masked patches
loss = l1_loss(
input=masked_patches_reconstructed,
target=maskes_patches_original,
)/n_masked_tokens
把上面的代码封装成类并增加一些辅助函数,这里就不贴了有兴趣的看下最后的源代码。然后使用的时候如下:
from timm import (
create_model,
)
from torch.nn.functional import (
l1_loss,
)
from torch.optim import (
AdamW,
)
vit = create_model(
'vit_small_patch32_224',
num_classes=0,
)
simmim = SimMIM(
vit=vit,
masking_ratio=0.5,
)
optimizer = AdamW(
params=simmim.parameters(),
lr=1e-4,
weight_decay=5e-2,
)
for epoch in range(n_epochs):
for input in dataloader:
n_masked_tokens, masked_patches_reconstructed, masked_patches_original = simmim(input)
loss = l1_loss(
input=masked_patches_reconstructed,
target=maskes_patches_original,
)
loss /= n_masked_tokens
loss.backward()
optimizer.backward()
optimizer.zero_grad()
总结
引用:
A Simple Framework for Contrastive Learning of Visual Representations
https://arxiv.org/abs/2002.05709
Exploring Simple Siamese Representation Learning
https://arxiv.org/abs/2011.10566
SimMIM: A Simple Framework for Masked Image Modeling
https://arxiv.org/abs/2111.09886
本文代码:
https://github.com/BobMcDear/PyTorch-SimMIM
编辑:王菁
校对:林亦霖
评论
真高!比亚迪员工爆料比亚迪在越南的薪资水平:基本工资480万,全勤奖35万,交通补助20万,餐补110万,每周6天,每天10小时
上一篇:某大公司为逼迫员工离职,竟然把他的工位安排到厕所旁,没想到他直接开始记录领导的如厕时间,还发到公司大群...对此,你怎么看?--完--PS:欢迎在留言区留下你的观点,一起讨论提高。如果今天的文章让你有新的启发,欢迎转发分享给更多人。全文完,感谢你的耐心阅读。如果你还想看到我的文章,请一定给本
开发者全社区
0
某大公司为逼迫员工离职,竟然把他的工位安排到厕所旁,没想到他直接开始记录领导的如厕时间,还发到公司大群...
上一篇:字节的跳动职级与薪资(2024年)我们与公司间的合作,宛如两艘船只在茫茫大海上相互依靠,共同抵御风浪,携手驶向成功的彼岸。然而,当航向开始产生分歧,或是波涛汹涌的风浪改变了我们的初衷,我们或许应当冷静地选择和平分手,而非在风雨中硬撑。最近,一位网友的遭遇引起了广大职场人的关注和热议。这位网友
开发者全社区
0
金融研究 | 使用Python测量关键审计事项的「信息含量」
Tips: 公众号推送后内容只能更改一次,且只能改20字符。如果内容出问题,或者想更新内容, 只能重复推送。为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-01-13-information-content-of-critical-aud
大邓和他的Python
0
我看阿里的年终奖总算发了!
到4月底了,这两天看朋友圈,发现阿里的年终奖终于发了,问了问老同学,也从网上检索了不少信息,基本搞清楚了阿里今年的年终奖情况。近来来阿里一些集团对绩效等级做了较大的调整,以前的旧绩效系统中,绩效分为3.25、3.5、3.75、4和5五个等级,其中4和5是较高绩效等级,较少见。而且之前3.5绩效内部划
公子龙
0
CVPR 2024|大视觉模型的开山之作!无需任何语言数据即可打造大视觉模型
↑ 点击蓝字 关注极市平台作者丨科技猛兽编辑丨极市平台极市导读 本文提出一种序列建模 (sequential modeling) 的方法,不使用任何语言数据,训练大视觉模型。>>加入极市CV技术交流群,走在计算机视觉的最前沿本文目录1 序列建模打造大视觉模型(来自 U
极市平台
1
金融研究(更新) | 使用Python构建关键审计事项的「信息含量」
Tips: 公众号推送后内容只能更改一次,且只能改20字符。如果内容出问题,或者想更新内容, 只能重复推送。为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-01-13-information-content-of-critical-aud
大邓和他的Python
0
字节的跳动职级与薪资(2024年)
上一篇:阿里公布年终奖,P7, 3.5+,22W年终奖,还有35W长期现金激励,真香字节跳动自2012年3月成立以来,已经迅速成长为一个全球性的科技公司。其产品和服务已经遍布全球150多个国家与地区,并且支持超过75种不同的语言。在字节跳动的官方网站上,列出了一系列引人注目的产品和服务,包括但不限于
开发者全社区
0
盘点Lombok的几个骚操作,你绝对没用过!
👉 欢迎加入小哈的星球 ,你将获得: 专属的项目实战 / Java 学习路线 / 一对一提问 / 学习打卡 / 赠书福利全栈前后端分离博客项目 2.0 版本完结啦, 演示链接:http://116.62.199.48/ ,新项目正在酝酿中
小哈学Java
0