如何使用 PyTorch torch.max()

在本文中,我们将了解如何使用 PyTorch torch.max() 函数。

正如大家所料,这是一个非常简单的功能,但有趣的是,它的功能比想象的要多。

让我们通过一些简单的例子来看看如何使用这个函数。

注意 :在撰写本文时,使用的 PyTorch 版本是 PyTorch 1.5.0


PyTorch torch.max() – 基本语法

要使用 PyTorch torch.max(),首先导入 torch

import torch

现在,此函数返回 Tensor 中元素的最大值。

PyTorch torch.max() 的默认行为

默认行为是返回单个元素和一个索引,对应于全局最大元素。

max_element = torch.max(input_tensor)

下面是一个例子:

p = torch.randn([2, 3])
print(p)
max_element = torch.max(p)
print(max_element)

输出

tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
tensor(2.7976)

事实上,这给了我们 Tensor 中的全局最大元素!

沿维度使用 torch.max()

但是,大家可能希望获得沿特定维度的最大值,作为张量,而不是单个元素。

要指定维度(轴 – 在 numpy 中),还有另一个可选的关键字参数,称为 dim

这代表了我们取最大值的方向。

这将返回一个元组 max_elements 和 max_indices

  • max_elements -> Tensor的所有最大元素。
  • max_indices -> 对应于最大元素的索引。
max_elements, max_indices = torch.max(input_tensor, dim)

这将返回一个 Tensor,它具有沿维度 dim 的最大元素。

现在让我们看一些例子。

p = torch.randn([2, 3])
print(p)

# 沿 dim = 0 (axis = 0) 获取最大值
max_elements, max_idxs = torch.max(p, dim=0)
print(max_elements)
print(max_idxs)

输出如下所示

tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
tensor([0.0688, 2.7976, 1.4443])
tensor([1, 0, 1])

如你所见,我们找到了沿维度 0 的最大值(沿列的最大值)。

此外,我们得到与元素对应的索引。 例如,0.0688 在第 0 列的索引为 1

同样,如果要沿行查找最大值,请使用 dim=1

# 沿 dim = 1(axis = 1)获取最大值
max_elements, max_idxs = torch.max(p, dim=1)
print(max_elements)
print(max_idxs)

输出如下所示

tensor([2.7976, 1.4443])
tensor([1, 2])

实际上,我们得到了沿行的最大元素,以及相应的索引(沿行)。


使用 torch.max() 进行比较

我们还可以使用 torch.max() 来获取两个 Tensor 之间的最大值。

output_tensor = torch.max(a, b)

在这里,a 和 b 必须具有相同的维度,或者必须是“可广播的” Tensor。

这是一个比较两个具有相同维度的Tensor的简单示例。

p = torch.randn([2, 3])
q = torch.randn([2, 3])

print("p =", p)
print("q =",q)

# 比较 p 和 q 的元素并得到最大值
max_elements = torch.max(p, q)

print(max_elements)

结果输出如下所示

p = tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688, -1.0376,  1.4443]])
q = tensor([[-0.0678,  0.2042,  0.8254],
        [-0.1530,  0.0581, -0.3694]])
tensor([[-0.0665,  2.7976,  0.9753],
        [ 0.0688,  0.0581,  1.4443]])

实际上,我们得到的输出 Tensor 在 p 和 q 之间具有最大元素。


总结

在本文中,我们学习了如何使用 torch.max() 函数来找出 Tensor 的最大元素。

我们还使用这个函数来比较两个 Tensor 并获得其中的最大值。

免责声明:
1.本站所有内容由本站原创、网络转载、消息撰写、网友投稿等几部分组成。
2.本站原创文字内容若未经特别声明,则遵循协议CC3.0共享协议,转载请务必注明原文链接。
3.本站部分来源于网络转载的文章信息是出于传递更多信息之目的,不意味着赞同其观点。
4.本站所有源码与软件均为原作者提供,仅供学习和研究使用。
5.如您对本网站的相关版权有任何异议,或者认为侵犯了您的合法权益,请及时通知我们处理。
火焰兔 » 如何使用 PyTorch torch.max()