脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|shell|

服务器之家 - 脚本之家 - Python - 详解model.train()和model.eval()两种模式的原理与用法

详解model.train()和model.eval()两种模式的原理与用法

2023-03-24 12:06想变厉害的大白菜 Python

这篇文章主要介绍了详解model.train()和model.eval()两种模式的原理与用法,相信很多没有经验的人对此束手无策,那么看完这篇文章一定会对你有所帮助

一、两种模式

pytorch可以给我们提供两种方式来切换训练和评估(推断)的模式,分别是:model.train() 和 model.eval()。

一般用法是:在训练开始之前写上 model.trian() ,在测试时写上 model.eval() 。

二、功能

1. model.train()

在使用 pytorch 构建神经网络的时候,训练过程中会在程序上方添加一句model.train(),作用是 启用 batch normalization 和 dropout 。

如果模型中有BN层(Batch Normalization)和 Dropout ,需要在 训练时 添加 model.train()。

model.train() 是保证 BN 层能够用到 每一批数据 的均值和方差。对于 Dropout,model.train() 是 随机取一部分 网络连接来训练更新参数。

2. model.eval()

model.eval()的作用是 不启用 Batch Normalization 和 Dropout。

如果模型中有 BN 层(Batch Normalization)和 Dropout,在 测试时 添加 model.eval()。

model.eval() 是保证 BN 层能够用 全部训练数据 的均值和方差,即测试过程中要保证 BN 层的均值和方差不变。对于 Dropout,model.eval() 是利用到了 所有 网络连接,即不进行随机舍弃神经元。

为什么测试时要用 model.eval() ?

训练完 train 样本后,生成的模型 model 要用来测试样本了。在 model(test) 之前,需要加上model.eval(),否则的话,有输入数据,即使不训练,它也会改变权值。这是 model 中含有 BN 层和 Dropout 所带来的的性质。

eval() 时,pytorch 会自动把 BN 和 DropOut 固定住,不会取平均,而是用训练好的值。
不然的话,一旦 test 的 batch_size 过小,很容易就会被 BN 层导致生成图片颜色失真极大。
eval() 在非训练的时候是需要加的,没有这句代码,一些网络层的值会发生变动,不会固定,你神经网络每一次生成的结果也是不固定的,生成质量可能好也可能不好。

也就是说,测试过程中使用model.eval(),这时神经网络会 沿用 batch normalization 的值,而并 不使用 dropout。

3. 总结与对比

如果模型中有 BN 层(Batch Normalization)和 Dropout,需要在训练时添加 model.train(),在测试时添加 model.eval()。

其中 model.train() 是保证 BN 层用每一批数据的均值和方差,而 model.eval() 是保证 BN 用全部训练数据的均值和方差;

而对于 Dropout,model.train() 是随机取一部分网络连接来训练更新参数,而 model.eval() 是利用到了所有网络连接。

三、Dropout 简介

dropout 常常用于抑制过拟合。

设置Dropout时,torch.nn.Dropout(0.5),这里的 0.5 是指该层(layer)的神经元在每次迭代训练时会随机有 50% 的可能性被丢弃(失活),不参与训练。也就是将上一层数据减少一半传播。

到此这篇关于详解model.train()和model.eval()两种模式的原理与用法的文章就介绍到这了,更多相关model.train()和model.eval()原理用法内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/weixin_44211968/article/details/123774649

延伸 · 阅读

精彩推荐
  • Python使用Python构建Hopfield网络的教程

    使用Python构建Hopfield网络的教程

    这篇文章主要介绍了使用Python构建Hopfield网络的教程,本文来自于IBM官方网站的技术文档,需要的朋友可以参考下...

    脚本之家2782020-06-06
  • PythonPython虚拟环境的创建和使用详解

    Python虚拟环境的创建和使用详解

    这篇文章主要给大家介绍了关于Python虚拟环境的创建和使用的相关资料,文中通过图文介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,...

    300%努力努力再努力15842020-09-08
  • Python在tensorflow中实现去除不足一个batch的数据

    在tensorflow中实现去除不足一个batch的数据

    今天小编就为大家分享一篇在tensorflow中实现去除不足一个batch的数据,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧 ...

    Pywin4882020-04-10
  • PythonPython实现获取系统临时目录及临时文件的方法示例

    Python实现获取系统临时目录及临时文件的方法示例

    这篇文章主要介绍了Python实现获取系统临时目录及临时文件的方法,结合实例形式分析了Python文件与目录操作相关函数与使用技巧,需要的朋友可以参考下...

    轻舞肥羊5472021-07-22
  • Pythonpython中Switch/Case实现的示例代码

    python中Switch/Case实现的示例代码

    本篇文章主要介绍了python中Switch/Case实现的示例代码,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧...

    gerrydeng9332020-12-16
  • Pythonpython编写adb截图工具的实现源码

    python编写adb截图工具的实现源码

    adb截图工具可用于Android手机及Android终端,Android端或者Android终端的远程截图至本地电脑中,今天通过本文给大家介绍python编写adb截图工具的实现源码,感兴...

    mengyuelby11822021-12-23
  • Pythonpython中sys.argv参数用法实例分析

    python中sys.argv参数用法实例分析

    这篇文章主要介绍了python中sys.argv参数用法,实例分析了python中sys.argv参数的功能、定义及使用技巧,需要的朋友可以参考下...

    久月3092020-07-02
  • PythonPython中Selenium模拟JQuery滑动解锁实例

    Python中Selenium模拟JQuery滑动解锁实例

    这篇文章主要介绍了Python中Selenium模拟JQuery滑动解锁实例,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    虫师6152020-11-28