以下内容翻译自:Optimize Deep Learning GPU Operators with TVM: A Depthwise Convolution Example

高效的深度学习算子是深度学习系统的核心。通常这些算子很难优化,并且需要高性能计算专家的努力。TVM,端到端张量IR/DSL堆栈,使得这项任务更容易。

这个博客教你如何在TVM的帮助下编写高性能GPU运算核心。我们使用深度卷积(即topi.nn.depthwise_conv2d_nchw)作为示例,并演示如何在tensorflow中优化手动调优过的CUDA内核。在不同的工作负载下,我们的最终版本比tf-1.2中的优化内核快2到4倍,启用算子融合时速度快了3x-7倍。以下是在GTX1080上,filter size= [1,256,3,3],stride = [1,1],padding ='SAME’的测试结果:

让GPU一次处理多个对象_DeepLearning

Depthwise Convolution介绍

深度卷积是现代架构的重要组成部分,如XceptionMobileNet。这是一种降低深度神经网络计算复杂度的有效方法。

让GPU一次处理多个对象_2d_02

source: http://machinethink.net/blog/googles-mobile-net-architecture-on-iphone/

在TVM中,深度卷积可以被声明为:

# padding stage
PaddedInput = tvm.compute(
    (batch, in_channel, height_after_pad, width_after_pad),
    lambda b, c, i, j: tvm.select(
        tvm.all(i >= pad_top, i - pad_top < in_height, j >= pad_left, j - pad_left < in_width),
        Input[b, c, i - pad_top, j - pad_left], tvm.const(0.0)),
    name="PaddedInput")
# depthconv stage
di = tvm.reduce_axis((0, filter_height), name='di')
dj = tvm.reduce_axis((0, filter_width), name='dj')
Output = tvm.compute(
    (batch, out_channel, out_height, out_width),
    lambda b, c, i, j: tvm.sum(
        PaddedInput[b, c/channel_multiplier, i*stride_h + di, j*stride_w + dj] * Filter[c/channel_multiplier, c%channel_multiplier, di, dj],
        axis=[di, dj]),
    name='DepthwiseConv2d')

通用GPU优化指南

本部分简要介绍了优化CUDA代码时应该了解的三个概念:数据重用,共享内存和存储体冲突。如果你已了解它们,很好,那么你可以跳过这部分。

数据重用

在现代计算体系结构中,从存储器加载数据的成本远高于进行单个浮点计算。因此,我们总是希望在输入数据加载到寄存器或共享内存(缓存)后重新使用输入数据。

深度卷积有两种形式的数据重用:

  • 滤波器重用
  • 输入重用

滤波器重用发生在滤波器滑动窗口并进行多次计算时;输入重用是通过平铺来实现的,我们以3x3深度转换为例:

让GPU一次处理多个对象_GPU_03

如果没有平铺,每个线程加载3x3输入数据并计算1个输出元素。16个线程一起有9x16负载。

让GPU一次处理多个对象_GPU_04

通过平铺,每个线程加载4x4输入数据并计算2x2输出元素。4个线程一起有16x4负载。

共享内存和Bank Conflicts

共享内存可以被看作是GPU中的缓存。它是片上的,比全局存储器要快得多。

让GPU一次处理多个对象_2d_05

共享内存按块分配。通常的做法是将全局内存中的数据加载到共享内存中,然后块中的所有线程都从共享内存中读取数据。

共享内存的大小是有限的(通常是48K),所以我们必须注意共享内存溢出。此外,分配给一个块的共享内存太多会限制每个多处理器的活动块数量。

共享内存的另一个性能问题是Bank Conflicts。共享内存被分成可以同时访问的大小相同的内存模块(bank),但是,如果多个线程访问相同的存储体(导致bank冲突),访问将被串行化,从而降低有效带宽。

共享存储体的组织方式使得连续的地址被分配给连续的存储体。为了避免存储体冲突,最好连续的线程访问连续的内存地址,如下所示(每种颜色代表一个共享内存组):

