目录

1、结论

2、使用和理解

点乘mul(元素乘)

补充点积

2维矩阵乘mm

补充"混合"矩阵乘法 torch.matmul()

3、参考


1、结论

  • @操作符可以执行矩阵乘法操作,类似 torch.mm(), torch.bmm(), torch.matmul() ;
  • *乘法操作可以执行元素乘法,使用方法类似 torch.mul()

2、使用和理解

点乘mul(元素乘)

对应点相乘,x.mul(y) ,即点乘操作,点乘不求和操作,又可以叫作Hadamard product,官方手册参考;点乘再求和就是点积dot,即为卷积。

>>> a = torch.Tensor([[1,2], [3,4], [5, 6]])
>>> a
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])
>>> a.mul(a)
tensor([[ 1.,  4.],
        [ 9., 16.],
        [25., 36.]])
 
# a*a等价于a.mul(a)

补充点积

手册参考 

>>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
tensor(7)

2维矩阵乘mm

矩阵相乘,x.mm(y) , 矩阵大小需满足: (i, n)x(n, j)。

  • torch.mm(mat1, mat2, out=None)
  • 其中mat1(nxm), mat2(mxd), 输出out(nxd)
  • 一般只用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。
  • 官方手册参考
>>> a
tensor([[1., 2.],
        [3., 4.],
        [5., 6.]])
>>> b = a.t()  # 转置
>>> b
tensor([[1., 3., 5.],
        [2., 4., 6.]])
 
>>> a.mm(b)
tensor([[ 5., 11., 17.],
        [11., 25., 39.],
        [17., 39., 61.]])

补充3维带Batch矩阵乘法 torch.bmm()

  • torch.bmm(bmat1, bmat2, out=None)
  • 其中bmat1(B x n x m), bmat2(B x m x d), 输出out(B x n x d)。输出维度就是少了相同的那个维度,其他的拼接
  • 该函数的两个输入必须是3维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])

补充"混合"矩阵乘法 torch.matmul()

  • torch.matmul(input, other, out=None)
  • 输出维度根据情况判定(少相同的那个维度,剩下的拼接)
  • 支持broadcast操作.
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])

3、参考

pytorch】矩阵乘法 mm bmm matmul mul @ * 总结

PyTorch 对应点相乘、矩阵相乘