深度学习文本分类|模型&代码&技巧
文本分类是NLP的必备入门任务,在搜索、推荐、对话等场景中随处可见,并有情感分析、新闻分类、标签分类等成熟的研究分支和数据集。
本文主要介绍深度学习文本分类的常用模型原理、优缺点以及技巧。
P.S. 有基础的同学可以直接看文末的技巧
Fasttext
论文:https://arxiv.org/abs/1607.01759
代码:https://github.com/facebookresearch/fastText
Fasttext是Facebook推出的一个便捷的工具,包含文本分类和词向量训练两个功能。
Fasttext的分类实现很简单:把输入转化为词向量,取平均,再经过线性分类器得到类别。输入的词向量可以是预先训练好的,也可以随机初始化,跟着分类任务一起训练。
Fasttext直到现在还被不少人使用,主要有以下优点:
模型本身复杂度低,但效果不错,能快速产生任务的baseline Facebook使用C++进行实现,进一步提升了计算效率 采用了char-level的n-gram作为附加特征,比如paper的trigram是 [pap, ape, per],在将输入paper转为向量的同时也会把trigram转为向量一起参与计算。这样一方面解决了长尾词的OOV (out-of-vocabulary)问题,一方面利用n-gram特征提升了表现 当类别过多时,支持采用hierarchical softmax进行分类,提升效率
对于文本长且对速度要求高的场景,Fasttext是baseline首选。同时用它在无监督语料上训练词向量,进行文本表示也不错。不过想继续提升效果还需要更复杂的模型。
TextCNN
论文:https://arxiv.org/abs/1408.5882
代码:https://github.com/yoonkim/CNN_sentence
TextCNN是Yoon Kim小哥在2014年提出的模型,开创了用CNN编码n-gram特征的先河。
模型结构如图,图像中的卷积都是二维的,而TextCNN则使用「一维卷积」,即filter_size * embedding_dim
,有一个维度和embedding相等。这样就能抽取filter_size个gram的信息。以1个样本为例,整体的前向逻辑是:
对词进行embedding,得到 [seq_length, embedding_dim]
用N个卷积核,得到N个 seq_length-filter_size+1
长度的一维feature map对feature map进行max-pooling(因为是时间维度的,也称max-over-time pooling),得到N个 1x1
的数值,拼接成一个N维向量,作为文本的句子表示将N维向量压缩到类目个数的维度,过Softmax
在TextCNN的实践中,有很多地方可以优化(参考这篇论文[1]):
Filter尺寸:这个参数决定了抽取n-gram特征的长度,这个参数主要跟数据有关,平均长度在50以内的话,用10以下就可以了,否则可以长一些。在调参时可以先用一个尺寸grid search,找到一个最优尺寸,然后尝试最优尺寸和附近尺寸的组合 Filter个数:这个参数会影响最终特征的维度,维度太大的话训练速度就会变慢。这里在100-600之间调参即可 CNN的激活函数:可以尝试Identity、ReLU、tanh 正则化:指对CNN参数的正则化,可以使用dropout或L2,但能起的作用很小,可以试下小的dropout率(<0.5),L2限制大一点 Pooling方法:根据情况选择mean、max、k-max pooling,大部分时候max表现就很好,因为分类任务对细粒度语义的要求不高,只抓住最大特征就好了 Embedding表:中文可以选择char或word级别的输入,也可以两种都用,会提升些效果。如果训练数据充足(10w+),也可以从头训练 蒸馏BERT的logits,利用领域内无监督数据 加深全连接:原论文只使用了一层全连接,而加到3、4层左右效果会更好[2]
TextCNN是很适合中短文本场景的强baseline,但不太适合长文本,因为卷积核尺寸通常不会设很大,无法捕获长距离特征。同时max-pooling也存在局限,会丢掉一些有用特征。另外再仔细想的话,TextCNN和传统的n-gram词袋模型本质是一样的,它的好效果很大部分来自于词向量的引入[3],因为解决了词袋模型的稀疏性问题。
DPCNN
论文:https://ai.tencent.com/ailab/media/publications/ACL3-Brady.pdf
代码:https://github.com/649453932/Chinese-Text-Classification-Pytorch
上面介绍TextCNN有太浅和长距离依赖的问题,那直接多怼几层CNN是否可以呢?感兴趣的同学可以试试,就会发现事情没想象的那么简单。直到2017年,腾讯才提出了把TextCNN做到更深的DPCNN模型:
上图中的ShallowCNN指TextCNN。DPCNN的核心改进如下:
在Region embedding时不采用CNN那样加权卷积的做法,而是对n个词进行pooling后再加个1x1的卷积,因为实验下来效果差不多,且作者认为前者的表示能力更强,容易过拟合 使用1/2池化层,用size=3 stride=2的卷积核,直接让模型可编码的sequence长度翻倍(自己在纸上画一下就get啦) 残差链接,参考ResNet,减缓梯度弥散问题
凭借以上一些精妙的改进,DPCNN相比TextCNN有1-2个百分点的提升。
TextRCNN
论文:https://dl.acm.org/doi/10.5555/2886521.2886636
代码:https://github.com/649453932/Chinese-Text-Classification-Pytorch
除了DPCNN那样增加感受野的方式,RNN也可以缓解长距离依赖的问题。下面介绍一篇经典TextRCNN。
模型的前向过程是:
得到单词 i 的表示 通过RNN得到左右双向的表示 和 将表示拼接得到 ,再经过变换得到 对多个 进行 max-pooling,得到句子表示 ,在做最终的分类
这里的convolutional是指max-pooling。通过加入RNN,比纯CNN提升了1-2个百分点。
TextBiLSTM+Attention
论文:https://www.aclweb.org/anthology/P16-2034.pdf
代码:https://github.com/649453932/Chinese-Text-Classification-Pytorch
从前面介绍的几种方法,可以自然地得到文本分类的框架,就是先基于上下文对token编码,然后pooling出句子表示再分类。在最终池化时,max-pooling通常表现更好,因为文本分类经常是主题上的分类,从句子中一两个主要的词就可以得到结论,其他大多是噪声,对分类没有意义。而到更细粒度的分析时,max-pooling可能又把有用的特征去掉了,这时便可以用attention进行句子表示的融合:
BiLSTM就不解释了,要注意的是,计算attention score时会先进行变换:
其中 是context vector,随机初始化并随着训练更新。最后得到句子表示 ,再进行分类。
这个加attention的套路用到CNN编码器之后代替pooling也是可以的,从实验结果来看attention的加入可以提高2个点。如果是情感分析这种由句子整体决定分类结果的任务首选RNN。
HAN
论文:https://www.aclweb.org/anthology/N16-1174.pdf
代码:https://github.com/richliao/textClassifier
上文都是句子级别的分类,虽然用到长文本、篇章级也是可以的,但速度精度都会下降,于是有研究者提出了层次注意力分类框架,即Hierarchical Attention。先对每个句子用 BiGRU+Att 编码得到句向量,再对句向量用 BiGRU+Att 得到doc级别的表示进行分类:
方法很符合直觉,不过实验结果来看比起avg、max池化只高了不到1个点(狗头,真要是很大的doc分类,好好清洗下,fasttext其实也能顶的(捂脸。
BERT
BERT的原理代码就不用放了叭~
BERT分类的优化可以尝试:
多试试不同的预训练模型,比如RoBERT、WWM、ALBERT 除了 [CLS] 外还可以用 avg、max 池化做句表示,甚至可以把不同层组合起来 在领域数据上增量预训练 集成蒸馏,训多个大模型集成起来后蒸馏到一个上 先用多任务训,再迁移到自己的任务
其他模型
除了上述常用模型之外,还有Capsule Network[4]、TextGCN[5]等红极一时的模型,因为涉及的背景知识较多,本文就暂不介绍了(嘻嘻)。
虽然实际的落地应用中比较少见,但在机器学习比赛中还是可以用的。Capsule Network被证明在多标签迁移的任务上性能远超CNN和LSTM[6],但这方面的研究在18年以后就很少了。TextGCN则可以学到更多的global信息,用在半监督场景中,但碰到较长的需要序列信息的文本表现就会差些[7]。
技巧
模型说得差不多了,下面介绍一些自己的数据处理血泪经验,如有不同意见欢迎讨论~
数据集构建
首先是标签体系的构建,拿到任务时自己先试标一两百条,看有多少是难确定(思考1s以上)的,如果占比太多,那这个任务的定义就有问题。可能是标签体系不清晰,或者是要分的类目太难了,这时候就要找项目owner去反馈而不是继续往下做。
其次是训练评估集的构建,可以构建两个评估集,一个是贴合真实数据分布的线上评估集,反映线上效果,另一个是用规则去重后均匀采样的随机评估集,反映模型的真实能力。训练集则尽可能和评估集分布一致,有时候我们会去相近的领域拿现成的有标注训练数据,这时就要注意调整分布,比如句子长度、标点、干净程度等,尽可能做到自己分不出这个句子是本任务的还是从别人那里借来的。
最后是数据清洗:
去掉文本强pattern:比如做新闻主题分类,一些爬下来的数据中带有的XX报道、XX编辑高频字段就没有用,可以对语料的片段或词进行统计,把很高频的无用元素去掉。还有一些会明显影响模型的判断,比如之前我在判断句子是否为无意义的闲聊时,发现加个句号就会让样本由正转负,因为训练预料中的闲聊很少带句号(跟大家的打字习惯有关),于是去掉这个pattern就好了不少 纠正标注错误:这个我真的屡试不爽,生生把自己从一个算法变成了标注人员。简单的说就是把训练集和评估集拼起来,用该数据集训练模型两三个epoch(防止过拟合),再去预测这个数据集,把模型判错的拿出来按 abs(label-prob) 排序,少的话就自己看,多的话就反馈给标注人员,把数据质量搞上去了提升好几个点都是可能的
长文本
任务简单的话(比如新闻分类),直接用fasttext就可以达到不错的效果。
想要用BERT的话,最简单的方法是粗暴截断,比如只取句首+句尾、句首+tfidf筛几个词出来;或者每句都预测,最后对结果综合。
另外还有一些魔改的模型可以尝试,比如XLNet、Reformer、Longformer。
如果是离线任务且来得及的话还是建议跑全部,让我们相信模型的编码能力。
少样本
自从用了BERT之后,很少受到数据不均衡或者过少的困扰,先无脑训一版。
如果样本在几百条,可以先把分类问题转化成匹配问题,或者用这种思想再去标一些高置信度的数据,或者用自监督、半监督的方法。
鲁棒性
在实际的应用中,鲁棒性是个很重要的问题,否则在面对badcase时会很尴尬,怎么明明那样就分对了,加一个字就错了呢?
这里可以直接使用一些粗暴的数据增强,加停用词加标点、删词、同义词替换等,如果效果下降就把增强后的训练数据洗一下。
当然也可以用对抗学习、对比学习这样的高阶技巧来提升,一般可以提1个点左右,但不一定能避免上面那种尴尬的情况。
总结
文本分类是工业界最常用的任务,同时也是大多数NLPer入门做的第一个任务,我当年就是啥都不会,从训练到部署地实践了文本分类后就顺畅了。上文给出了不少模型,但实际任务中常用的也就那几个,下面是快速选型的建议:
实际上,落地时主要还是和数据的博弈。数据决定模型的上限,大多数人工标注的准确率达到95%以上就很好了,而文本分类通常会对准确率的要求更高一些,与其苦苦调参想fancy的结构,不如好好看看badcase,做一些数据增强提升模型鲁棒性更实用。
参考资料
A Sensitivity Analysis of (and Practitioners' Guide to) Convolutional Neural Networks for Sentence Classification: https://arxiv.org/pdf/1510.03820.pdf
[2]卷积层和分类层,哪个更重要?: https://www.zhihu.com/question/270245936
[3]从经典文本分类模型TextCNN到深度模型DPCNN: https://zhuanlan.zhihu.com/p/35457093
[4]揭开迷雾,来一顿美味的Capsule盛宴: https://kexue.fm/archives/4819
[5]Graph Convolutional Networks for Text Classification: https://arxiv.org/abs/1809.05679
[6]胶囊网络(Capsule Network)在文本分类中的探索: https://zhuanlan.zhihu.com/p/35409788
[7]怎么看待最近比较火的 GNN?: https://www.zhihu.com/question/307086081/answer/717456124