轻松学pytorch – 使用多标签损失函数训练卷积网络

共 13275字,需浏览 27分钟

 ·

2022-06-24 10:55

点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

大家好,我还在坚持继续写,如果我没有记错的话,这个是系列文章的第十五篇,pytorch中有很多非常方便使用的损失函数,本文就演示了如何通过多标签损失函数训练验证码识别网络,实现验证码识别。 


数据集


这个数据是来自Kaggle上的一个验证码识别例子,作者采用的是迁移学习,基于ResNet18做到的训练。

https://www.kaggle.com/anjalichoudhary12/captcha-with-pytorch

这个数据集总计有1070张验证码图像,我把其中的1040张用作训练,30张作为测试,使用pytorch自定义了一个数据集类,代码如下:

 1import torch
2import numpy as np
3from torch.utils.data import Dataset, DataLoader
4from torchvision import transforms
5import os
6import cv2 as cv
7
8NUMBER = ['0''1''2''3''4''5''6''7''8''9']
9ALPHABET = ['a''b''c''d''e''f''g''h''i''j''k''l''m''n''o''p''q''r''s''t''u''v''w''x''y''z']
10ALL_CHAR_SET = NUMBER + ALPHABET
11ALL_CHAR_SET_LEN = len(ALL_CHAR_SET)
12MAX_CAPTCHA = 5
13
14
15def output_nums():
16    return MAX_CAPTCHA * ALL_CHAR_SET_LEN
17
18
19def encode(a):
20    onehot = [0]*ALL_CHAR_SET_LEN
21    idx = ALL_CHAR_SET.index(a)
22    onehot[idx] += 1
23    return onehot
24
25
26class CapchaDataset(Dataset):
27    def __init__(self, root_dir):
28        self.transform = transforms.Compose([transforms.ToTensor()])
29        img_files = os.listdir(root_dir)
30        self.txt_labels = []
31        self.encodes = []
32        self.images = []
33        for file_name in img_files:
34            label = file_name[:-4]
35            label_oh = []
36            for i in label:
37                label_oh += encode(i)
38            self.images.append(os.path.join(root_dir, file_name))
39            self.encodes.append(np.array(label_oh))
40            self.txt_labels.append(label)
41
42    def __len__(self):
43        return len(self.images)
44
45    def num_of_samples(self):
46        return len(self.images)
47
48    def __getitem__(self, idx):
49        if torch.is_tensor(idx):
50            idx = idx.tolist()
51            image_path = self.images[idx]
52        else:
53            image_path = self.images[idx]
54        img = cv.imread(image_path)  # BGR order
55        h, w, c = img.shape
56        # rescale
57        img = cv.resize(img, (12832))
58        img = (np.float32(img) /255.0 - 0.5) / 0.5
59        # H, W C to C, H, W
60        img = img.transpose((201))
61        sample = {'image': torch.from_numpy(img), 'encode': self.encodes[idx], 'label': self.txt_labels[idx]}
62        return sample

 

模型实现

 

基于ResNet的block结构,我实现了一个比较简单的残差网络,最后加一个全连接层输出多个标签。验证码是有5个字符的,每个字符的是小写26个字母加上0~9十个数字,总计36个类别,所以5个字符就有5x36=180个输出,其中每个字符是独热编码,这个可以从数据集类的实现看到。模型的输入与输出格式:

输入:NCHW=Nx3x32x128
卷积层最终输出:NCHW=Nx256x1x4
全连接层:Nx(256x4)
最终输出层:Nx180

代码实现如下:


 1class CapchaResNet(torch.nn.Module):
2    def __init__(self):
3        super(CapchaResNet, self).__init__()
4        self.cnn_layers = torch.nn.Sequential(
5            # 卷积层 (128x32x3)
6            ResidualBlock(3321),
7            ResidualBlock(32642),
8            ResidualBlock(64642),
9            ResidualBlock(641282),
10            ResidualBlock(1282562),
11            ResidualBlock(2562562),
12        )
13
14        self.fc_layers = torch.nn.Sequential(
15            torch.nn.Linear(256 * 4, output_nums()),
16        )
17
18    def forward(self, x):
19        # stack convolution layers
20        x = self.cnn_layers(x)
21        out = x.view(-14 * 256)
22        out = self.fc_layers(out)
23        return out

 

