简介

Go的并发原语使构建流数据pipeline变得容易,流数据pipeline可以有效地利用I/O和多个CPU。本文介绍了构建此类pipeline的示例,说明了故障发生时go协程永远阻塞的情况,并介绍了健壮地处理故障的方法。

何为pipeline

Pipeline即流水线、管道。一个pipeline是指一系列用通道(channels)连接的阶段(stages),每个阶段是一组运行同一个函数的go协程,在每一个阶段中的go协程:

  • 从上游管道接收数据
  • 在数据上执行一些操作,或生成新的值
  • 将数据发到下游管道中

例子:平方数

以下例子第一个阶段将给定的数发到管道中,第二个阶段从管道接收数字,取平方后发送到结果管道中。

// 将数字发到管道中
func gen(nums ...int) <-chan int {
    out := make(chan int)
    go func() {
        for _, n := range nums {
            out <- n
        }
        close(out)
    }()
    return out
}
// 从管道接收数字,取平方后发送到其下游管道中
func sq(in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        for n := range in {
            out <- n * n
        }
        close(out)
    }()
    return out
}

扇入扇出

多个函数能够从同一个通道读取数据直至通道关闭。这提供了一个将工作分配到一组worker的方法。
一个函数能够从多个管道读取数据,直到它们都关闭,并把读取的数据发送到一个输出管道,这称为扇入
我们可以修改上述pipeline,使得可以同时运行两个sq实例从管道读取值,这需要引入一个merge函数:

func main() {
    in := gen(2, 3)

    // Distribute the sq work across two goroutines that both read from in.
    c1 := sq(in)
    c2 := sq(in)

    // Consume the merged output from c1 and c2.
    for n := range merge(c1, c2) {
        fmt.Println(n) // 4 then 9, or 9 then 4
    }
}

merge函数读取多个管道的值并将它们汇聚到一个管道。其实现如下:

func merge(cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int)

    // 为每一个管道开启一个go协程,每个go协程分别从相应的管道读取值
    // 当负责的管道关闭后,go协程将退出
    output := func(c <-chan int) {
        for n := range c {
            out <- n
        }
        wg.Done()
    }
    wg.Add(len(cs))
    for _, c := range cs {
        go output(c)
    }

    // 开启一个额外的go协程负责将管道关闭
    go func() {
        wg.Wait()
        close(out)
    }()
    return out
}

存在的问题

上述实现中,如果下游的go协程从管道中接收到一个无法处理的值,并且导致它们退出,那么上游的go协程将因为往一个永远没有go协程读取的管道中发送值而永远阻塞。
解决方法之一是将无缓冲管道改为缓冲管道,并且缓冲大小至少为可能往管道发送的数据量。但是这要求事先知道可能发送的最大数据量。以下介绍一种无需提前知道可能发送数据量的解决办法。

显式取消

设置一个done通道,在创建阶段时显示地将done通道传入,利用select语句实现对任务的取消。

main函数:

func main() {
    // 设置done通道,当done通道关闭时所有工作取消并应该退出
    done := make(chan struct{})
    defer close(done)          

    in := gen(done, 2, 3)

    // sq将开启新的go协程从同管道in读取值并计算平方值,结果将发到管道c1, c2中
    c1 := sq(done, in)
    c2 := sq(done, in)

    // merge将几个管道的输出汇聚到一起
    out := merge(done, c1, c2)
    fmt.Println(<-out) // 4 or 9

    // 由defer语句,运行结束后done将关闭 
}

merge函数:

