当前位置:首页>编程日记>正文

使用early stopping解决神经网络过拟合问题

神经网络训练多少轮是一个很关键的问题,训练轮数少了欠拟合(underfit),训练轮数多了过拟合(overfit),那如何选择训练轮数呢?

Early stopping可以帮助我们解决这个问题,它的作用就是当模型在验证集上的性能不再增加的时候就停止训练,从而达到充分训练的作用,又避免过拟合。

一、在Keras中使用early stopping

完整代码

Keras中有EarlyStopping类,可以直接拿来使用,非常方便

from keras.callbacks import EarlyStoppingearlystop = EarlyStopping(monitor = 'val_loss',mode='min',min_delta = 0,patience = 3,verbose = 1,)
  1. monitor。想要监控的指标,比如在这里我们主要看的是验证集上的loss,当loss不再降低的时候就停止
  2. mode。想要最大值还是最小值,在这里我们使用的min,当时loss越小越好
  3. min_delta。指标的变化超过min_delta才认为产生了变化,否则都认为不再上升或下降
  4. patience。多少轮不发生变化才停止
  5. verbose。设置为1的时候,训练结束会打印出epoch的情况

二、保存最佳模型

完整代码

在early stopping结束后得到模型不一定是最佳模型,所以我们需要把训练过程中表现最好的模型保存下来,以便使用。在这里我们可以使用Keras提供的另一callback来实现:

from keras.callbacks import ModelCheckpointmc = ModelCheckpoint(file_path='./best_model.h5',monitor='val_accuracy',mode='max',verbose=1,save_best_only=True)
  1. filepath,模型存储的路径
  2. monitor,监控的指标
  3. mode,最大还是最小模式
  4. verbose,日志显示控制
  5. save_best_only,是否只存储最好的模型

通过使用这个方法我们就可以把最好的模型存储下来,在使用的时候直接load就可以了。

三、在IMDB数据集上使用Early Stopping

完整代码​​​​​​​

IMDB是一个情感分析数据集,我们首先在这个数据集上使用一个简单的CNN看看效果,然后再使用Early Stopping作为对比。首先看看CNN代码。先对句子embedding, 然后使用一层Conv1D+Maxpooling。

# Build model
sentence = Input(batch_shape=(None, max_words), dtype='int32', name='sentence')
embedding_layer = Embedding(top_words, embedding_dims, input_length=max_words)
sent_embed = embedding_layer(sentence)
conv_layer = Conv1D(filters, kernel_size, padding='valid', activation='relu')
sent_conv = conv_layer(sent_embed)
sent_pooling = GlobalMaxPooling1D()(sent_conv)
sent_repre = Dense(250)(sent_pooling)
sent_repre = Activation('relu')(sent_repre)
sent_repre = Dense(1)(sent_repre)
pred = Activation('sigmoid')(sent_repre)
model = Model(inputs=sentence, outputs=pred)
rmsprop = optimizers.rmsprop(lr=0.0003)
model.compile(loss='binary_crossentropy', optimizer=rmsprop, metrics=['accuracy'])

最终在数据集上的结果如下,在训练集上基本达到了100,而在测试集上还不到90,看起来有点过拟合了

Training Accuracy: 100%
Test Accuracy: 88.50%

我们再看Loss曲线,大约在第8轮的时候,验证集上的Loss达到最低,但是在往后Loss开始升高,这就更加确定发生了过拟合,我们需要提前停止训练,最好在第8轮之后就停下来。

使用early stopping解决神经网络过拟合问题 配图01

在IMDB数据集上使用Early Stopping

我们再训练过程中加上一个patience=10的earlystop,监控验证集loss。当验证集的loss在近10轮都没有下降的话就停止。

#early stopping
earlystop = EarlyStopping(monitor='val_loss',min_delta=0,patience=10,verbose=1)# fit the model
history = model.fit(x_train, y_train, batch_size=batch_size,epochs=epochs, verbose=1, validation_data=(x_test, y_test), callbacks[earlystop])

结果如下,我们可以看到训练最终在第16轮停止了,停止时在测试集上的准确率为88.40%,并没有高于不使用Early Stopping的情况,但是在训练的第12轮模型的准确达到了89.30%,超过了Baseline。所以我们需要加上存储最好模型的callback。

