基于keras实现多标签分类(multi-label classification)
共 2340字,需浏览 5分钟
·
2022-02-06 09:12
向AI转型的程序员都关注了这个号👇👇👇
机器学习AI算法工程 公众号:datayx
首先讨论多标签分类数据集(以及如何快速构建自己的数据集)。
之后简要讨论SmallerVGGNet,我们将实现的Keras神经网络架构,并用于多标签分类。
然后我们将实施SmallerVGGNet并使用我们的多标签分类数据集对其进行训练。
最后,我们将通过在示例图像上测试我们的网络,并讨论何时适合多标签分类,包括需要注意的一些注意事项。
相关代码,获取方式:
关注微信公众号 datayx 然后回复 多标签分类 即可获取。
数据集包含六个类别的2,167个图像,包括:
黑色牛仔裤(344图像)
蓝色连衣裙(386图像)
蓝色牛仔裤(356图像)
蓝色衬衫(369图像)
红色连衣裙(380图像)
红色衬衫(332图像)
6类图像数据可以通过python爬虫在网站上抓取得到。
为了方便起见,可以通过使用Bing图像搜索API(Microsoft’s Bing Image Search API)建立图像数据(需要在线注册获得api key,使用key进行图像搜索),python代码:
使用find方法得到下载的图像数据数目
多标签分类multi-label classsification
这里给出的是项目的文件结构
多标签分类的网络结构--smallervggnet【Very Deep Convolutional Networks for Large Scale Image Recognition.】
https://arxiv.org/pdf/1409.1556/
smallervggnet.py
train.py
run
继续preprocessing
run
构建训练和测试数据集,做数据增强
构建模型,初始化Adam优化器
编译模型,开始训练
训练后保存模型,并二值化标签
绘制出acc,loss
绘制好的结果会保存成图片格式保存。
多标签分类模型训练
python train.py --dataset dataset --model fashion.model --labelbin mlb.pickle
使用训练完成的模型预测新的图像
classify.py
最终显示出预测的分类结果
使用Keras执行多标签分类非常简单,包括两个主要步骤:
1.使用sigmoid激活替换网络末端的softmax激活
2.二值交叉熵作为分类交叉熵损失函数
shortcomings:
网络无法预测没有在训练集中出现过的数据样品,如果出现的次数过少,预测的效果也不会很好,解决办法是增大数据集,这样可能非常不容易,还有一种用的已经很多的方法用在大的数据集上训练得到的权重数据对网络做初始化,提高模型的泛化能力。
机器学习算法AI大数据技术
搜索公众号添加: datanlp
长按图片,识别二维码
阅读过本文的人还看了以下文章:
基于40万表格数据集TableBank,用MaskRCNN做表格检测
《深度学习入门:基于Python的理论与实现》高清中文PDF+源码
2019最新《PyTorch自然语言处理》英、中文版PDF+源码
《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码
PyTorch深度学习快速实战入门《pytorch-handbook》
【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》
李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材
【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类
如何利用全新的决策树集成级联结构gcForest做特征工程并打分?
Machine Learning Yearning 中文翻译稿
斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)
中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程
不断更新资源
深度学习、机器学习、数据分析、python
搜索公众号添加: datayx