脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - 基于Pytorch的神经网络之Regression的实现

基于Pytorch的神经网络之Regression的实现

2022-11-02 09:34ZDDWLIG Python

本文主要介绍了基于Pytorch的神经网络之Regression的实现,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

1.引言

我们之前已经介绍了神经网络的基本知识,神经网络的主要作用就是预测与分类,现在让我们来搭建第一个用于拟合回归的神经网络吧。

 

2.神经网络搭建

 

2.1 准备工作

要搭建拟合神经网络并绘图我们需要使用python的几个库。

?
1
2
3
4
5
6
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())

 既然是拟合,我们当然需要一些数据啦,我选取了在区间 基于Pytorch的神经网络之Regression的实现 内的100个等间距点,并将它们排列成三次函数的图像。

 

2.2 搭建网络

我们定义一个类,继承了封装在torch中的一个模块,我们先分别确定输入层、隐藏层、输出层的神经元数目,继承父类后再使用torch中的.nn.Linear()函数进行输入层到隐藏层的线性变换,隐藏层也进行线性变换后传入输出层predict,接下来定义前向传播的函数forward(),使用relu()作为激活函数,最后输出predict()结果即可。

?
1
2
3
4
5
6
7
8
9
10
11
12
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)
    def forward(self, x):
        x = F.relu(self.hidden(x))
        return self.predict(x)
net = Net(1, 20, 1)
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()

网络的框架搭建完了,然后我们传入三层对应的神经元数目再定义优化器,这里我选取了Adam而随机梯度下降(SGD),因为它是SGD的优化版本,效果在大部分情况下比SGD好,我们要传入这个神经网络的参数(parameters),并定义学习率(learning rate),学习率通常选取小于1的数,需要凭借经验并不断调试。最后我们选取均方差法(MSE)来计算损失(loss)。

 

2.3 训练网络

接下来我们要对我们搭建好的神经网络进行训练,我训练了2000轮(epoch),先更新结果prediction再计算损失,接着清零梯度,然后根据loss反向传播(backward),最后进行优化,找出最优的拟合曲线。

?
1
2
3
4
5
6
for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

 

3.效果

使用如下绘图的代码展示效果。

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
        plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

基于Pytorch的神经网络之Regression的实现

基于Pytorch的神经网络之Regression的实现

最后的结果: 

基于Pytorch的神经网络之Regression的实现

 

4. 完整代码

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)
    def forward(self, x):
        x = F.relu(self.hidden(x))
        return self.predict(x)
net = Net(1, 20, 1)
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()
plt.ion()
for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
        plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

到此这篇关于基于Pytorch的神经网络之Regression的实现的文章就介绍到这了,更多相关 Pytorch Regression内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/ZDDWLIG/article/details/123488056

延伸 · 阅读

精彩推荐
  • PythonPython遍历文件夹 处理json文件的方法

    Python遍历文件夹 处理json文件的方法

    今天小编就为大家分享一篇Python遍历文件夹 处理json文件的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    Norton-Linux内核研究10712021-05-20
  • Python在Python的循环体中使用else语句的方法

    在Python的循环体中使用else语句的方法

    这篇文章主要介绍了在Python的循环体中使用else语句的方法,else语句的使用在各种语言的学习当中均为基本功、本文中主要介绍其在for循环中的应用,需要的...

    Shahriar Tajbakhsh3122020-05-26
  • Pythonpython pygame实现方向键控制小球

    python pygame实现方向键控制小球

    这篇文章主要为大家详细介绍了python pygame实现方向键控制小球,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    努力学python5602021-06-27
  • PythonPython使用sorted对字典的key或value排序

    Python使用sorted对字典的key或value排序

    这篇文章主要介绍了Python使用sorted对字典的key或value排序,小编觉得挺不错的,现在分享给大家,也给大家做个参考。一起跟随小编过来看看吧...

    Polar Snow Documentation8462021-04-19
  • PythonPython文件操作之二进制文件详解

    Python文件操作之二进制文件详解

    下面小编就为大家带来一篇使用Python文件操作之二进制文件。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧...

    Tester_Cheng5072022-01-12
  • Python详解Python prometheus_client使用方式

    详解Python prometheus_client使用方式

    本文主要介绍了Python prometheus_client使用方式,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下...

    JoJo938132022-09-15
  • Pythonpython 利用栈和队列模拟递归的过程

    python 利用栈和队列模拟递归的过程

    这篇文章主要介绍了python 利用栈和队列模拟递归的过程,文中并通过两段代码给大家介绍了下递归和非递归的区别,需要的朋友可以参考下...

    渔单渠9632021-02-26
  • Pythonpython看某个模块的版本方法

    python看某个模块的版本方法

    今天小编就为大家分享一篇python看某个模块的版本方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    Bovinitwo4542021-04-08