一、torch.argmax()
(1)torch.argmax(input, dim=None, keepdim=False)
返回指定维度最大值的序号;
(2)dim
给定的定义是:the demention to reduce.也就是把dim
这个维度的,变成这个维度的最大值的index。
二、栗子
(1)这个例子,tensor(2, 3, 4)
,因为是dim=1
,即将第二维度去掉,变成tensor(2, 4)
,将每一个3x4数组,变成1x4数组。
如上所示的3×4矩阵,取每一列的最大值对应的下标,a[0]中第一列的最大值的行标为1, 第二列的最大值的行标为2,第三列的最大值行标为0,第4列的最大值行标为1,所以最后输出[1, 2, 0, 1],取每一列的最大值,结果为:
(1)如果改成dim=2
,即将第三维去掉,即取每一行的最大值对应的下标,结果为tensor(2, 3)
。