脚本之家,脚本语言编程技术及教程分享平台!
分类导航

Python|VBS|Ruby|Lua|perl|VBA|Golang|PowerShell|Erlang|autoit|Dos|bat|

服务器之家 - 脚本之家 - Python - pytorch中torch.topk()函数的快速理解

pytorch中torch.topk()函数的快速理解

2022-10-14 10:52Neo很努力 Python

我们在做分类算法时,时常见到@acc1和@acc5的情况,@acc1比较容易实现,但是一直苦于@acc5算法的实现,在此为大家提供一种@topk的实现方法,这篇文章主要给大家介绍了关于pytorch中torch.topk()函数的快速理解,需要的朋友可以参考下

函数作用:

pytorch中torch.topk()函数的快速理解

pytorch中torch.topk()函数的快速理解

该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。

通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中获取到的元素在原数组中的位置标号。

举个栗子:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
 
####################准备一个数组#########################
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
             [3,4,5,1,1,1,1,1,1,1,1],
             [7,8,9,1,1,1,1,1,1,1,1],
             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
 
####################打印这个原数组#########################
print('tensor1:')
print(tensor1)
 
#################使用torch.topk()这个函数##################
print('使用torch.topk()这个函数得到:')
 
'''k=3代表从原数组中取得3个元素,dim=1表示从原数组中的第一维获取元素
(在本例中是分别从[10,1,2,1,1,1,1,1,1,1,10]、[3,4,5,1,1,1,1,1,1,1,1]、
  [7,8,9,1,1,1,1,1,1,1,1]、[1,4,7,1,1,1,1,1,1,1,1]这四个数组中获取3个元素)
其中largest=True表示从大到小取元素'''
print(torch.topk(tensor1, k=3, dim=1, largest=True))
 
 
#################打印这个函数第一个返回值####################
print('函数第一个返回值topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
 
#################打印这个函数第二个返回值####################
print('函数第二个返回值topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
 
#######################运行结果##########################
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])
 
使用torch.topk()这个函数得到:
 
'得到的values是原数组dim=1的四组从大到小的三个元素值;
得到的indices是获取到的元素值在原数组dim=1中的位置。'
 
 
torch.return_types.topk(
values=tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]]))
 
函数第一个返回值topk[0]如下
tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]])
        
函数第二个返回值topk[1]如下
tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]])
'''

该函数功能经常用来获取张量或者数组中最大或者最小的元素以及索引位置,是一个经常用到的基本函数。

实例演示

任务一:

取top1(最大值):

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.01450.4053],
        [ 0.72651.41641.34431.20351.8823],
        [-0.44510.16731.2590, -2.07571.7255],
        [ 0.20210.30410.13830.3849, -1.6311]])
print(pred)
values, indices = pred.topk(1, dim=0, largest=True, sorted=True)
print(indices)
print(values)
# 用max得到的结果,设置keepdim为True,避免降维。因为topk函数返回的index不降维,shape和输入一致。
_, indices_max = pred.max(dim=0, keepdim=True)
print(indices_max)
print(indices_max == indices)
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.01450.4053],
        [ 0.72651.41641.34431.20351.8823],
        [-0.44510.16731.2590, -2.07571.7255],
        [ 0.20210.30410.13830.3849, -1.6311]])
tensor([[1, 1, 1, 1, 1]])
tensor([[0.7265, 1.4164, 1.3443, 1.2035, 1.8823]])
tensor([[1, 1, 1, 1, 1]])
tensor([[True, True, True, True, True]])

任务二:

按行取出topk,将小于topk的置为inf:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
pred = torch.tensor([[-0.5816, -0.3873, -1.0215, -1.01450.4053],
        [ 0.72651.41641.34431.20351.8823],
        [-0.44510.16731.2590, -2.07571.7255],
        [ 0.20210.30410.13830.3849, -1.6311]])
print(pred)
top_k = 2  # 按行求出每一行的最大的前两个值
filter_value=-float('Inf')
indices_to_remove = pred < torch.topk(pred, top_k)[0][..., -1, None]
print(indices_to_remove)
pred[indices_to_remove] = filter_value  # 对于topk之外的其他元素的logits值设为负无穷
print(pred)
 
输出:
tensor([[-0.5816, -0.3873, -1.0215, -1.01450.4053],
        [ 0.72651.41641.34431.20351.8823],
        [-0.44510.16731.2590, -2.07571.7255],
        [ 0.20210.30410.13830.3849, -1.6311]])
tensor([[4],
        [4],
        [4],
        [3]])
tensor([[0.4053],
        [1.8823],
        [1.7255],
        [0.3849]])
tensor([[ True, FalseTrueTrue, False],
        [ True, FalseTrueTrue, False],
        [ TrueTrue, FalseTrue, False],
        [ True, FalseTrue, FalseTrue]])
tensor([[   -inf, -0.3873,    -inf,    -inf,  0.4053],
        [   -inf,  1.4164,    -inf,    -inf,  1.8823],
        [   -inf,    -inf,  1.2590,    -inf,  1.7255],
        [   -inf,  0.3041,    -inf,  0.3849,    -inf]])

任务三:

?
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
import numpy as np
import torch
import torch.utils.data.dataset as Dataset
from torch.utils.data import Dataset,DataLoader
tensor1=torch.tensor([[10,1,2,1,1,1,1,1,1,1,10],
             [3,4,5,1,1,1,1,1,1,1,1],
             [7,8,9,1,1,1,1,1,1,1,1],
             [1,4,7,1,1,1,1,1,1,1,1]],dtype=torch.float32)
# tensor2=torch.tensor([[3,2,1],
#                       [6,5,4],
#                       [1,4,7],
#                       [9,8,7]],dtype=torch.float32)
#
print('tensor1:')
print(tensor1)
print('直接输出topk,会得到两个东西,我们需要的是第二个indices')
print(torch.topk(tensor1, k=3, dim=1, largest=True))
print('topk[0]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[0])
print('topk[1]如下')
print(torch.topk(tensor1, k=3, dim=1, largest=True)[1])
'''
tensor1:
tensor([[10.,  1.,  2.,  1.,  1.,  1.,  1.,  1.,  1.,  1., 10.],
        [ 3.,  4.,  5.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 7.,  8.,  9.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 1.,  4.,  7.,  1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.]])
