一、简介

广播机制实际上是在运算过程中,去处理两个形状不同向量的一种手段。

Pytorch中的广播机制和numpy中的广播机制一样, 因为都是数组的广播机制。神经网络模型构建过程中,很多时候使用到这里的广播机制。

二、​两个张量进行广播机制的条件


1 两个张量都至少有一个维度

2  按从右往左顺序看两个张量的每一个维度,每个对应的两个维度都匹配

什么情况下算匹配上了?满足下面的条件就可以:


     a. 这两个维度的大小相等;

     b. 某个维度上,一个张量有,一个张量没有;

     c. 某个维度上,一个张量有,一个张量也有——但大小是1;


举例:

x=torch.empty(5,3,4,1)
y=torch.empty( 3,1,1)

上面代码中,首先将两个张量维度向右靠齐,从右往左看


  • 两个张量第四维大小相等,都为1,满足上面条件a;
  • 第三个维度大小不相等,但第二个张量第三维大小为1,满足上面条件b;
  • 第二个维度大小相等都为3,满足上面条件a;
  • 第一个维度第一个张量有,第二个张量没有,满足上面条件b。

因此,两个张量每个维度都符合上面广播条件,因此可以进行广播。


两个张量维度从右往左看,如果出现两个张量在某个维度位置上面,维度大小不相等,且两个维度大小没有一个是1,那么这两个张量一定不能进行广播。


三、当两个张量满足可广播条件,怎么进行广播

举例:

x=torch.empty(5,3,4,1)
y=torch.empty( 3,1,1)

如上面代码所示:

 [第一步]将上面条件b的类型变成条件c的类型,也即是把第二个张量在缺失维度的位置上新增一个维度,维度大小为1,新增的维度如下面所示。

统一前:

x=torch.empty(5,3,4,1)

y=torch.empty( 3,1,1)

统一后:

x=torch.empty(5,3,4,1)

y=torch.empty(1,3,1,1)

 [第二步]x、y对应维度不等的位置,把size为1的维度会被广播得和对应维度一样大,比如y中0维的1会变成5,y中2维的1会变成4,最后两个张量的维度大小变成一样,然后再进行张量运算,转变的维度如下所示。

​统一前:

x=torch.empty(5,3,4,1)

y=torch.empty(1,3,1,1)

统一后:

x=torch.empty(5,3,4,1)

y=torch.empty(5,3,4,1)

[第三步]相量的相加运算:

In [66]: x=torch.empty(5,3,4,1)


In [67]: y=torch.empty(5,3,4,1)


In [68]: x+y

Out[68]: 

tensor([[[[ 2.5639e-09],

          [ 2.7450e-06],

          [ 1.1545e+19],

          [ 1.0013e+33]],


         [[ 1.8361e+25],

          [ 1.5128e+04],

          [ 7.0062e+22],

          [ 4.7423e+30]],


         [[ 4.7393e+30],

          [ 9.5478e-01],

          [ 4.4377e+27],

          [ 1.7975e+19]]],



        [[[ 4.6894e+27],

          [ 7.9463e+08],

          [ 3.2604e-12],

          [ 1.7743e+28]],


         [[ 3.4088e-19],

          [ 5.9682e-02],

          [ 7.0374e+22],

          [ 3.8946e+21]],


         [[ 4.4650e+30],

          [ 7.0975e+22],

          [ 7.9309e+34],

          [ 7.9439e+08]]],



        [[[ 2.5672e-09],

          [ 7.3113e+34],

          [ 1.8936e+23],

          [ 7.2153e+31]],


         [[ 6.0041e+31],

          [ 4.3638e+24],

          [ 2.2700e+31],

          [ 1.0899e+27]],


         [[ 5.7886e+22],

          [ 6.7120e+22],

          [ 1.1632e+33],

          [ 5.6003e-02]]],



        [[[ 7.0374e+22],

          [ 6.9983e+28],

          [ 1.9859e+29],

          [ 4.3218e+27]],


         [[ 4.7423e+30],

          [ 2.2856e+20],

          [ 3.2607e-12],

          [ 7.4086e+28]],


         [[ 7.1463e+22],

          [ 4.6241e+30],

          [ 1.0552e+24],

          [ 5.5757e-02]]],



        [[[ 1.8728e+31],

          [ 8.2661e-10],

          [ 2.9551e+21],

          [ 1.7036e+19]],


         [[ 4.3988e+21],

          [ 1.8524e+28],

          [ 3.7292e-08],

          [ 1.8728e-41]],


         [[ 1.2612e-44],

          [ 0.0000e+00],

          [-3.8139e-02],

          [-1.0845e-19]]]])


引用