如何使用 PyTorch torch.max()
在本文中,我们将了解如何使用 PyTorch torch.max()
函数。
正如大家所料,这是一个非常简单的功能,但有趣的是,它的功能比想象的要多。
让我们通过一些简单的例子来看看如何使用这个函数。
注意
:在撰写本文时,使用的 PyTorch 版本是 PyTorch 1.5.0
PyTorch torch.max() – 基本语法
要使用 PyTorch torch.max()
,首先导入 torch。
import torch
现在,此函数返回 Tensor 中元素的最大值。
PyTorch torch.max() 的默认行为
默认行为是返回单个元素和一个索引,对应于全局最大元素。
max_element = torch.max(input_tensor)
下面是一个例子:
p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)
输出
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor(2.7976)
事实上,这给了我们 Tensor 中的全局最大元素!
沿维度使用 torch.max()
但是,大家可能希望获得沿特定维度的最大值,作为张量,而不是单个元素。
要指定维度(轴 – 在 numpy 中),还有另一个可选的关键字参数,称为 dim
这代表了我们取最大值的方向。
这将返回一个元组 max_elements
和 max_indices
。
- max_elements -> Tensor的所有最大元素。
- max_indices -> 对应于最大元素的索引。
max_elements, max_indices = torch.max(input_tensor, dim)
这将返回一个 Tensor,它具有沿维度 dim
的最大元素。
现在让我们看一些例子。
p = torch.randn([2, 3])
print(p)
# 沿 dim = 0 (axis = 0) 获取最大值
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)
输出如下所示
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])
如你所见,我们找到了沿维度 0 的最大值(沿列的最大值)。
此外,我们得到与元素对应的索引。 例如,0.0688 在第 0 列的索引为 1
同样,如果要沿行查找最大值,请使用 dim=1
。
# 沿 dim = 1(axis = 1)获取最大值
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)
输出如下所示
tensor([2.7976, 1.4443])
tensor([1, 2])
实际上,我们得到了沿行的最大元素,以及相应的索引(沿行)。
使用 torch.max() 进行比较
我们还可以使用 torch.max() 来获取两个 Tensor 之间的最大值。
output_tensor = torch.max(a, b)
在这里,a 和 b 必须具有相同的维度,或者必须是“可广播的” Tensor。
这是一个比较两个具有相同维度的Tensor的简单示例。
p = torch.randn([2, 3])
q = torch.randn([2, 3])
print("p =", p)
print("q =",q)
# 比较 p 和 q 的元素并得到最大值
max_elements = torch.max(p, q)
print(max_elements)
结果输出如下所示
p = tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, -1.0376, 1.4443]])
q = tensor([[-0.0678, 0.2042, 0.8254],
[-0.1530, 0.0581, -0.3694]])
tensor([[-0.0665, 2.7976, 0.9753],
[ 0.0688, 0.0581, 1.4443]])
实际上,我们得到的输出 Tensor 在 p 和 q 之间具有最大元素。
总结
在本文中,我们学习了如何使用 torch.max()
函数来找出 Tensor 的最大元素。
我们还使用这个函数来比较两个 Tensor 并获得其中的最大值。