模型训练与测试


使用多标签损失函数,Adam优化器,代码实现如下:

 1model = CapchaResNet()
2print(model)
3
4# 使用GPU
5if train_on_gpu:
6    model.cuda()
7
8ds = CapchaDataset("D:/python/pytorch_tutorial/capcha/samples")
9num_train_samples = ds.num_of_samples()
10bs = 16
11dataloader = DataLoader(ds, batch_size=bs, shuffle=True)
12
13# 训练模型的次数
14num_epochs = 25
15# optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
16optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)
17model.train()
18
19# 损失函数
20mul_loss = torch.nn.MultiLabelSoftMarginLoss()
21index = 0
22for epoch in range(num_epochs):
23    train_loss = 0.0
24    for i_batch, sample_batched in enumerate(dataloader):
25        images_batch, oh_labels = \
26            sample_batched['image'], sample_batched['encode']
27        if train_on_gpu:
28            images_batch, oh_labels= images_batch.cuda(), oh_labels.cuda()
29        optimizer.zero_grad()
30
31        # forward pass: compute predicted outputs by passing inputs to the model
32        m_label_out_ = model(images_batch)
33        oh_labels = torch.autograd.Variable(oh_labels.float())
34
35        # calculate the batch loss
36        loss = mul_loss(m_label_out_, oh_labels)
37
38        # backward pass: compute gradient of the loss with respect to model parameters
39        loss.backward()
40
41        # perform a single optimization step (parameter update)
42        optimizer.step()
43
44        # update training loss
45        train_loss += loss.item()
46        if index % 100 == 0:
47            print('step: {} \tTraining Loss: {:.6f} '.format(index, loss.item()))
48        index += 1
49
50        # 计算平均损失
51    train_loss = train_loss / num_train_samples
52
53    # 显示训练集与验证集的损失函数
54    print('Epoch: {} \tTraining Loss: {:.6f} '.format(epoch, train_loss))
55
56# save model
57model.eval()
58torch.save(model, 'capcha_recognize_model.pt')

调用保存之后的模型,对图像测试代码如下:

 1cnn_model = torch.load("./capcha_recognize_model.pt")
2root_dir = "D:/python/pytorch_tutorial/capcha/testdata"
3files = os.listdir(root_dir)
4one_hot_len = ALL_CHAR_SET_LEN
5for file in files:
6    if os.path.isfile(os.path.join(root_dir, file)):
7        image = cv.imread(os.path.join(root_dir, file))
8        h, w, c = image.shape
9        img = cv.resize(image, (12832))
10        img = (np.float32(img) /255.0 - 0.5) / 0.5
11        img = img.transpose((201))
12        x_input = torch.from_numpy(img).view(1332128)
13        probs = cnn_model(x_input.cuda())
14        mul_pred_labels = probs.squeeze().cpu().tolist()
15        c0 = ALL_CHAR_SET[np.argmax(mul_pred_labels[0:one_hot_len])]
16        c1 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len:one_hot_len*2])]
17        c2 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*2:one_hot_len*3])]
18        c3 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*3:one_hot_len*4])]
19        c4 = ALL_CHAR_SET[np.argmax(mul_pred_labels[one_hot_len*4:one_hot_len*5])]
20        pred_txt = '%s%s%s%s%s' % (c0, c1, c2, c3, c4)
21        cv.putText(image, pred_txt, (1020), cv.FONT_HERSHEY_PLAIN, 1.5, (00255), 2)
22        print("current code : %s, predict code : %s "%(file[:-4], pred_txt))
23        cv.imshow("capcha predict", image)
24        cv.waitKey(0)

其中对输入结果,要根据每个字符的独热编码,截取成五个独立的字符分类标签,然后使用argmax获取index根据index查找类别标签,得到最终的验证码预测字符串,代码运行结果如下:

好消息!

小白学视觉知识星球

开始面向外开放啦👇👇👇




下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 102
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报