目录
1、结论
2、使用和理解
点乘mul(元素乘)
补充点积
2维矩阵乘mm
补充"混合"矩阵乘法 torch.matmul()
3、参考
1、结论
- @操作符可以执行矩阵乘法操作,类似 torch.mm(), torch.bmm(), torch.matmul() ;
- *乘法操作可以执行元素乘法,使用方法类似 torch.mul()
2、使用和理解
点乘mul(元素乘)
对应点相乘,x.mul(y) ,即点乘操作,点乘不求和操作,又可以叫作Hadamard product,官方手册参考;点乘再求和就是点积dot,即为卷积。
>>> a = torch.Tensor([[1,2], [3,4], [5, 6]])
>>> a
tensor([[1., 2.],
[3., 4.],
[5., 6.]])
>>> a.mul(a)
tensor([[ 1., 4.],
[ 9., 16.],
[25., 36.]])
# a*a等价于a.mul(a)
补充点积
手册参考
>>> torch.dot(torch.tensor([2, 3]), torch.tensor([2, 1]))
tensor(7)
2维矩阵乘mm
矩阵相乘,x.mm(y) , 矩阵大小需满足: (i, n)x(n, j)。
torch.mm(mat1, mat2, out=None)
- 其中mat1(nxm), mat2(mxd), 输出out(nxd)
- 一般只用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。
- 官方手册参考
>>> a
tensor([[1., 2.],
[3., 4.],
[5., 6.]])
>>> b = a.t() # 转置
>>> b
tensor([[1., 3., 5.],
[2., 4., 6.]])
>>> a.mm(b)
tensor([[ 5., 11., 17.],
[11., 25., 39.],
[17., 39., 61.]])
补充3维带Batch矩阵乘法 torch.bmm()
torch.bmm(bmat1, bmat2, out=None)
- 其中bmat1(B x n x m), bmat2(B x m x d), 输出out(B x n x d)。输出维度就是少了相同的那个维度,其他的拼接
- 该函数的两个输入必须是3维矩阵且第一维相同(表示Batch维度),不支持broadcast操作。
>>> input = torch.randn(10, 3, 4)
>>> mat2 = torch.randn(10, 4, 5)
>>> res = torch.bmm(input, mat2)
>>> res.size()
torch.Size([10, 3, 5])
补充"混合"矩阵乘法 torch.matmul()
torch.matmul(input, other, out=None)
输出维度根据情况判定(少相同的那个维度,剩下的拼接)
- 支持broadcast操作.
>>> # vector x vector
>>> tensor1 = torch.randn(3)
>>> tensor2 = torch.randn(3)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([])
>>> # matrix x vector
>>> tensor1 = torch.randn(3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([3])
>>> # batched matrix x broadcasted vector
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3])
>>> # batched matrix x batched matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(10, 4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
>>> # batched matrix x broadcasted matrix
>>> tensor1 = torch.randn(10, 3, 4)
>>> tensor2 = torch.randn(4, 5)
>>> torch.matmul(tensor1, tensor2).size()
torch.Size([10, 3, 5])
3、参考
pytorch】矩阵乘法 mm bmm matmul mul @ * 总结
PyTorch 对应点相乘、矩阵相乘