Epoch 2/50
5000/5000 [==============================] - 5s 951us/step - loss: 0.4851 - acc: 0.7986 - val_loss: 0.4320 - val_acc: 0.8170
Epoch 3/50
5000/5000 [==============================] - 5s 918us/step - loss: 0.3193 - acc: 0.8802 - val_loss: 0.3599 - val_acc: 0.8370
Epoch 4/50
5000/5000 [==============================] - 4s 882us/step - loss: 0.2093 - acc: 0.9322 - val_loss: 0.3392 - val_acc: 0.8530
Epoch 5/50
5000/5000 [==============================] - 4s 880us/step - loss: 0.1209 - acc: 0.9702 - val_loss: 0.4001 - val_acc: 0.8260
Epoch 6/50
5000/5000 [==============================] - 4s 887us/step - loss: 0.0600 - acc: 0.9884 - val_loss: 0.2900 - val_acc: 0.8710
Epoch 7/50
5000/5000 [==============================] - 4s 865us/step - loss: 0.0208 - acc: 0.9986 - val_loss: 0.2978 - val_acc: 0.8840
Epoch 8/50
5000/5000 [==============================] - 4s 883us/step - loss: 0.0053 - acc: 1.0000 - val_loss: 0.3180 - val_acc: 0.8840
Epoch 9/50
5000/5000 [==============================] - 4s 856us/step - loss: 0.0011 - acc: 1.0000 - val_loss: 0.3570 - val_acc: 0.8830
Epoch 10/50
5000/5000 [==============================] - 4s 845us/step - loss: 1.7574e-04 - acc: 1.0000 - val_loss: 0.4035 - val_acc: 0.8800
Epoch 11/50
5000/5000 [==============================] - 4s 869us/step - loss: 2.0190e-05 - acc: 1.0000 - val_loss: 0.4490 - val_acc: 0.8820
Epoch 12/50
5000/5000 [==============================] - 4s 846us/step - loss: 1.6874e-06 - acc: 1.0000 - val_loss: 0.5164 - val_acc: 0.8930
Epoch 13/50
5000/5000 [==============================] - 4s 860us/step - loss: 2.6231e-07 - acc: 1.0000 - val_loss: 0.5429 - val_acc: 0.8840
Epoch 14/50
5000/5000 [==============================] - 4s 870us/step - loss: 1.4614e-07 - acc: 1.0000 - val_loss: 0.5754 - val_acc: 0.8810
Epoch 15/50
5000/5000 [==============================] - 4s 888us/step - loss: 1.2477e-07 - acc: 1.0000 - val_loss: 0.5744 - val_acc: 0.8850
Epoch 16/50
5000/5000 [==============================] - 4s 876us/step - loss: 1.1823e-07 - acc: 1.0000 - val_loss: 0.5909 - val_acc: 0.8840
Epoch 00016: early stopping
Accuracy: 88.40%

存储最好模型

我们使用ModelCheckPoint存储最好的模型,具体如下,通过监控验证集上的准确率,我们把准确率最高的模型存储下来

from keras.callbacks import EarlyStopping, ModelCheckpointmc = ModelCheckpoint(filepath='best_model.h5',monitor='val_acc',mode='max',verbose=1,save_best_only=True)

然后在使用的时候进行load,然后就可以进行预测了

from keras.models import load_model
saved_model = load_model('best_model.h5')
# evaluate the model
_, train_acc = saved_model.evaluate(x_train, y_train, verbose=0)
_, test_acc = saved_model.evaluate(x_test, y_test, verbose=0)
print('Train: %.3f, Test: %.3f' % (train_acc, test_acc))

最终的结果如下

Train: 1.000, Test: 0.893

正确使用Early Stopping加上存储最佳模型可以帮助我们减轻过拟合,从而训练出表现更好的模型。

完整代码​​​​​​​​​​​​​​


http://www.coolblog.cn/news/c760d01a401b5e2a.html

相关文章:

  • asp多表查询并显示_SpringBoot系列(五):SpringBoot整合Mybatis实现多表关联查询
  • s7day2学习记录
  • 【求锤得锤的故事】Redis锁从面试连环炮聊到神仙打架。
  • 矿Spring入门Demo
  • 拼音怎么写_老师:不会写的字用圈代替,看到孩子试卷,网友:人才
  • Linux 实时流量监测(iptraf中文图解)
  • Win10 + Python + GPU版MXNet + VS2015 + RTools + R配置
  • 美颜
  • shell访问php文件夹,Shell获取某目录下所有文件夹的名称
  • 如何优雅的实现 Spring Boot 接口参数加密解密?
  • LeCun亲授的深度学习入门课:从飞行器的发明到卷积神经网络
  • Mac原生Terminal快速登录ssh
  • java受保护的数据与_Javascript类定义语法,私有成员、受保护成员、静态成员等介绍...
  • mysql commit 机制_1024MySQL事物提交机制
  • 支撑微博千亿调用的轻量级RPC框架:Motan
  • 法拉利虚拟学院2010 服务器,法拉利虚拟学院2010
  • 2019-9
  • jquery 使用小技巧
  • vscode pylint 错误_将实际未错误的py库添加到pylint白名单
  • 科学计算工具NumPy(3):ndarray的元素处理
  • linux批量创建用户和密码
  • 工程师在工作电脑存 64G 不雅文件,被公司开除后索赔 41 万,结果…
  • js常用阻止冒泡事件
  • newinsets用法java_Java XYPlot.setInsets方法代碼示例
  • 气泡图在开源监控工具中的应用效果
  • 各类型土地利用图例_划重点!国土空间总体规划——土地利用
  • php 启动服务器监听
  • dubbo简单示例
  • [iptables]Redhat 7.2下使用iptables实现NAT
  • Ubuntu13.10:[3]如何开启SSH SERVER服务
  • 【设计模式】 模式PK:策略模式VS状态模式
  • CSS小技巧——CSS滚动条美化
  • JS实现-页面数据无限加载
  • 最新DOS大全
  • 阿里巴巴分布式服务框架 Dubbo
  • Django View(视图系统)
  • 阿里大鱼.net core 发送短信
  • 程序员入错行怎么办?
  • Sorenson Capital:值得投资的 5 种 AI 技术
  • 两张超级大表join优化
  • 第九天函数
  • Linux软件安装-----apache安装
  • HDU 5988 最小费用流
  • Arm芯片的新革命在缓缓上演
  • 《看透springmvc源码分析与实践》读书笔记一
  • 正式开课!如何学习相机模型与标定?(单目+双目+鱼眼+深度相机)
  • nagios自写插件—check_file
  • python3 错误 Max retries exceeded with url 解决方法
  • 通过Spark进行ALS离线和Stream实时推荐
  • 行为模式之Template Method模式