让GPU一次处理多个对象_让GPU一次处理多个对象_06

有关共享内存和存储体冲突的更多详细信息,请参阅Nvidia的博客

好吧,现在让我们开始优化TVM中的深度卷积。

Schedule优化

内联计算PaddedInput以节省内存分配

正如我们从第1部分看到的那样,填充被明确地声明为一个单独的阶段。我们在线计算它以避免冗余内存分配:

s = tvm.create_schedule(Output.op)
s[PaddedInput].compute_inline()

将一个大通道分成较小的块

深度卷积的一个简单的调度是一个cuda块负责一个输入通道和相应的滤波器,将它们加载到共享内存中,然后计算:

IS = s.cache_read(PaddedInput, "shared", [DepthwiseConv2d])
FS = s.cache_read(Filter, "shared", [DepthwiseConv2d])
block_y = tvm.thread_axis("blockIdx.y")
block_x = tvm.thread_axis("blockIdx.x")
# bind the dimension of batch (N in NCHW) with block_y
s[Output].bind(Output.op.axis[0], block_y)
# bind the dimension of channel (C in NCHW) with block_x
s[Output].bind(Output.op.axis[1], block_x)

我们在GTX 1080上测试1000次运行的平均时间成本,并与tensorflow中的depthwise conv2d进行比较。结果如下:

Input

Filter

stride

tf-1.2 SAME pad (us)

TVM SAME pad (us)

[1, 256, 21, 21]

[256, 1, 3, 3]

[1, 1]

16.1

9.1

[1, 256, 32, 32]

[256, 1, 3, 3]

[1, 1]

34.8

14.5

[1, 256, 64, 64]

[256, 1, 3, 3]

[1, 1]

130.9

98.9

[1, 256, 96, 96]

[256, 1, 3, 3]

[1, 1]

251.6

387.4

正如我们所看到的,这个调度表在21x21或32x32这样的小特征图下表现良好,然而,随着特征图增加到大于64x64,其性能严重下降。一个主要原因是分配的共享内存过多 一个块限制每个多处理器的活动块数量。

我们修改调度表将一个大通道分成更小的块。例如,一个通道(64x64或96x96)被分成32x32的块,一个cuda块处理一个32x32的块:

blocking_h = 32
blocking_w = 32
# split the dimension of height (H in NCHW)
bx1, _ = s[Output].split(Output.op.axis[2], factor=blocking_h)
# split the dimension of width (W in NCHW)
bx2, _ = s[Output].split(Output.op.axis[3], factor=blocking_w)
# assign one 32 x 32 block to one cuda block
by = s[Output].fuse(Output.op.axis[0], Output.op.axis[1])
s[Output].bind(by, block_y)
bx = s[Output].fuse(bx1, bx2)
s[Output].bind(bx, block_x)

这是新的结果:

Input

[blocking_h, blocking_w]

tf-1.2 SAME pad (us)

TVM SAME pad (us)

[1, 256, 64, 64]

[32, 32]

130.9

63.4

[1, 256, 96, 96]

[32, 32]

251.6

132.5

我们的分块策略有效!对于64x64尺寸通道,它带来1.6倍的加速(98.9us->63.4us); 对于96x96尺寸通道,它带来了2.9倍的加速(387.4us->132.5us)。

调整线程号参数

如何在一个cuda块中安排32x32线程的工作负载?直观地说,它应该是这样的:

num_thread_y = 8
num_thread_x = 8
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
ty, yi = s[Output].split(h_dim, nparts=num_thread_y)
tx, xi = s[Output].split(w_dim, nparts=num_thread_x)
s[Output].reorder(ty, tx, yi, xi)
s[Output].bind(ty, thread_y)
s[Output].bind(tx, thread_x)

调度中有两个参数:num_thread_ynum_thread_x。如何确定它们的最佳组合? 那么,我们先做一些实验。以下是Filter = [256,1,3,3]和stride = [1,1]的结果:

Case

Input

num_thread_y

num_thread_x

TVM SAME pad (us)

1

[1, 256, 32, 32]

8

32

9.7

2

[1, 256, 32, 32]

4

32

8.8

3

[1, 256, 32, 32]

1

32

17.7

4

[1, 256, 32, 32]

32

1

32.5

上面一些有趣的观察结果:

  • 情况2比情况1快。在情况2中,每个线程计算输出中的8×1分片,其对应于输入中的10×3分片。它比情况1的4x1分片具有更好的数据重用性。
  • 情况3比情况2慢。这是因为在情况3中,每个线程的工作量太大并且导致本地存储器读取的很多成本。
  • 情况4比情况3慢。这是因为num_thread_x=32确保没有存储体冲突,而num_thread_y=32不能。

总结我们从以上观察得出的结论:

  • 大块分片有利于数据重用,但对本地内存读取不利。
  • num_thread_ynum_thread_x对存储体冲突的影响是不同的。
  • 要找到num_thread_ynum_thread_x的最佳组合,可以实现有效的共享内存访问(避免存储库冲突),数据重用和本地内存读取之间的平衡。

非常棘手。那么,我们应该做些什么才能找到最佳组合?答案是蛮力搜索。我们可以将num_thread_ynum_thread_x作为参数传递给schedule函数,并尝试所有可能的组合以找到最优的一个。这可以在TVM中轻松完成:

def schedule_depthwise_conv2d(..., num_thread_y=8, num_thread_x=8):
    num_thread_y = num_thread_y
    num_thread_x = num_thread_x
    do_schedule_as_usual
    return schedule

min_time_cost = inf
for num_thread_y, num_thread_x in all_possible_combinations:
    schedule = schedule_depthwise_conv2d(..., num_thread_y=num_thread_y, num_thread_x=num_thread_x)
    time_cost = test_depthwise_conv2d(..., schedule)
    if time_cost < min_time_cost:
        min_time_cost = time_cost
        optimal_combination = [num_thread_y, num_thread_x]

实际上,它可以被看作是一个简单的自动调度程序。

Vthread和Stripped模式

引入TVM中的Vthread(虚拟线程)以支持分步模式。我们可以这样使用它:

num_vthread_y = 2
num_vthread_x = 2
num_thread_y = 8
num_thread_x = 8
thread_vy = tvm.thread_axis((0, num_vthread_y), "vthread", name="vy")
thread_vx = tvm.thread_axis((0, num_vthread_x), "vthread", name="vx")
thread_y = tvm.thread_axis((0, num_thread_y), "threadIdx.y")
thread_x = tvm.thread_axis((0, num_thread_x), "threadIdx.x")
# split the dimension of height (H in NCHW) twice
tvy, vyi = s[Output].split(h_dim, nparts=num_vthread_y)
ty, yi = s[Output].split(vyi, nparts=num_thread_y)
# split the dimension of width (W in NCHW) twice
tvx, vxi = s[Output].split(w_dim, nparts=num_vthread_x)
tx, xi = s[Output].split(vxi, nparts=num_thread_x)
# bind thread and vthread respectively
s[Output].bind(tvy, thread_vy)
s[Output].bind(tvx, thread_vx)
s[Output].bind(ty, thread_y)
s[Output].bind(tx, thread_x)
s[Output].reorder(tvy, tvx, ty, tx, yi, xi)

让我们打印IR以查看vthread的作用:

/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce DepthwiseConv2d {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
  for (i.inner.inner.inner, 0, 2) {
    for (j.inner.inner.inner, 0, 2) {
      DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = 0.000000f
      DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] = 0.000000f
      DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] = 0.000000f
      DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] = 0.000000f
      for (di, 0, 3) {
        for (dj, 0, 3) {
          DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = (DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((33 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
          DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 512)] + (tvm_if_then_else(((((((-15 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((17 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + 479)], 0.000000f)*Filter[((di*3) + dj)]))
          DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 16)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((33 - di) - i.inner.inner.inner))) && (((-15 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((17 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -17)], 0.000000f)*Filter[((di*3) + dj)]))
          DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] = (DepthwiseConv2d[(((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + 528)] + (tvm_if_then_else(((((((-15 - di) - i.inner.inner.inner) <= (((blockIdx.x*16) + threadIdx.y)*2)) && ((((blockIdx.x*16) + threadIdx.y)*2) < ((17 - di) - i.inner.inner.inner))) && (((-15 - dj) - j.inner.inner.inner) <= (threadIdx.x*2))) && ((threadIdx.x*2) < ((17 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*16) + threadIdx.y)*32) + threadIdx.x)*2) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + 495)], 0.000000f)*Filter[((di*3) + dj)]))
        }
      }
    }
  }
}

没有vthread(只设置为1),IR是:

/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce DepthwiseConv2d {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
  for (i.inner.inner.inner, 0, 4) {
    for (j.inner.inner.inner, 0, 4) {
      DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = 0.000000f
      for (di, 0, 3) {
        for (dj, 0, 3) {
          DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] = (DepthwiseConv2d[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner)] + (tvm_if_then_else(((((((1 - di) - i.inner.inner.inner) <= (((blockIdx.x*8) + threadIdx.y)*4)) && ((((blockIdx.x*8) + threadIdx.y)*4) < ((33 - di) - i.inner.inner.inner))) && (((1 - dj) - j.inner.inner.inner) <= (threadIdx.x*4))) && ((threadIdx.x*4) < ((33 - dj) - j.inner.inner.inner))), Input[(((((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i.inner.inner.inner*32)) + j.inner.inner.inner) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
        }
      }
    }
  }
}

正如我们所看到的,当num_vthread_y = 2num_vthread_x = 2时,32 x 32通道被分成四个16 x 16的子通道。每个线程一次计算四个输出元素,一个子通道中有一个元素。

以下是Filter = [256,1,3,3],stride = [1,1],blocking_h = 32,blocking_w = 32的结果:

|Case | Input | num_thread_y, num_thread_x | num_vthread_y, num_vthread_x | TVM SAME pad (us)|

|—|---|—|---|

|1 | [1, 256, 96, 96] | 8, 8 | 1, 1 | 132.5|

|2 | [1, 256, 96, 96] | 8, 8 | 1, 4 | 103.1|

|3 | [1, 256, 96, 96] | 4, 32| 1, 1 | 95.9 |

|4 | [1, 256, 96, 96] | 8, 16| 1, 2 | 90.9 |

情况2比情况1更快。这是因为在情况2中num_thread_x = 8num_vthread_x = 4一起确保连续线程访问连续内存地址,从而避免存储库冲突(如下所示)(每种颜色表示一个线程的工作负载):

让GPU一次处理多个对象_DeepLearning_07

理论上,情况3和4应该是相样快,因为它们每个线程具有相同的工作量,并且都享有高效的共享内存访问。不知怎的,案例4就是更快一点。

还记得tensorflow的速度吗?是251.6us,现在TVM速度提高了2.8倍。387.4 -> 132.5 -> 95.9 -> 90.9,分块帮助最大; 调整线程号节约37us; vthread节约额外的5us。

事实上,在更大或更多通道的卷积上,TVM比tensorflow更快(因为更多的滤波器重用):

Input

Filter

stride

tf-1.2 SAME pad (us)

TVM SAME pad (us)

How faster is TVM

[1, 256, 96, 96]

[256, 1, 3, 3]

[1, 1]

251.6

90.9

2.8x

[1, 256, 96, 96]

[256, 1, 5, 5]

[1, 1]

597.6

128.9

4.6x

[1, 256, 96, 96]

[256, 2, 3, 3]

[1, 1]

659.9

143.7

4.6x

[1, 256, 96, 96]

[256, 2, 5, 5]

[1, 1]

1203.9 170.5

7.1x

算子融合

我们可以在深度学习中进行的一种典型优化是运算符融合,即在单个内核中将多个运算符一起计算,而不将中间结果保存回全局内存。TVM支持开箱即用。

考虑神经网络中的常见模式:depthwise_conv2d + scale_shift + relu。 我们可以通过稍微修改原始调度表将三个算子融合为一个:

DepthwiseConv2d = topi.nn.depthwise_conv2d(Input, Filter, stride, padding)
ScaleShift = topi.nn.scale_shift(DepthwiseConv2d, Scale, Shift)
Relu = topi.nn.relu(ScaleShift)

Output = Relu # is no longer DepthwiseConv2d
s[ScaleShift].compute_inline() # this line fuses ScaleShift, explicitly
s[DepthwiseConv2d].set_scope("local") # this line fuses DepthwiseConv2d, implicitly
schedule(Output) # schedule for Output the same way we schedule for DepthwiseConv2d as discussed above
s[DepthwiseConv2d].compute_at(s[Output], tx) # tx is the inner most axis, bound to threadIdx.x

它会产生像这样的IR:

/* Input = [1, 1, 32, 32], Filter = [1, 1, 3, 3], stride = [1, 1], padding = 'SAME' */
produce Relu {
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 1
  // attr [DepthwiseConv2d] storage_scope = "local"
  allocate DepthwiseConv2d[float32 * 1 * 1 * 4 * 4]
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 1
  // attr [iter_var(threadIdx.y, Range(min=0, extent=8), threadIdx.y)] thread_extent = 8
  // attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
  produce DepthwiseConv2d {
    for (i, 0, 4) {
      for (j, 0, 4) {
        DepthwiseConv2d[((i*4) + j)] = 0.000000f
        for (di, 0, 3) {
          for (dj, 0, 3) {
            DepthwiseConv2d[((i*4) + j)] = (DepthwiseConv2d[((i*4) + j)] + (tvm_if_then_else(((((((1 - di) - i) <= (((blockIdx.x*8) + threadIdx.y)*4)) && ((((blockIdx.x*8) + threadIdx.y)*4) < ((33 - di) - i))) && (((1 - dj) - j) <= (threadIdx.x*4))) && ((threadIdx.x*4) < ((33 - dj) - j))), Input[(((((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i*32)) + j) + (di*32)) + dj) + -33)], 0.000000f)*Filter[((di*3) + dj)]))
          }
        }
      }
    }
  }
  for (i2.inner.inner.inner, 0, 4) {
    for (i3.inner.inner.inner, 0, 4) {
      Relu[((((((((blockIdx.y + blockIdx.x)*8) + threadIdx.y)*32) + threadIdx.x)*4) + (i2.inner.inner.inner*32)) + i3.inner.inner.inner)] = max(((DepthwiseConv2d[((i2.inner.inner.inner*4) + i3.inner.inner.inner)]*Scale[0]) + Shift[0]), 0.000000f)
    }
  }
}

正如我们所看到的,每个线程在将depthwise_conv2d的结果写入全局内存之前计算scale_shift和relu。融合的运算符与单个depthwise_conv2d一样快。以下是Input = [1,256,96,96],Filter = [256,1,3,3],stride = [1,1],padding ='SAME’的结果:

  • tf-1.2 depthwise_conv2d:251.6 us
  • tf-1.2 depthwise_conv2d + scale_shift + relu(单独):419.9 us
  • TVM depthwise_conv2d:90.9 us
  • TVM depthwise_conv2d + scale_shift + relu(融合):91.5 us

算子融合的优势是显而易见的。

这不是终点,TVM可以以更智能的方式进行算子融合。你可以参考这个并阅读下面提供的源代码。

让我们看看代码

致谢

作者非常感谢陈天奇的有益建议和鼓舞人心的讨论。