一、torch.argmax()

(1)​​torch.argmax(input, dim=None, keepdim=False)​​​返回指定维度最大值的序号;
(2)​​​dim​​​给定的定义是:the demention to reduce.也就是把​​dim​​这个维度的,变成这个维度的最大值的index。

二、栗子

# -*- coding: utf-8 -*-
"""
Created on Fri Jan 7 15:05:09 2022

@author: 86493
"""
import torch
a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],

[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
b=torch.argmax(a,dim=1)
print(a)
print(a.shape)
print(b)

(1)这个例子,​​tensor(2, 3, 4)​​​,因为是​​dim=1​​​,即将第二维度去掉,变成​​tensor(2, 4)​​,将每一个3x4数组,变成1x4数组。

[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]

如上所示的3×4矩阵,取每一列的最大值对应的下标,a[0]中第一列的最大值的行标为1, 第二列的最大值的行标为2,第三列的最大值行标为0,第4列的最大值行标为1,所以最后输出[1, 2, 0, 1],取每一列的最大值,结果为:

tensor([[[ 1,  5,  5,  2],
[ 9, -6, 2, 8],
[-3, 7, -9, 1]],

[[-1, 7, -5, 2],
[ 9, 6, 2, 8],
[ 3, 7, 9, 1]]])
torch.Size([2, 3, 4])
tensor([[1, 2, 0, 1],
[1, 0, 2, 1]])

(1)如果改成​​dim=2​​,即将第三维去掉,即取每一行的最大值对应的下标,结果为​​tensor(2, 3)​​。

import torch
a=torch.tensor([
[
[1, 5, 5, 2],
[9, -6, 2, 8],
[-3, 7, -9, 1]
],

[
[-1, 7, -5, 2],
[9, 6, 2, 8],
[3, 7, 9, 1]
]])
b=torch.argmax(a,dim=2)
print(b)
print(a.shape)
"""
tensor([[2, 0, 1],
[1, 0, 2]])
torch.Size([2, 3, 4])
"""