func merge(done <-chan struct{}, cs ...<-chan int) <-chan int {
    var wg sync.WaitGroup
    out := make(chan int)

    // merge为cs的每个管道开启一个go协程,并将读取的数据汇聚到out管道中。
    // 同时利用select语句接收任务结束的广播消息,进而不再继续监听管道。
    output := func(c <-chan int) {
        defer wg.Done()
        for n := range c {
            select {
            case out <- n:
            case <-done:	// done通道关闭后merge不再从上游管道接收值
                return
            }
        }
    }
    // ... the rest is unchanged ...

sq函数:

// sq函数同样利用select接收任务结束的通知
func sq(done <-chan struct{}, in <-chan int) <-chan int {
    out := make(chan int)
    go func() {
        defer close(out)
        for n := range in {
            select {
            case out <- n * n:
            case <-done:	// done通道关闭后将不再往out通道发送值
                return
            }
        }
    }()
    return out
}

故障分析:如果下游的sq函数遇到一个无法处理的值而退出,它将关闭其输出通道(defer close(out))。sq函数的输出通道作为merge的输入通道,当输入通道关闭时,merge创建的负责监听的go协程将退出。当所有监听的go协程都退出后,由上文所述,WaitGroupwait方法将返回,这个go协程将关闭out通道,这进一步导致了main函数将done通道关闭,进而所有go协程都将退出。这就避免了上游的go协程永远阻塞的情况。

实例

用MD5加密算法对一个目录下的所有文件进行digest。一个比较简单的方法是串行地对每个文件进行digest。

type result struct {
    path string
    sum  [md5.Size]byte
    err  error
}

sumFiles返回两个通道:一个用于接收结果,一个用于接收filepath.Walk的错误。walk函数开启一个新的go协程来处理每个文件。在Walk方法中,select语句监听done通道,当done通道关闭时,Walk方法将退出,不再继续遍历剩下的文件。外部的匿名函数则在done通道关闭后往error通道发送一个错误。sumFiles最后开启了一个额外的go协程,当WaitGroupwait方法返回后,它将关闭通道c

func sumFiles(done <-chan struct{}, root string) (<-chan result, <-chan error) {
    // 为目录下的每个文件开启一个go协程计算digest
    c := make(chan result)
    // errc负责发送错误给调用者
    errc := make(chan error, 1)
    go func() {
        var wg sync.WaitGroup
        err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }
            if !info.Mode().IsRegular() {
                return nil
            }
            wg.Add(1)
            go func() {
                data, err := ioutil.ReadFile(path)
                select {
                case c <- result{path, md5.Sum(data), err}:
                case <-done:
                }
                wg.Done()
            }()
            // Abort the walk if done is closed.
            select {
            case <-done:
                return errors.New("walk canceled")
            default:
                return nil
            }
        })
        // 额外的go协程负责关闭输出通道
        go func() {
            wg.Wait()
            close(c)
        }()
        // 无需select语句,因为errc通道有缓冲
        errc <- err
    }()
    return c, errc
}
func MD5All(root string) (map[string][md5.Size]byte, error) {
    // MD5All在返回前将关闭done通道,进而导致所有工作goroutine退出
    done := make(chan struct{})
    defer close(done)          

    c, errc := sumFiles(done, root)

    m := make(map[string][md5.Size]byte)
    for r := range c {
        if r.err != nil {
            return nil, r.err
        }
        m[r.path] = r.sum
    }
    if err := <-errc; err != nil {
        return nil, err
    }
    return m, nil
}

限制开启的go协程数

上述方法为目录中的每一个文件创建一个go协程,如果一个目录有许多文件,这将开启大量的go协程。我们可以创建一个固定数量的go协程。为此,我们不再在Walk方法直接创建go协程,我们让Walk方法将一个目录下的所有文件的路径输出到paths通道中,再开启固定数量的digesterpaths通道中读取文件路径,并对该路径下的文件进行digest。现在该pipeline有三个阶段:1. 遍历目录;2. digester读取并digest;3. 收集计算的digest。

func walkFiles(done <-chan struct{}, root string) (<-chan string, <-chan error) {
    paths := make(chan string)
    errc := make(chan error, 1)
    go func() {
        // Walk返回时将paths通道关闭
        defer close(paths)
        // No select needed for this send, since errc is buffered.
        errc <- filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
            if err != nil {
                return err
            }
            if !info.Mode().IsRegular() {
                return nil
            }
            select {
            // 将路径发到paths通道中
            case paths <- path:
            case <-done:
                return errors.New("walk canceled")
            }
            return nil
        })
    }()
    return paths, errc
}

中间阶段由digester读取路径,对路径下的文件进行digest:

func digester(done <-chan struct{}, paths <-chan string, c chan<- result) {
    for path := range paths {
        data, err := ioutil.ReadFile(path)
        select {
        case c <- result{path, md5.Sum(data), err}:
        case <-done:
            return
        }
    }
}

MD5All函数内开启固定数量的digester

// 开启固定数量的digester
    c := make(chan result)
    var wg sync.WaitGroup
    const numDigesters = 20
    wg.Add(numDigesters)
    for i := 0; i < numDigesters; i++ {
        go func() {
            digester(done, paths, c)
            wg.Done()
        }()
    }
    go func() {
        wg.Wait()
        close(c)
    }()

最后一个阶段从接收结果的通道中读取,退出前检查是否有错误发生:

m := make(map[string][md5.Size]byte)
    for r := range c {
        if r.err != nil {
            return nil, r.err
        }
        m[r.path] = r.sum
    }
    // Check whether the Walk failed.
    if err := <-errc; err != nil {
        return nil, err
    }
    return m, nil
}

总结

本文介绍了go的流式数据pipeline,以及如何构建一个正确的流水线。

参考资料

Go Concurrency Patterns: Pipelines and cancellation

Further reading:

Go Concurrency Patterns presents the basics of Go’s concurrency primitives and several ways to apply them.
Advanced Go Concurrency Patterns (video) covers more complex uses of Go’s primitives, especially select.
Douglas McIlroy’s paper Squinting at Power Series shows how Go-like concurrency provides elegant support for complex calculations.