深度学习TabNet能否超越GBDT?

共 4456字,需浏览 9分钟

 ·

2022-05-14 00:12

996a249beed66533f5153864ddbc9e03.webp

向AI转型的程序员都关注了这个号👇👇👇

机器学习AI算法工程   公众号:datayx


随着深度神经网络的不断发展,DNN在图像、文本和语音等类型的数据上都有了广泛的应用,然而对于同样非常常见的一种数据——表格数据,DNN却似乎并没有取得像它在其他领域那么大的成功。参加过Kaggle等数据挖掘竞赛的同学应该都知道,对于采用表格数据的任务,基本都是决策树模型的主场,像XGBoost和LightGBM这类提升(Boosting)树模型已经成为了现在数据挖掘比赛中的标配。相比于DNN,这类树模型好处主要有:

  • 模型的决策流形(decision manifolds)是可以看成是超平面边界的,对于表格数据的效果很好

  • 可以根据决策树追溯推断过程,可解释性较好

  • 训练起来更快

而对于DNN,它的优势在于:

  • 类似于图像和文本,可以对表格数据进行编码(encode),从而得到一个能够表征表格数据的方法,这种表征学习(representation learning)可以用在很多地方

  • 可以减少对于特征工程(feature engineering)的依赖(相信打过比赛的同学都知道这有多重要)

  • 可以通过online learning的方式来更新模型,而树模型只能用整个数据集重新训练

然而对于传统的DNN,一味地堆叠网络层很容易导致模型过参数化(overparametrized),导致DNN在表格数据集上表现并不尽如人意。因此,如果能够设计这样一种DNN,它既吸收了树模型的长处,又继承了DNN的优点,那么这样的模型无疑是针对于表格数据的一大利器,而这次介绍的论文就巧妙地设计出了这样的模型——TabNet,它在保留DNN的end-to-end和representation learning特点的基础上,还拥有了树模型的可解释性和稀疏特征选择的优点,这使得它在具备DNN优点的同时,在表格数据任务上也可以和目前主流的树模型相媲美,接下来我们就开始具体介绍TabNet。

用DNN构造决策树

既然想要让DNN具有树模型的优点,那么我们首先需要解决的一个问题就是:如何构建一个与树模型具有相似决策流形的神经网络?下图是一个决策树流形的简单示例。

cb5c8f55631204e4fdc53ad1b5fc5d70.webp


732bf1ab019fa6b5ab764cd2eeebace3.webp



fccbed305407fe2ca41788bd00b2ab23.webp



模型架构

为了理解起来比较容易,上面的那个神经网络构造得比较简单,作为一个加性模型它只有两步,Mask层是人为设置好的,特征计算用的也是一个简单的FC层,而接下来介绍的TabNet就对这些地方做了改进,它的基本结构如下所示。

98fe1ceb5b3d27ff47da0f2606c1597c.webp


2b87077b612f434f8e7104a394e9334a.webp



763ff21000d980a3122b02367204c16c.webp



0d9f5eef0367e44dc84cf0915a0c7297.webp


71d6863e8ae6cdb1a23ecc33029d6d94.webp


  • 特征选择:Attentive transformer层可以根据上一个step的结果得到当前step的Mask矩阵,并尽量使得Mask矩阵是稀疏且不重复的。值得注意的一点是,不同样本的Mask向量可以不同,也就是说TabNet可以让不同的样本选择不同的特征(instance-wise),而这个特点是树模型所不具备的,对于XGBoost这类加性模型,一个step就是一棵树,而这棵决策树用到的特征是在所有样本上挑选出来的(例如通过计算信息增益),它没有办法做到instance-wise。

  • 特征计算:Feature transformer层实现了对于当前step步所选取特征的计算处理。还是类比于决策树,对于给定的一些特征,一棵决策树构造的是单个特征的大小关系的组合,也就是上面提到的决策流形,而之前那个简单神经网络就是通过一个FC层来模仿这个决策流形,但FC层只是构造了一组简单的线性关系,并没有考虑更加复杂的情况,因此TabNet通过更复杂的Feature transformer层来进行特征计算,个人感觉它的决策流形不一定和决策树的相似,在一些特征组合上它可能比决策树做得更好。

自监督学习

前面提到了DNN的一个好处就是可以进行表征学习,而TabNet就应用了自监督学习的方法,通过encoder-decoder框架来获得表格数据的representation,从而也有助于分类和回归任务,如下图所示:

1d3f87bd8ad64317f5009be7ed31c30c.webp



简单来说,我们认为同一样本的不同特征之间是有关联的,因此自监督学习就是先人为mask掉一些feature,然后通过encoder-decoder模型来对mask掉的feature进行预测。我们认为通过这样的方式训练出来的encoder模型,可以有效地将表征样本的feature(可以理解为对数据进行了编码或压缩),这时再将encoder模型由于回归或分类任务,就能够事半功倍。自监督学习时的encoder模型就是上图中的模型,decoder模型如下所示:

aa3ba1230fcf233ca589d93a7768ded0.webp

这里的encoded representation就是encoder中没有经过FC层的加和向量,将它作为decoder的输入,decoder同样利用了Feature transformer层,只不过这次的目的是将representation向量重构为feature,然后类似地经过若干个step的加和,得到最后的重构feature。


dbffe2e3e037a0abe90d176a21f6bca9.webp


实验

为了证明TabNet确实具有上文中提到的种种优点,这篇文章在不同的数据集上进行了各种类型的实验,这里只介绍一部分,其它实验以及具体实验细节可以看论文原文,写得也很详细。



这个模型的代码:

