使用PyTorch进行小样本学习的图像分类
数据派THU
共 7219字,需浏览 15分钟
· 2022-11-18
来源:DeepHub IMBA
什么是小样本学习?
N-Shot Learning (NSL) Few-Shot Learning ( FSL ) One-Shot Learning (OSL) Zero-Shot Learning (ZSL)
小样本学习方法
小样本学习图像分类算法
元学习者在每个分集(episode)开始时创建自己的副本C, C 在这一分集上进行训练(在 base-model 的帮助下), C 对查询集进行预测, 从这些预测中计算出的损失用于更新 C, 这种情况一直持续到完成所有分集的训练。
来自支持集和查询集的每个图像都被馈送到一个 CNN,该 CNN 为它们输出特征的嵌入 查询图像使用支持集训练的模型得到嵌入特征的余弦距离,通过 softmax 进行分类 分类结果的交叉熵损失通过 CNN 反向传播更新特征嵌入模型
使用 Open-AI Clip 进行零样本学习
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.gitimport numpy as np
import torch
from pkg_resources import packaging
print("Torch version:", torch.__version__)
import clipclip.available_models() # it will list the names of available CLIP modelsmodel, preprocess = clip.load("ViT-B/32")
model.cuda().eval()
input_resolution = model.visual.input_resolution
context_length = model.context_length
vocab_size = model.vocab_size
print("Model parameters:", f"{np.sum([int(np.prod(p.shape)) for p in model.parameters()]):,}")
print("Input resolution:", input_resolution)
print("Context length:", context_length)
print("Vocab size:", vocab_size)
import os
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
from collections import OrderedDict
import torch
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
# images in skimage to use and their textual descriptions
descriptions = {
"page": "a page of text about segmentation",
"chelsea": "a facial photo of a tabby cat",
"astronaut": "a portrait of an astronaut with the American flag",
"rocket": "a rocket standing on a launchpad",
"motorcycle_right": "a red motorcycle standing in a garage",
"camera": "a person looking at a camera on a tripod",
"horse": "a black-and-white silhouette of a horse",
"coffee": "a cup of coffee on a saucer"
}original_images = []
images = []
texts = []
plt.figure(figsize=(16, 5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}\n{descriptions[name]}")
plt.xticks([])
plt.yticks([])
original_images.append(image)
images.append(preprocess(image))
texts.append(descriptions[name])
plt.tight_layout()
image_input = torch.tensor(np.stack(images)).cuda()
text_tokens = clip.tokenize(["This is " + desc for desc in texts]).cuda()
with torch.no_grad():
image_features = model.encode_image(image_input).float()
text_features = model.encode_text(text_tokens).float()
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = text_features.cpu().numpy() @ image_features.cpu().numpy().T
count = len(descriptions)
plt.figure(figsize=(20, 14))
plt.imshow(similarity, vmin=0.1, vmax=0.3)
# plt.colorbar()
plt.yticks(range(count), texts, fontsize=18)
plt.xticks([])
for i, image in enumerate(original_images):
plt.imshow(image, extent=(i - 0.5, i + 0.5, -1.6, -0.6), origin="lower")
for x in range(similarity.shape[1]):
for y in range(similarity.shape[0]):
plt.text(x, y, f"{similarity[y, x]:.2f}", ha="center", va="center", size=12)
for side in ["left", "top", "right", "bottom"]:
plt.gca().spines[side].set_visible(False)
plt.xlim([-0.5, count - 0.5])
plt.ylim([count + 0.5, -2])
plt.title("Cosine similarity between text and image features", size=20)
from torchvision.datasets import CIFAR100
cifar100 = CIFAR100(os.path.expanduser("~/.cache"), transform=preprocess, download=True)
text_descriptions = [f"This is a photo of a {label}" for label in cifar100.classes]
text_tokens = clip.tokenize(text_descriptions).cuda()
with torch.no_grad():
text_features = model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_probs = (100.0 * image_features @ text_features.T).softmax(dim=-1)
top_probs, top_labels = text_probs.cpu().topk(5, dim=-1)
plt.figure(figsize=(16, 16))
for i, image in enumerate(original_images):
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(top_probs.shape[-1])
plt.grid()
plt.barh(y, top_probs[i])
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, [cifar100.classes[index] for index in top_labels[i].numpy()])
plt.xlabel("probability")
plt.subplots_adjust(wspace=0.5)
plt.show()
评论
真高!比亚迪员工爆料比亚迪在越南的薪资水平:基本工资480万,全勤奖35万,交通补助20万,餐补110万,每周6天,每天10小时
上一篇:某大公司为逼迫员工离职,竟然把他的工位安排到厕所旁,没想到他直接开始记录领导的如厕时间,还发到公司大群...对此,你怎么看?--完--PS:欢迎在留言区留下你的观点,一起讨论提高。如果今天的文章让你有新的启发,欢迎转发分享给更多人。全文完,感谢你的耐心阅读。如果你还想看到我的文章,请一定给本
开发者全社区
0
太敢穿了!透视纱裙!性感火辣的身材
绝了呀今天的厂花:吴宣仪1995年1月26日,吴宣仪出生于海南省海口市,中国内地流行乐女歌手、影视演员。2016年2月,吴宣仪随宇宙少女发行首张迷你专辑正式出道。2018年4月,她参加《创造101》综艺选秀,获得第二名,成功加入火箭少女101组合。吴宣仪的颜值一直备受称赞,她的五官立体精致,皮肤白皙
逆锋起笔
0
某大公司为逼迫员工离职,竟然把他的工位安排到厕所旁,没想到他直接开始记录领导的如厕时间,还发到公司大群...
上一篇:字节的跳动职级与薪资(2024年)我们与公司间的合作,宛如两艘船只在茫茫大海上相互依靠,共同抵御风浪,携手驶向成功的彼岸。然而,当航向开始产生分歧,或是波涛汹涌的风浪改变了我们的初衷,我们或许应当冷静地选择和平分手,而非在风雨中硬撑。最近,一位网友的遭遇引起了广大职场人的关注和热议。这位网友
开发者全社区
0
金融研究 | 使用Python测量关键审计事项的「信息含量」
Tips: 公众号推送后内容只能更改一次,且只能改20字符。如果内容出问题,或者想更新内容, 只能重复推送。为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-01-13-information-content-of-critical-aud
大邓和他的Python
0
我看阿里的年终奖总算发了!
到4月底了,这两天看朋友圈,发现阿里的年终奖终于发了,问了问老同学,也从网上检索了不少信息,基本搞清楚了阿里今年的年终奖情况。近来来阿里一些集团对绩效等级做了较大的调整,以前的旧绩效系统中,绩效分为3.25、3.5、3.75、4和5五个等级,其中4和5是较高绩效等级,较少见。而且之前3.5绩效内部划
公子龙
0
CVPR 2024|大视觉模型的开山之作!无需任何语言数据即可打造大视觉模型
↑ 点击蓝字 关注极市平台作者丨科技猛兽编辑丨极市平台极市导读 本文提出一种序列建模 (sequential modeling) 的方法,不使用任何语言数据,训练大视觉模型。>>加入极市CV技术交流群,走在计算机视觉的最前沿本文目录1 序列建模打造大视觉模型(来自 U
极市平台
1
金融研究(更新) | 使用Python构建关键审计事项的「信息含量」
Tips: 公众号推送后内容只能更改一次,且只能改20字符。如果内容出问题,或者想更新内容, 只能重复推送。为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-01-13-information-content-of-critical-aud
大邓和他的Python
0
词向量(更新) | 使用MD&A2001-2022语料训练Word2Vec模型
buTips: 公众号推送后内容只能更改一次,且只能改20字符。 如果内容出问题,或者想更新内容, 只能重复推送。 为了更好的阅读体验,建议阅读本文博客版, 链接地址https://textdata.cn/blog/2023-03-24-load-w2v-and-expand-your-concpe
大邓和他的Python
0