楔子
WaitGroup 是用来做任务编排的一个并发原语,它解决的是 "并发 - 等待" 的问题。比如我们要完成一个大的任务,需要使用并行的 goroutine 执行三个小任务,只有这三个小任务都完成,我们才能去执行后面的任务。
这个时候使用使用 WaitGroup 就非常合适了,而且名字也很形象:等待组,它可以对一组 goroutine 进行编排,保证这一组 goroutine 都执行完毕之后程序再往下执行。那么 WaitGroup 是怎么做到的呢?以及用法如何呢?我们下面来看一看。
WaitGroup 的基本用法
首先 WaitGroup 内部有一个计数器,然后围绕着计数器提供了三个方法:
func (wg *WaitGroup) Add(delta int)
func (wg *WaitGroup) Done()
func (wg *WaitGroup) Wait()
我们分别看下这三个方法:
Add:给 WaitGroup 内部计数器的值增加 delta
Done:给 WaitGroup 内部计数器的值减一
Wait:如果计数器的值不为 0,那么此方法会阻塞,如果为 0,程序往下执行
想必你此时已经猜到 WaitGroup 该怎么用了,进入协程时 Add(1),执行结束时 Done() 一下即可。
package main
import (
"sync"
"time"
)
type Counter struct {
sync.Mutex
count int64
}
func (c *Counter) Incr() {
c.Lock()
defer c.Unlock()
c.count++
}
func (c *Counter) Count() int64 {
c.Lock()
defer c.Unlock()
return c.count
}
func main() {
var wg sync.WaitGroup
var count Counter
for i := 0; i < 10; i++ {
// 开启 10 个协程,但是在开始的位置 Add(1) 会使得计数器的值加 1
// 结束的位置 Done() 一下会使得计数器的值减 1
go func() {
wg.Add(1)
count.Incr()
time.Sleep(time.Second)
wg.Done()
}()
}
// 如果计数器的值不为 0,那么会一直阻塞
// 因此我们一定要保证协程执行完毕之后计数器的值为 0,否则程序就会卡死在这里了
wg.Wait()
// 如果我们能够知道开启的协程的数量,那么也可以不把 Add(1) 写在协程中
// 比如我们上面开启了 10 个协程,那么可以在进入 for 循环之前写上 wg.Add(10),此时计数器的值为 10
// 然后在协程内部只调用 wg.Done() 即可,当所有协程都执行完之后,技术器的值会减去 10,最终变成 0
}
因此使用起来还是比较简单的,核心是一定要确保任务执行完之后计数器的值为 0,否则程序就会卡死。很常见的错误就是 Add 之后忘记 Done,尤其是逻辑比较长的时候,写到最后很容易把 wg.Done() 给忘记了,因此建议通过 defer 来保证。
以上就是我们使用 WaitGroup 编排这类任务的常用方式,而 "这类任务" 指的就是需要启动多个 goroutine 执行子任务,并且主 goroutine 需要等待指定的子 goroutine 都完成后才继续执行。
了解完基本用法之后,下面我们来看看底层实现,因为单从语法的使用层面上讲,确实没有什么难度,毕竟 Go 本身用起来就很简单。所以我们还需要剖析底层实现,当然这就不一定简单了。
WaitGroup 的实现
首先来看一下 WaitGroup 的数据结构,结构体只包含了两个成员:
noCopy:辅助字段,一个空结构体(struct {})的别名,主要就是辅助 vet 工具检测该 WaitGroup 实例是否发生了值拷贝,和 Mutex、RWMutex 一样,WaitGroup 在传递时也必须传递指针
state1:一个具有符合意义的字段,包含 WaitGroup 的计数器的值、调用 Wait 方法阻塞时的 waiter 数和信号量
type WaitGroup struct {
noCopy noCopy
state1 [3]uint32
}
然后我们重点说一下这个 state1 成员,它包含了 WaitGroup 的计数器的值、调用 Wait 方法阻塞时的 waiter 数和信号量。但是有一点需要注意,对于不同的处理该字段的值也会有区别。
如果是 64 位机器,那么 state1[0] 表示调用 Wait 方法阻塞时的 waiter 数,可以理解为有几个 goroutine 调用 wg.Wait() 阻塞了,那么 waiter 数就是几;state1[1] 表示 WaitGroup 计数器的值;state1[2] 表示信号量
如果是 32 位机器,那么 state1[0] 表示信号量;state1[1] 表示调用 Wait 方法阻塞时的 waiter 数;state1[2] 表示 WaitGroup 计数器的值
WaitGroup 有一个方法 state,专门用来获取上面的信息,我们来看一下。
func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
// 如果地址(转成十进制整数)模上 8 等于 0,那么说明地址是 8 字节对齐,也就是 64 位机器
// 否则是 32 位机器,因为 Go 只能运行在 64 和 32 位机器上
if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
/* 但是两个 return 可能会有点难理解,首先它们都返回了两个元素,并且第二个值是信号量(指针)
* 但第一个值是什么?我们注意到 state1 数组的元素是 uint32 类型,而这里转成了 *uint64
* 很明显,"调用 Wait 方法阻塞的 waiter 数" 和 "WaitGroup 计数器的值" 这两个 uint32 整数
* 被组合成了一个 uint64 整数,然后返回其指针
*/
return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
} else {
return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
}
}
关于上面的 return,我们实际演示一下,这里只以 64 位机器为例。首先我们来假设一下,如果 state1 数组为 [1, 2, 3]
,那么你认为 state 方法会返回什么?不用想,肯定是 return 两个值,第二个值是 3,但第一个值是什么呢。
uint32 整数 1 转成二进制:00000000 00000000 00000000 00000001
uint32 整数 2 转成二进制:00000000 00000000 00000000 00000010
拼接成 uint64:00000000 00000000 00000000 00000010 00000000 00000000 00000000 00000001
注意:state1[1] 占高位、state1[0] 占低位,因为数组从左往右地址是增大的,而 Go 里面高位存储在高地址中。
package main
import "fmt"
func main() {
fmt.Println(0b00000000_00000000_00000000_00000010_00000000_00000000_00000000_00000001)
/* 8589934593 */
}
我们看到结果是 8589934593,那么实际情况是不是这样呢?
package main
import (
"fmt"
"unsafe"
)
func main() {
var state1 = [3]uint32{1, 2, 3}
fmt.Println(*(*uint64)(unsafe.Pointer(&state1)))
/* 8589934593 */
}
我们看到结果是一样的,至于根据结果进行逆运算也是非常容易的。
state1[1] 放在高位,那么计算的时候直接右移 32 个位即可
state1[0] 放在低位,那么计算的时候只需要和后 16 个位为 1(其它位为 0)的整数进行与运算即可
package main
import (
"fmt"
)
func main() {
var num = 8589934593
fmt.Printf("state[1]:%d,state[0]:%d", num>>32, num&0xFFFF)
/* state[1]:2,state[0]:1 */
}
以上我们就解释了 state 方法里面的 return 到底是怎么回事,当然有点跑题了,不过也还好。
回归正题,我们来看一下 Add、Done、Wait 方法的实现,删除掉了一些用于 race 检查和异常检查的部分,这里我们重点关注 Add、Done、Wait 这些方法本身的实现。
Add 方法
Add 方法,我们知道主要负责操作计数器的值,也就是 state1[1]。我们可以给计数器增加一个 delta,内部会通过原子操作将这个值加上去。此外需要注意的是,delta 也可以是一个负数,Done 方法内部就是通过 Add(-1) 实现的。
func (wg *WaitGroup) Add(delta int) {
statep, semap := wg.state()
// 给计数器增加 delta,所以 delta 左移 32 位加上去即可
state := atomic.AddUint64(statep, uint64(delta)<<32)
// 获取计数器(state[1])的值,我们说它位于高位,所以右移 32 位即可
v := int32(state >> 32)
// 获取调用 Wait 方法阻塞的 waiter 数(state[0]),我们之前是和 0xFFFF 进行与运算
// 但是显然 Go 内部的做法更简单,直接转成 uint32 即可,超过 32 位的部分会截断
w := uint32(state)
// 计数器的值一定大于等于 0
// 如果计数器大于 0,或者没有阻塞的 goroutine,那么直接返回
if v > 0 || w == 0 {
return
}
// 如果计数器的值 v 是 0 并且 waiter 的数量 w 不是 0,那么 state 的值就是 waiter 的数量
// 此时应该将 waiter 的数量设置为 0,因为计数器的值为 0 了,所以应该唤醒所有的 waiter
// 这个 waiter 就是调用 wg.Wait() 阻塞的 goroutine,由于 v 是 0,w 也要设置为 0,所以将 *statep 直接设置为 0 即可
*statep = 0
for ; w != 0; w-- {
runtime_Semrelease(semap, false, 0)
}
}
整个逻辑还是很简单的,就是将计数器的值增加 delta。然后判断计数器的值 v 是否等于 0,如果 v 是 0、并且 w 不是 0,那么唤醒所有的 waiter 即可。
Done 方法
Done 的逻辑很简单,我们说它内部是通过 Add 实现的,可以看一下代码。
// Done 方法实际就是计数器减 1
func (wg *WaitGroup) Done() {
wg.Add(-1)
}
Wait 方法
Wait 方法的逻辑是:不断检查 state 的值,如果其中的计数器的值变为了 0,那么说明所有的任务都已完成,调用者不必再等待,直接返回。如果计数值大于 0,说明此时还有任务没完成,那么调用者就变成了等待者,需要加入 waiter 队列,并且阻塞住自己。
func (wg *WaitGroup) Wait() {
statep, semap := wg.state()
for {
state := atomic.LoadUint64(statep)
// 当前计数器的值
v := int32(state >> 32)
// waiter 的数量,就是调用 wg.Done() 而阻塞的 goroutine 的数量
w := uint32(state)
if v == 0 {
// 如果计数值为 0,调用这个方法的 goroutine 不必再等待,直接返回即可
return
}
// 否则把 waiter 数量加 1
// 期间可能有并发调用 Wait 的情况,所以最外层使用了一个 for 循环
if atomic.CompareAndSwapUint64(statep, state, state+1) {
// 阻塞休眠等待,而唤醒则通过 wg.Add,当计数器更新之后值变为 0,那么会唤醒 waiter
runtime_Semacquire(semap)
return
}
}
}
以上就是这三个方法的底层实现,其实也挺简单的。
使用 WaitGroup 时的常见错误
使用 WaitGroup 也是会有很多坑的,下面来看一下。
计数器的值为负数
WaitGroup 的计数器的值必须大于等于 0,我们在更改这个计数值的时候,WaitGroup 会检查计数器的值,如果小于 0,则引发 panic。
一般情况下,有两种情况会导致计数器设置为负数。
第一种情况:调用 Add 的时候传递一个负数,当然,如果你能保证当前的计数器的值加上这个负数后还是大于等于 0 的话,也没有问题,否则就会导致 panic;
func main() {
var wg sync.WaitGroup
wg.Add(10)
wg.Add(-10) // 将 -10 作为参数调用 Add,计数器的值被设置为0
wg.Add(-1) // 将 -1 作为参数调用 Add,此时会变成 -1,会引发 panic
}
第二种情况:调用 Done 方法的次数过多,超过了 WaitGroup 的计数值;
func main() {
var wg sync.WaitGroup
wg.Add(1)
wg.Done()
wg.Done() // panic
}
使用 WaitGroup 的正确姿势是,预先确定好 WaitGroup 的计数值,然后调用相同次数的 Done 完成相应的任务。比如在 WaitGroup 变量声明之后,就立即设置它的计数,但这要求你必须事先知道组里面到底有多个 goroutine。或者在 goroutine 启动之前增加 1,然后在 goroutine 退出之间调用 Done。
如果你没有遵循这些规则,就很可能会导致 Done 方法调用的次数和计数值不一致,进而造成死锁(Done 调用次数比计数值少)或者 panic(Done 调用次数比计数值多)。
Add 之前先 Wait
在使用 WaitGroup 的时候,要遵循一个原则:等所有的 Add 方法调用之后再调用 Wait,否则就可能导致 panic 或者不期望的结果。
package main
import (
"fmt"
"sync"
"time"
)
func doSomething(wg *sync.WaitGroup) {
time.Sleep(1000) // 故意 sleep 一下
wg.Add(1)
fmt.Println("do something")
wg.Done()
}
func main() {
var wg sync.WaitGroup
go doSomething(&wg)
go doSomething(&wg)
go doSomething(&wg)
wg.Wait()
}
程序执行之后会发现什么也没有打印,原因就在于子协程执行 wg.Add(1) 的时候,主协程就已经执行了 wg.Wait()。而计数器初始是为 0 的,所以主协程不会阻塞,然后程序退出了。因此我们要确保子协程中的 Add 一定要在 Wait 之前执行。
Add 和 Wait 同时调用
首先 WaitGroup 是可重用的,只要 WaitGroup 内部计数器的值恢复到零值的状态,那么它就可以被看作是新创建的 WaitGroup。之后我们可以继续把该 WaitGroup 当成是新创建的来用,然后继续调用 wg.Add,但是这一步一定要在上一轮 wg.Wait 结束之后进行。光说可能有点难理解,我们举个栗子:
package main
import (
"sync"
"time"
)
func main() {
var wg sync.WaitGroup
go func() {
time.Sleep(time.Millisecond)
wg.Add(1)
wg.Done() // 计数器减 1
wg.Add(1) // 计数值加 1
}()
wg.Wait() // 主 goroutine 等待,有可能和第二个 wg.Add(1) 并发执行
}
当调用完 wg.Done() 之后,阻塞在 wg.Wait() 处的协程会被唤醒,然后这个 wg 就可以看成是新创建的 wg,因为此时内部的成员的值为零值,和新创建一个 wg 没有区别。但重点是:执行 wg.Wait() 的同时(唤醒阻塞 goroutine),子协程内部也会执行 wg.Add(1),如果这两者并发执行,那么就会 panic。
不过个人测试了几次均没有出现 panic,因为 wg.Wait() 总是在子协程的第二个 wg.Add(1) 执行之前先执行完,但如果真的不幸、这两者同时执行了,那么就会造成 panic。所以如果想要重用 WaitGroup,那么一定要等到上一轮的 wg.Wait() 执行完毕之后,再执行 wg.Add。
Docker 源码里面就犯过两次这种错误.
小结
WaitGroup 的使用场景还是很明确的,就是编排一组 goroutine。尽管使用 WaitGroup 也容易踩坑,但只要记住以下五点,便可以避免。
不重用 WaitGroup,新建一个 WaitGroup 不会带来多大的资源开销,重用反而更容易出错
保证所有的 Add 方法调用都在 Wait 之前
不传递负数给 Add 方法,只通过 Done 来给计数值减 1
不做多余的 Done 方法调用,保证 Add 的计数值和 Done 方法调用的数量是一样的
不遗漏 Done 方法的调用,否则会导致 Wait 阻塞而无法返回
如果觉得文章对您有所帮助,可以请囊中羞涩的作者喝杯柠檬水,万分感谢,愿每一个来到这里的人都生活愉快,幸福美满。