tensorflow版本的代码

https://github.com/google-research/google-research/tree/master/tabnet


  • pytorch版本的代码

  • https://github.com/microsoft/qlib/blob/main/qlib/contrib/model/pytorch_tabnet.py


  1. Instance-wise feature selection


59082bb0ff355e57e0e7b1e529d60b85.webp


2. 真实数据集

Forest Cover Type:这个数据集是一个分类任务——根据cartographic变量来对森林覆盖类型进行分类,实验的baseline采用了如XGBoost等目前主流的树模型、可以自动构造高阶特征的AutoInt、以及AutoML Tables这种用了神经网络结构搜索 (Neural Architecture Search)的强力模型(node hours的数量反映了模型的复杂性),对比结果如下:


4d52e7371535fd894ce39f12c560f2f2.webp

Higgs Boson:这是一个物理领域的数据集,任务是将产生希格斯玻色子的信号与背景信号分辨开来,由于这个数据集很大,因此DNN比树模型的表现更好,下面是对比结果,其中Sparse evolutionary MLP应用了目前最好的evolutionary sparsification算法,能够有效减小原始MLP模型的大小,不过可以看出,和它大小相近的TabNet-S的性能也只是稍弱一点,这说明轻量级的TabNet表现依旧很好。

57f598d78664fa3b0934fb1bdcd25257.webp


3. 可解释性

f1e634010e73e7381c8a57e3782b9910.webp


dae2d17ece8fd6aa8785150790ea48de.webp



c378823c81f6b8095364ea9ad8007507.webp


4. 自监督学习

前面已经提到了,自监督学习可以提高模型的小样本学习能力,还能加快模型的收敛速度。为了验证这一点,这里我们采用Higgs Boson数据集,其中用全部样本来做自监督学习(pre-training),而只用部分样本做监督学习(fine-tuning),该方法与直接全样本监督学习的对比结果如下所示:

c8ce53b6bf7bb852395ff758d9960fbe.webp4e1c37cc4d39a9dddc703895285cb6ef.webp

从结果中可以看出,通过自监督学习进行预训练之后,模型的收敛速度明显更快,小样本学习的结果也变得更好。

总结

这篇论文提出的TabNet是一种针对于表格数据的神经网络,它通过类似于加性模型的顺序注意力机制(sequential attention mechanism)实现了instance-wise的特征选择,还通过encoder-decoder框架实现了自监督学习,从而将树模型的可解释性与DNN的表征能力很好地结合到了一起,相信这种兼具两者优点的模型将会成为数据挖掘竞赛中的一大利器,也对未来的研究提供了一个很好的思路。

参考资料

[1] TabNet: Attentive Interpretable Tabular Learning

https://arxiv.org/abs/1908.07442



机器学习算法AI大数据技术

 搜索公众号添加: datanlp

长按图片,识别二维码




阅读过本文的人还看了以下文章:


TensorFlow 2.0深度学习案例实战


基于40万表格数据集TableBank,用MaskRCNN做表格检测


《基于深度学习的自然语言处理》中/英PDF


Deep Learning 中文版初版-周志华团队


【全套视频课】最全的目标检测算法系列讲解,通俗易懂!


《美团机器学习实践》_美团算法团队.pdf


《深度学习入门:基于Python的理论与实现》高清中文PDF+源码


《深度学习:基于Keras的Python实践》PDF和代码


特征提取与图像处理(第二版).pdf


python就业班学习视频,从入门到实战项目


2019最新《PyTorch自然语言处理》英、中文版PDF+源码


《21个项目玩转深度学习:基于TensorFlow的实践详解》完整版PDF+附书代码


《深度学习之pytorch》pdf+附书源码


PyTorch深度学习快速实战入门《pytorch-handbook》


【下载】豆瓣评分8.1,《机器学习实战:基于Scikit-Learn和TensorFlow》


《Python数据分析与挖掘实战》PDF+完整源码


汽车行业完整知识图谱项目实战视频(全23课)


李沐大神开源《动手学深度学习》,加州伯克利深度学习(2019春)教材


笔记、代码清晰易懂!李航《统计学习方法》最新资源全套!


《神经网络与深度学习》最新2018版中英PDF+源码


将机器学习模型部署为REST API


FashionAI服装属性标签图像识别Top1-5方案分享


重要开源!CNN-RNN-CTC 实现手写汉字识别


yolo3 检测出图像中的不规则汉字


同样是机器学习算法工程师,你的面试为什么过不了?


前海征信大数据算法:风险概率预测


【Keras】完整实现‘交通标志’分类、‘票据’分类两个项目,让你掌握深度学习图像分类


VGG16迁移学习,实现医学图像识别分类工程项目


特征工程(一)


特征工程(二) :文本数据的展开、过滤和分块


特征工程(三):特征缩放,从词袋到 TF-IDF


特征工程(四): 类别特征


特征工程(五): PCA 降维


特征工程(六): 非线性特征提取和模型堆叠


特征工程(七):图像特征提取和深度学习


如何利用全新的决策树集成级联结构gcForest做特征工程并打分?


Machine Learning Yearning 中文翻译稿


蚂蚁金服2018秋招-算法工程师(共四面)通过


全球AI挑战-场景分类的比赛源码(多模型融合)


斯坦福CS230官方指南:CNN、RNN及使用技巧速查(打印收藏)


python+flask搭建CNN在线识别手写中文网站


中科院Kaggle全球文本匹配竞赛华人第1名团队-深度学习与特征工程



不断更新资源

深度学习、机器学习、数据分析、python

 搜索公众号添加: datayx  



浏览 61
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

分享
举报