Keras入门(八)K折交叉验证
Python爬虫与算法
共 2911字,需浏览 6分钟
· 2021-01-26
在文章Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中,笔者介绍了如何搭建DNN模型来解决IRIS数据集的多分类问题。
本文将在此基础上介绍如何在Keras中实现K折交叉验证。
什么是K折交叉验证?
K折交叉验证是机器学习中的一个专业术语,它指的是将原始数据随机分成K份,每次选择K-1份作为训练集,剩余的1份作为测试集。交叉验证重复K次,取K次准确率的平均值作为最终模型的评价指标。一般取K=10,即10折交叉验证,如下图所示:
用交叉验证的目的是为了得到可靠稳定的模型。K折交叉验证能够有效提高模型的学习能力,类似于增加了训练样本数量,使得学习的模型更加稳健,鲁棒性更强。选择合适的K值能够有效避免过拟合。
Keras实现K折交叉验证
我们仍采用文章Keras入门(一)搭建深度神经网络(DNN)解决多分类问题中的模型,如下:
同时,我们对IRIS数据集采用10折交叉验证,完整的实现代码如下:
# -*- coding: utf-8 -*-
# model_train.py
# Python 3.6.8, TensorFlow 2.3.0, Keras 2.4.3
# 导入模块
import keras as K
import pandas as pd
from sklearn.model_selection import KFold
# 读取CSV数据集
# 该函数的传入参数为csv_file_path: csv文件路径
def load_data(sv_file_path):
iris = pd.read_csv(sv_file_path)
target_var = 'class' # 目标变量
# 数据集的特征
features = list(iris.columns)
features.remove(target_var)
# 目标变量的类别
Class = iris[target_var].unique()
# 目标变量的类别字典
Class_dict = dict(zip(Class, range(len(Class))))
# 增加一列target, 将目标变量转化为类别变量
iris['target'] = iris[target_var].apply(lambda x: Class_dict[x])
return features, 'target', iris
# 创建模型
def create_model():
init = K.initializers.glorot_uniform(seed=1)
simple_adam = K.optimizers.Adam()
model = K.models.Sequential()
model.add(K.layers.Dense(units=5, input_dim=4, kernel_initializer=init, activation='relu'))
model.add(K.layers.Dense(units=6, kernel_initializer=init, activation='relu'))
model.add(K.layers.Dense(units=3, kernel_initializer=init, activation='softmax'))
model.compile(loss='sparse_categorical_crossentropy', optimizer=simple_adam, metrics=['accuracy'])
return model
def main():
# 1. 读取CSV数据集
print("Loading Iris data into memory")
n_split = 10
features, target, data = load_data("./iris_data.csv")
x = data[features]
y = data[target]
avg_accuracy = 0
avg_loss = 0
for train_index, test_index in KFold(n_split).split(x):
print("test index: ", test_index)
x_train, x_test = x.iloc[train_index], x.iloc[test_index]
y_train, y_test = y.iloc[train_index], y.iloc[test_index]
print("create model and train model")
model = create_model()
model.fit(x_train, y_train, batch_size=1, epochs=80, verbose=0)
print('Model evaluation: ', model.evaluate(x_test, y_test))
avg_accuracy += model.evaluate(x_test, y_test)[1]
avg_loss += model.evaluate(x_test, y_test)[0]
print("K fold average accuracy: {}".format(avg_accuracy / n_split))
print("K fold average accuracy: {}".format(avg_loss / n_split))
main()
模型的输出结果如下:
Iteration | loss | accuracy |
---|---|---|
1 | 0.00056 | 1.0 |
2 | 0.00021 | 1.0 |
3 | 0.00022 | 1.0 |
4 | 0.00608 | 1.0 |
5 | 0.21925 | 0.8667 |
6 | 0.52390 | 0.8667 |
7 | 0.00998 | 1.0 |
8 | 0.04431 | 1.0 |
9 | 0.14590 | 1.0 |
10 | 0.21286 | 0.8667 |
avg | 0.11633 | 0.9600 |
10折交叉验证的平均loss为0.11633,平均准确率为96.00%。
总结
本文代码已存放至Github,网址为:https://github.com/percent4/Keras-K-fold-test 。
感谢大家的阅读~
2020.1.24于上海浦东
评论
某程序员吐槽:公司最近招了一批35左右的,这帮人习惯天天卷到八点多,导致现在我们也要八点才下班
架构师大咖
架构师大咖,打造有价值的架构师交流平台。分享架构师干货、教程、课程、资讯。架构师大咖,每日推送。
公众号该公众号已被封禁某位程序员的吐槽引发了广泛的思考和共鸣。他抱怨公司
源码共读
0
原来Matplotlib能画股票K线图!!附代码
之前在一篇文章中提到Matplotlib可视化,甚至可以用来画股票K线图,许多同学也在问代码,这次来发个文回应下。Python用matplotlib绘制K线图,需要配合talib、numpy、mpl_finance等第三方库来使用,效果展示如下:简单讲讲K线图的结构,我不搞股票,所以不太懂,特地查了
Python大数据分析
9
实践教程 | 在yolov5上验证一些不成熟的想法
点击上方“小白学视觉”,选择加"星标"或“置顶”重磅干货,第一时间送达作者丨王小二@知乎(已授权)来源丨https://zhuanlan.zhihu.com/p/388246083编辑丨极市平台极市导读 本文做了两件事:一是把基于mxnet的训练代码迁移到pytorch上,二是在yolov
小白学视觉
10
CPU的入门知识
不管你玩硬件还是做软件,你的世界都少不了计算机最核心的 —— CPU。01CPU是什么?CPU与计算机的关系就相当于大脑和人的关系,它是一种小型的计算机芯片,通常嵌入在电脑的主板上。CPU的构建是通过在单个计算机芯片上放置数十亿个微型晶体管来实现。这些晶体管使它能够执行运行存储在系统内存中的程序所需
机器学习算法与Python实战
10
【机器学习】如何在交叉验证中使用SHAP?
在许多情况下,机器学习模型比传统线性模型更受欢迎,因为它们具有更好的预测性能和处理复杂非线性数据的能力。然而,机器学习模型的一个常见问题是它们缺乏可解释性。例如,集成方法如XGBoost和随机森林将许多个体学习器的结果组合起来生成结果。尽管这通常会带来更好的性能,但它使得难以知道数据集中每个特征对输
机器学习初学者
10
10个python爬虫入门实例
涉及主要知识点:web是如何交互的requests库的get、post函数的应用response对象的相关函数,属性python文件的打开,保存代码中给出了注释,并且可以直接运行哦如何安装requests库(安装好python的朋友可以直接参考,没...
马哥Linux运维
0
Gin 框架介绍与快速入门
目录Gin 框架介绍与快速入门1.gin.Engine2.gin.Context1.安装2.导入3.第一个Gin 应用1. 快速和轻量级2. 路由和中间件3. JSON解析4. 支持插件5. Gin相关文档一、Gin框架介绍二、基本使用三、应用举例四、Gin 入门核心...
马哥Linux运维
0
Java日志系统历史从入门到崩溃
前言最早开始撸码当时就遇到几次日志jar包冲突的问题,当时也是很烦躁,毕竟了解的也不多,什么那里4j,这里4j,还有什么桥接包,而且在我感觉他们的包名都还差不多!!我当时是比较懵逼的,上网搜了下,随便看到一...
浪尖聊大数据
0