脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - pytorch使用nn.Moudle实现逻辑回归

pytorch使用nn.Moudle实现逻辑回归

2022-07-30 15:25zeroooo000oo Python

这篇文章主要为大家详细介绍了pytorch使用nn.Moudle实现逻辑回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了pytorch使用nn.Moudle实现逻辑回归的具体代码,供大家参考,具体内容如下

内容

pytorch使用nn.Moudle实现逻辑回归

问题

loss下降不明显

解决方法

?
1
2
3
4
5
6
7
8
9
10
#源代码 out的数据接收方式
     if torch.cuda.is_available():
         x_data=Variable(x).cuda()
         y_data=Variable(y).cuda()
     else:
         x_data=Variable(x)
         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
?
1
2
3
4
5
6
7
8
9
10
11
#源代码 out的数据有拼装数据直接输入
#     if torch.cuda.is_available():
#         x_data=Variable(x).cuda()
#         y_data=Variable(y).cuda()
#     else:
#         x_data=Variable(x)
#         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
    print_loss=loss.data.item()  #得出损失函数值

源代码

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch
from torch import nn
from torch.autograd import Variable
import matplotlib.pyplot as plt
import numpy as np
 
#生成数据
sample_nums = 100
mean_value = 1.7
bias = 1
n_data = torch.ones(sample_nums, 2)
x0 = torch.normal(mean_value * n_data, 1) + bias      # 类别0 数据 shape=(100, 2)
y0 = torch.zeros(sample_nums)                         # 类别0 标签 shape=(100, 1)
x1 = torch.normal(-mean_value * n_data, 1) + bias     # 类别1 数据 shape=(100, 2)
y1 = torch.ones(sample_nums)                          # 类别1 标签 shape=(100, 1)
x_data = torch.cat((x0, x1), 0)  #按维数0行拼接
y_data = torch.cat((y0, y1), 0)
 
#画图
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=y.data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.show()
 
# 利用torch.nn实现逻辑回归
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.lr = nn.Linear(2, 1)
        self.sm = nn.Sigmoid()
 
    def forward(self, x):
        x = self.lr(x)
        x = self.sm(x)
        return x
    
logistic_model = LogisticRegression()
# if torch.cuda.is_available():
#     logistic_model.cuda()
 
#loss函数和优化
criterion = nn.BCELoss()
optimizer = torch.optim.SGD(logistic_model.parameters(), lr=0.01, momentum=0.9)
#开始训练
#训练10000次
for epoch in range(10000):
#     if torch.cuda.is_available():
#         x_data=Variable(x).cuda()
#         y_data=Variable(y).cuda()
#     else:
#         x_data=Variable(x)
#         y_data=Variable(y)
    
    out=logistic_model(x_data)  #根据逻辑回归模型拟合出的y值
    loss=criterion(out.squeeze(),y_data)  #计算损失函数
    print_loss=loss.data.item()  #得出损失函数值
    #反向传播
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    mask=out.ge(0.5).float()  #以0.5为阈值进行分类
    correct=(mask==y_data).sum().squeeze()  #计算正确预测的样本个数
    acc=correct.item()/x_data.size(0)  #计算精度
    #每隔20轮打印一下当前的误差和精度
    if (epoch+1)%100==0:
        print('*'*10)
        print('epoch {}'.format(epoch+1))  #误差
        print('loss is {:.4f}'.format(print_loss))
        print('acc is {:.4f}'.format(acc))  #精度
        
        
w0, w1 = logistic_model.lr.weight[0]
w0 = float(w0.item())
w1 = float(w1.item())
b = float(logistic_model.lr.bias.item())
plot_x = np.arange(-7, 7, 0.1)
plot_y = (-w0 * plot_x - b) / w1
plt.xlim(-5, 7)
plt.ylim(-7, 7)
plt.scatter(x.data.numpy()[:, 0], x.data.numpy()[:, 1], c=logistic_model(x_data)[:,0].cpu().data.numpy(), s=100, lw=0, cmap='RdYlGn')
plt.plot(plot_x, plot_y)
plt.show()

输出结果

pytorch使用nn.Moudle实现逻辑回归

pytorch使用nn.Moudle实现逻辑回归

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持服务器之家。

原文链接:https://blog.csdn.net/zeroooo000oo/article/details/107885489

延伸 · 阅读

精彩推荐
  • PythonPython基于动态规划算法解决01背包问题实例

    Python基于动态规划算法解决01背包问题实例

    这篇文章主要介绍了Python基于动态规划算法解决01背包问题,结合实例形式分析了Python动态规划算法解决01背包问题的原理与具体实现技巧,需要的朋友可以参...

    littlethunder5702020-12-22
  • PythonPython、 Pycharm、Django安装详细教程(图文)

    Python、 Pycharm、Django安装详细教程(图文)

    这篇文章主要介绍了Python、 Pycharm、Django安装详细教程,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友...

    diligentkong10092021-06-15
  • PythonPython 中字符串拼接的多种方法

    Python 中字符串拼接的多种方法

    本篇文章给大家介绍python中字符串拼接的多种方法,非常不错,具有一定的参考借鉴价值,需要的朋友参考下吧...

    木头释然7462021-03-24
  • Python在类Unix系统上开始Python3编程入门

    在类Unix系统上开始Python3编程入门

    这篇文章主要介绍了在类Unix系统上开始Python3编程入门,讲解了最基础最直观的利用Print函数进行各种输出的方法,需要的朋友可以参考下...

    一线涯2832020-07-30
  • Pythonpython发送HTTP请求的方法小结

    python发送HTTP请求的方法小结

    这篇文章主要介绍了python发送HTTP请求的方法,实例总结了GET、HEAD与POST方式发送http请求的相关技巧,需要的朋友可以参考下...

    鉴客4462020-07-20
  • PythonPython缩进和冒号详解

    Python缩进和冒号详解

    下面小编就为大家带来一篇Python缩进和冒号详解。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧...

    脚本之家22302020-08-24
  • PythonPython基础之标准库和常用的第三方库案例教程

    Python基础之标准库和常用的第三方库案例教程

    这篇文章主要介绍了Python基础之标准库和常用的第三方库案例教程,本篇文章通过简要的案例,讲解了该项技术的了解与使用,以下就是详细内容,需要的朋友可...

    Holidaylovesam4382021-12-14
  • PythonPython教程使用Chord包实现炫彩弦图示例

    Python教程使用Chord包实现炫彩弦图示例

    在可视化中,有时候会使用到弦图(Chord Diagram)来表示事物之间关系,本篇文章教大家如何使用Chord包实现炫彩弦图,有需要的朋友可以借鉴参考下,希望大...

    麦片加奶不加糖11012022-01-17