直接输出topk,会得到两个东西,我们需要的是第二个indices
torch.return_types.topk(
values=tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]]),
indices=tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]]))
topk[0]如下
tensor([[10., 10.,  2.],
        [ 5.,  4.,  3.],
        [ 9.,  8.,  7.],
        [ 7.,  4.,  1.]])
topk[1]如下
tensor([[ 0, 10,  2],
        [ 2,  1,  0],
        [ 2,  1,  0],
        [ 2,  1,  0]])
'''

总结

到此这篇关于pytorch中torch.topk()函数快速理解的文章就介绍到这了,更多相关pytorch torch.topk()函数理解内容请搜索服务器之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持服务器之家!

原文链接:https://blog.csdn.net/qq_45193872/article/details/119878804

延伸 · 阅读

精彩推荐
  • PythonDjango Admin 实现外键过滤的方法

    Django Admin 实现外键过滤的方法

    下面小编就为大家带来一篇Django Admin 实现外键过滤的方法。小编觉得挺不错的,现在就分享给大家,也给大家做个参考。一起跟随小编过来看看吧...

    ishouyong6972020-12-10
  • PythonPython 利用切片从列表中取出一部分使用的方法

    Python 利用切片从列表中取出一部分使用的方法

    今天小编就为大家分享一篇Python 利用切片从列表中取出一部分使用的方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    a41117801012542021-05-25
  • Pythonpandas读取excel时获取读取进度的实现

    pandas读取excel时获取读取进度的实现

    这篇文章主要介绍了pandas读取excel时获取读取进度的实现,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋...

    THUNDER4742021-10-09
  • Python巧妙使用Python装饰器处理if...elif...else

    巧妙使用Python装饰器处理if...elif...else

    大家好,今天在 Github 阅读 EdgeDB[1] 的代码,发现它在处理大量if…elif…else的时候,巧妙地使用了装饰器,方法设计精巧,分享给大家一下,欢迎收藏学习...

    Python学习与数据挖掘11902022-03-03
  • Pythonpython学习笔记之列表(list)与元组(tuple)详解

    python学习笔记之列表(list)与元组(tuple)详解

    List(列表)是Python中使用最频繁的数据类型,而元组是另一个数据类型,类似于List(列表)。这篇文章主要给大家介绍了python学习笔记之列表(list)与元组...

    kaka_4092020-12-20
  • PythonPython实现定时自动关闭的tkinter窗口方法

    Python实现定时自动关闭的tkinter窗口方法

    今天小编就为大家分享一篇Python实现定时自动关闭的tkinter窗口方法,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧...

    Python_小屋7302021-05-28
  • Pythonpytorch 预训练模型读取修改相关参数的填坑问题

    pytorch 预训练模型读取修改相关参数的填坑问题

    这篇文章主要介绍了pytorch 预训练模型读取修改相关参数的填坑问题,具有很好的参考价值,希望对大家有所帮助。如有错误或未考虑完全的地方,望不吝...

    DRACO于6532021-11-23
  • Pythonpython读写数据读写csv文件(pandas用法)

    python读写数据读写csv文件(pandas用法)

    这篇文章主要介绍了python读写数据读写csv文件(pandas用法),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋...

    小朱小朱绝不服输10082021-08-13