PyTorch磨刀篇|argmax和argmin函数_一维数组

一、语法格式

格式一(只针对argmax函数):

torch.argmax(input) → LongTensor

功能:

Returns the indices of the maximum value of all elements in the input tensor。

即:返回输入张量中所有元素中最大值对应的索引(按行搜索);如果有多个相同的值,则返回第一次遇到的那个值对应的索引。

举例:

In [28]: r=torch.tensor([[1,2,3,4,5],[6,7,8,9,10],[11,12,13,14,15]])

In [29]: torch.argmax(r)

Out[29]: tensor(14)


格式二:

[1]torch.argmax(input, dim=None, keepdim=False)

功能:

Returns the indices of the maximum values of a tensor across a dimension.

  • input(​Tensor​) – the input tensor.即:输出张量。
  • dim(​int​) – the dimension to reduce. If​None​, the argmax of the flattened input is returned.即:要减少的维数。
  • keepdim(​bool​​) – whether the output tensor has ​​dim​​ retained or not. Ignored if ​​dim=None​​.即:


举例:

In [30]: a = torch.randn(4, 4)

In [31]: a
Out[31]:
tensor([[ 1.4360, 0.6342, -0.5233, 0.4902],
[ 1.1998, -0.8644, 0.5244, 0.2690],
[ 0.0998, -1.5043, 0.1619, -1.4634],
[ 0.0992, -1.0843, -1.3829, 0.5790]])

In [32]: torch.argmax(a)
Out[32]: tensor(0)

In [33]: torch.argmax(a,dim=0)
Out[33]: tensor([0, 0, 1, 3])

In [34]: torch.argmax(a,dim=1)
Out[34]: tensor([0, 0, 2, 3])
  • 对于tensor(0)输出,意义如下:

第0个:

1.4360

第1个:

0.6342

第2个:

-0.5233

第3个:

0.4902

第4个:

1.1998

第5个:

-0.8644

第6个:

0.5244

第7个:

0.2690

第8个:

0.0998

第9个:

-1.5043

第10个:

0.1619

第11个:

-1.4634

第12个:

0.0992

第13个:

-1.0843

第14个:

-1.3829

第15个:

0.5790




  • 对于tensor([0, 0, 1, 3])输出,意义如下:

PyTorch磨刀篇|argmax和argmin函数_数组_02

这时,每一列视为下标从0到3的一个数组。易见,从左到右每一列(数组)中最大值分别为:1.4360、0.6342、0.5244、0.5790,它们对应的一维数组中的下标分别为0、0、1、3,于是得到张量tensor([0, 0, 1, 3])。

  • ​对于tensor([0, 0, 2, 3])输出:

意义就容易理解了。沿水平方向从左向右从上到下看,每一行对应一个数组,下标向左向右依次为0、1、2、3。于是,这4个数组中最大值分别为1.4360、1.1998、0.1619、1.3829,它们对应的一维数组中的下标分别为0、0、2、3,于是得到张量tensor([0, 0, 2, 3])。


功能:

[2]torch.argmin(input, dim=None, keepdim=False) → LongTensor

argmin功能:Returns the indices of the minimum value(s) of the flattened tensor or along a dimension。

理解类似上面argmax函数的第二种格式,相应于dim=0和dim=1,依次返回由最小值对应下标组成的列方向数组与行方向数组组成的张量。