示例:并发时钟服务器

本节介绍 net 包,它提供构建客户端和服务器程序的组件,这些程序通过 TCP、UDP 或者 UNIX 套接字进行通信。网络服务 net/http 包是在 net 包的基础上构建的。

时钟服务器

这个示例是一个时钟服务器,它以每秒一次的频率向客户端发送当前时间:

package main

import (
    "io"
    "log"
    "net"
    "time"
)

func main() {
    listener, err := net.Listen("tcp", "localhost:8000")
    if err != nil {
        log.Fatal(err)
    }
    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Print(err) // 例如,连接终止
            continue
        }
        handleConn(conn) // 一次处理一个连接
    }
}

func handleConn(c net.Conn) {
    defer c.Close()
    for {
        _, err := io.WriteString(c, time.Now().Format("2006/01/02 15:04:05\r\n"))
        if err != nil {
            return // 例如,连接断开
        }
        time.Sleep(1 * time.Second)
    }
}

Listen 函数创建一个 net.Listener 对象,它在一个网络端口上监听进来的连接,这里是 TCP 端口 localhost:8000。监听器的 Accept 方法被阻塞,知道有连接请求进来,然后返回 net.Conn 对象来代表一个连接。
handleConn 函数处理一个完整的客户端连接。在循环里,它将 time.Now() 获取的当前时间发送给客户端。因为 net.Conn 满足 io.Writer 接口,所以可以直接向它进行写入。当写入失败时循环结束,很多时候是客户端断开连接,这是 handleConn 函数使用延迟(defer)的 Close 调用关闭自己这边的连接,然后继续等待下一个连接请求。
为了连接到服务器,还需要一个 socket 客户端,这里可以先使用系统的 telnet 来进行验证:

$ telnet localhost 8000

这里可以开两个 telnet 尝试进行连接,只有第一个可以连接上,而其他的连接会阻塞。当把第一个客户端的连接断开后,服务端会重新返回到 main 函数的 for 循环中等待新的连接。此时之前阻塞的一个连接就能连接进来,继续显示时间。服务端程序暂时先这样,先来实现一个 socket 客户端程序。

客户端 netcat

下面的客户端使用 net.Dial 实现了 Go 版本的 netcat 程序,用来连接 TCP服务器:

package main

import (
    "io"
    "log"
    "net"
    "os"
)

func main() {
    conn, err := net.Dial("tcp", "localhost:8000")
    if err != nil {
        log.Fatal(err)
    }
    defer conn.Close()
    mustCopy(os.Stdout, conn)
}

func mustCopy(dst io.Writer, src io.Reader) {
    if _, err := io.Copy(dst, src); err != nil {
        log.Fatal(err)
    }
}

这个程序从网络连接中读取,然后写到标准输出,直到到达 EOF 或者出错。

支持并发的服务器

如果打开多个客户端,同时只有一个客户端能正常工作。第二个客户端必须等到第一个结束才能正常工作,这是因为服务器是顺序的,一次只能处理一个客户请求。让服务器支持并发只需要一个很小的改变:在调用 handleConn 的地方添加一个 go 关键字,使它在自己的 goroutine 内执行:

for {
    conn, err := listener.Accept()
    if err != nil {
        log.Print(err) // 例如,连接终止
        continue
    }
    go handleConn(conn) // 并发处理连接
}

现在的版本,多个客户端可以同时接入并正常工作了。

示例:并发回声服务器

上面的时钟服务器每个连接使用一个 goroutine。下面要实现的这个回声服务器,每个连接使用多个 goroutine 来处理。大多数的回声服务器仅仅将读到的内容写回去,所以可以使用下面简单的 handleConn 版本:

func handleConn(c net.Conn) {
    io.Copy(c, c)
    c.Close()
}

有趣的回声服务端

下面的这个版本可以重复3次,第一个全大写,第二次正常,第三次全消息:

// reverb1
package main

import (
    "bufio"
    "fmt"
    "io"
    "log"
    "net"
    "strings"
    "time"
)

func echo(c net.Conn, shout string, delay time.Duration) {
    fmt.Fprintln(c, "\t", strings.ToUpper(shout))
    time.Sleep(delay)
    fmt.Fprintln(c, "\t", shout)
    time.Sleep(delay)
    fmt.Fprintln(c, "\t", strings.ToLower(shout))
}

func handleConn(c net.Conn) {
    input := bufio.NewScanner(c)
    for input.Scan() {
        echo(c, input.Text(), 1*time.Second)
    }
    // 注意:忽略 input.Err() 中可能的错误
    c.Close()
}

func handleConn0(c net.Conn) {
    io.Copy(c, c)
    c.Close()
}

func main() {
    listener, err := net.Listen("tcp", "localhost:8000")
    if err != nil {
        log.Fatal(err)
    }
    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Print(err) // 例如,连接终止
            continue
        }
        go handleConn(conn) // 并发处理连接
    }
}

在上一个示例中,已经知道需要使用 go 关键字调用 handleConn 函数。不过在这个例子中,重点不是处理多个客户端的连接,所以这里不是重点。

升级客户端

现在来升级一下客户端,使它可以在终端上向服务器输入,还可以将服务器的回复复制到输出,这里提供了另一个使用并发的机会:

package main

import (
    "io"
    "log"
    "net"
    "os"
)

func mustCopy(dst io.Writer, src io.Reader) {
    if _, err := io.Copy(dst, src); err != nil {
        log.Fatal(err)
    }
}

func main() {
    conn, err := net.Dial("tcp", ":8000")
    if err != nil {
        log.Fatal(err)
    }
    defer conn.Close()
    go mustCopy(os.Stdout, conn)
    mustCopy(conn, os.Stdin)
}

优化服务端

使用上面的服务端版本,如果有多个连续的输入,新输入的内容不会马上返回,而是要等待之前输入的内容全部返回后才会处理之后的内容。要想做的更好,需要更多的 goroutine。再一次,在调用 echo 时需要加入 go 关键字:

// reverb2
func handleConn(c net.Conn) {
    input := bufio.NewScanner(c)
    for input.Scan() {
        go echo(c, input.Text(), 1*time.Second)
    }
    // 注意:忽略 input.Err() 中可能的错误
    c.Close()
}

这个改进的版本,回声也是并发的,在时间上互相重合。

小结

这就是使服务器变成并发所要做的,不仅处理来自多个客户端的链接,还包括在一个连接处理中,使用多个 go 关键字。在这个例子里,单个客户端连接也可以同时发起多个请求。在最初的版本里,没有使用 go 调用 echo,所以处理单个客户端的请求不是并发的,只有前一个处理完才会继续处理下一个。之后改进的版本,使用 go 调用 echo,这里对每一个请求的处理都是并发的了。
然而,在添加这些 go 关键字的同时,必须要仔细考虑方法 net.Conn 的并发调用是不是安全的,对大多数类型来讲,这都是不安全的。

接收完回声再结束

之前的客户端在主 goroutine 中将输入复制到服务器中,这样的客户端在输入接收后立即退出,即使后台的 goroutine 还在继续。为了让程序等待后台的 goroutine 在完成后再退出,使用一个通道来同步两个 goroutine:

func main() {
    conn, err := net.Dial("tcp", ":8000")
    if err != nil {
        log.Fatal(err)
    }
    done := make(chan struct{})
    go func() {
        io.Copy(os.Stdout, conn) // 注意:忽略错误
        log.Println("done")
        done <- struct{}{} // 通知主 goroutine 的信号
    }()
    mustCopy(conn, os.Stdin)
    conn.Close()
    <-done // 等待后台 goroutine 完成
}

当用户关闭标准输入流(Windows系统使用Ctrl+Z)时,mustCopy 返回,主 goroutine 调用 conn.Close() 来关闭两端网络连接。关闭写半边的连接会导致服务器看到 EOF。关闭读半边的连接导致后台 goroutine 调用 io.Copy 返回 “read from closed connection” 错误,所以这个版本里去掉了打印错误日志。

客户端优化

上面这个版本使用起来的效果和之前的版本并没有太大的差别,几乎看不到差别。虽然多了等待连接关闭,但是依然不会等待接收完毕所有服务器的返回。不过这步解决了等待 goroutine 运行完毕后,主 goroutine 才会结束。使用下面的 TCP 链接,就可以实现接收完毕所有信息后,goroutine 才会结束。在 net 包中,conn 接口有一个具体的类型 *net.TCPConn,它代表一个 TCP 连接:

tcpAddr, err := net.ResolveTCPAddr("tcp", ":8000")
if err != nil {
    log.Fatal(err)
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
    log.Fatal(err)
}

TCP 链接由两半边组成,可以通过 CloseRead 和 CloseWrite 方法分别关闭。修改主 goroutine,仅仅关闭连接的写半边,这样程序可以继续执行输出来自 reverb1 服务器的回声,即使标准输入已经关闭:

package main

import (
    "io"
    "log"
    "net"
    "os"
)

func main() {
    tcpAddr, err := net.ResolveTCPAddr("tcp", ":8000")
    if err != nil {
        log.Fatal(err)
    }
    conn, err := net.DialTCP("tcp", nil, tcpAddr)
    if err != nil {
        log.Fatal(err)
    }
    done := make(chan struct{})
    go func() {
        io.Copy(os.Stdout, conn) // 注意:忽略错误
        log.Println("done")
        done <- struct{}{} // 通知主 goroutine 的信号
    }()
    mustCopy(conn, os.Stdin)
    conn.CloseWrite()
    <-done // 等待后台 goroutine 完成
}

func mustCopy(dst io.Writer, src io.Reader) {
    if _, err := io.Copy(dst, src); err != nil {
        log.Fatal(err)
    }
}

现在只对第一个回声服务器版本 reverb1 有效,对于之后改进的可以并发处理同一个客户端多个请求的 reverb2 服务器,服务端还需要做一些修改。

服务端优化

在 reverb2 服务器的版本中,因为对于每一个连接,每一次回声的请求都会生成一个新的 goroutine 进行处理。为了知道什么时候最后一个 goroutine 结束(有时候不一定是最后启动的那个),需要在每一个 goroutine 启动千递增计数,在每一个 goroutine 结束时递减计数。这需要一个特殊设计的计数器,它可以被多个 goroutine 安全地操作,然后又一个方法一直等到他变为 0。这个计数器类型是 sync.WaitGroup。下面是完整的服务器代码:

package main

import (
    "bufio"
    "fmt"
    "log"
    "net"
    "strings"
    "sync"
    "time"
)

var wg sync.WaitGroup // 工作 goroutine 的个数

func echo(c net.Conn, shout string, delay time.Duration) {
    defer wg.Done()
    fmt.Fprintln(c, "\t", strings.ToUpper(shout))
    time.Sleep(delay)
    fmt.Fprintln(c, "\t", shout)
    time.Sleep(delay)
    fmt.Fprintln(c, "\t", strings.ToLower(shout))
}

func handleConn(c net.Conn) {
    input := bufio.NewScanner(c)
    for input.Scan() {
        wg.Add(1)
        go echo(c, input.Text(), 2*time.Second)
    }
    // 注意:忽略 input.Err() 中可能的错误
    wg.Wait()
    c.Close()
}

func main() {
    listener, err := net.Listen("tcp", "localhost:8000")
    if err != nil {
        log.Fatal(err)
    }
    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Print(err) // 例如,连接终止
            continue
        }
        go handleConn(conn) // 并发处理连接
    }
}

注意 Add 和 Done 方法的不对称性。Add 递增计数器,它必须工作在 goroutine 开始之前执行,而不是在中间。另外,Add 有一个参数,但 Done 没有,它等价于 Add(-1)。使用 defer 来确保计数器在任何情况下都可以递减。在不知道迭代次数的情况下,上面的代码结构是通用的,符合习惯的并行循环模式。

超时断开

下面的版本增加了超时断开的功能。这样服务端和客户端就各有两个断开连接的情况了,原本只有一种。
服务端原本只要被动等待客户端断开就可以了,这个逻辑原本原本放在主 goroutine 中。现在服务端超时需要主动断开,客户端断开了,需要被动断开,这2个逻辑都需要一个单独的 goroutine,而主 goroutine 则阻塞接收这两个情况的通道,任意一个通道有数据,就断开并退出。
客户端原本只需要响应接收标准输入的 Ctrl+Z 然后断开写半边的连接,这个逻辑也需要从主 goroutine 放到一个新的 goroutine 中。另外一种断开的连接是被动响应服务端的断开连接然后客户端也退出。这里还要稍微在复杂一点,如果是服务端的超时断开,则直接断开。如果是客户端的主动断开,则还需要继续等待服务端的断开,然后再退出。
这里用到了大量的 select 多路复用:

package main

import (
    "flag"
    "fmt"
    "log"
    "sync"

    "gopl/ch5/links"
)

var count = make(chan int) // 统计一共爬取了多个页面

// 令牌 tokens 是一个计数信号量
// 确保并发请求限制在 20 个以内
var tokens = make(chan struct{}, 20)

func crawl(url string, depth int) urllist {
    fmt.Println(depth, <-count, url)
    tokens <- struct{}{} // 获取令牌
    list, err := links.Extract(url)
    <-tokens // 释放令牌
    if err != nil {
        log.Print(err)
    }
    return urllist{list, depth + 1}
}

var depth int

func init() {
    flag.IntVar(&depth, "depth", -1, "深度限制") // 小于0就是不限制递归深度,0就是只爬取当前页面
}

type urllist struct {
    urls  []string
    depth int
}

func main() {
    // 负责 count 值自增的 goroutine
    go func() {
        var i int
        for {
            i++
            count <- i
        }
    }()

    flag.Parse()
    worklist := make(chan urllist)
    // 等待发送到任务列表的数量
    // 因为需要在 goroutine 里修改,需要换成并发安全的计数器
    var n sync.WaitGroup
    starturls := flag.Args()
    if len(flag.Args()) == 0 {
        starturls = []string{"http://lab.scrapyd.cn/"}
    }

    // 从命令行参数开始
    n.Add(1)
    go func() { worklist <- urllist{starturls, 0} }()
    // 等待全部worklist处理完,就关闭worklist
    go func() {
        n.Wait()
        close(worklist)
    }()

    // 并发爬取 Web
    seen := make(map[string]bool)
    for list := range worklist {
        // 处理完一个worklist后才能让 n 计数器减1
        // 而处理 worklist 又是很多个 goroutine,所以需要再用一个计数器
        var n2 sync.WaitGroup
        for _, link := range list.urls {
            if !seen[link] {
                seen[link] = true
                n2.Add(1)
                go func(url string, listDepth int) {
                    nextList := crawl(url, listDepth)
                    // 如果 depth>0 说明有深度限制
                    // 如果当前的深度已经达到(或超过)深度限制,则爬取完这个连接后,不需要再继续爬取,直接返回
                    if depth >= 0 && listDepth >= depth {
                        // 超出递归深度的页面,在爬取完之后,也输出 URL
                        // for _, nextUrl := range nextList.urls {
                        //  fmt.Println(nextList.depth, "stop", nextUrl)
                        // }
                        n2.Done() // 所有退出的情况都要减计数器n2的值,但是一定要在向通道发送之前
                        return
                    }
                    n.Add(1)             // 添加任务前,计数加1
                    n2.Done()            // 先确保计数器n加1了,再减计数器n2的值
                    worklist <- nextList // 新的任务加入管道必须在最后,之后再一次for循环迭代的时候,才会接收这个值
                }(link, list.depth)
            }
        }
        n2.Wait()
        n.Done()
        // 把计数器的操作也放到 goroutine 中,这样可以继续下一次 for 循环的迭代
        // go func() {
        //  n2.Wait()
        //  n.Done()
        // }()
    }
}

示例:聊天服务器

实现一个聊天服务器,它可以在几个用户之间相互广播文本消息。
这个程序中有四种 goroutine:

  • 主 goroutine,就是 main 函数
  • 广播(broadcaster)goroutine。非常好的展示了 select 用法,因为它需要处理三种不同类型的消息
  • 每一个连接里有一个连接处理(handleConn)goroutine
  • 每一个连接里还有一个客户写入(clientWriter)goroutine

主函数

主函数的工作是监听端口,接受连接请求。对每一个连接,它创建一个新的 handleConn。就像之前的并发回声服务器中那样:

func main() {
    listener, err := net.Listen("tcp", ":8000")
    if err != nil {
        log.Fatal(err)
    }
    go broadcaster()
    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Print(err)
            continue
        }
        go handleConn(conn)
    }
}

广播器

广播器,它的变量 clients 会记录当前连接的客户集合。其记录的内容是每一个客户端对外发送消息的通道:

// 广播器
type client chan<- string // 对外发送消息的通道

var (
    entering = make(chan client)
    leaving  = make(chan client)
    messages = make(chan string) // 所有接受的客户消息
)

func broadcaster() {
    clients := make(map[client]bool) // 所有连接的客户端集合
    for {
        select {
        case msg := <-messages:
            // 把所有接收的消息广播给所有的客户
            // 发送消息通道
            for cli := range clients {
                cli <- msg
            }
        case cli := <-entering:
            clients[cli] = true
        case cli := <-leaving:
            delete(clients, cli)
            close(cli)
        }
    }
}

广播器监听两个全局的通道 entering 和 leaving。通过它们通知有客户进入和离开,如果从一个通道中接收到事件,它将更新 clients 集合。如果是客户离开,还会关闭对应客户对外发送消息的通道。
广播器还监听 messages 通道,所有的客户都会将要广播的消息发送到这个通道。当收到一个消息后,就会把消息广播给所有客户。

客户端处理函数

handleConn 函数创建一个对外发送消息的新通道,然后通过 entering 通道通知广播器新客户进入。接着,要读取客户发来的每一条消息,通过 messages 通道将每一条消息发送给广播器,发送时再每条消息前面加上发送者的ID作为前缀。一旦客户端将消息读取完毕,handleConn 通过 leaving 通道通知客户离开,然后关闭连接:

// 客户端处理函数: handleConn
func handleConn(conn net.Conn) {
    ch := make(chan string) // 对外发送客户消息的通道
    go clientWriter(conn, ch)

    who := conn.RemoteAddr().String()
    ch <- "You are " + who           // 这条单发给自己
    messages <- who + " has arrived" // 这条进行进行广播,但是自己还没加到广播列表中
    entering <- ch                   // 然后把自己加到广播列表中

    input := bufio.NewScanner(conn)
    for input.Scan() {
        messages <- who + ": " + input.Text()
    }
    // 注意,忽略input.Err()中可能的错误

    leaving <- ch
    messages <- who + " has left"
    conn.Close()
}

另外,handleConn 函数还为每一个客户创建了写入(clientWriter)goroutine,每个客户都从自己的通道中接收消息发送给客户端的网络连接。在广播器收到 leaving 通知并关闭这个接收消息的通道后,clientWriter 会结束通道的遍历后运行结束:

// 客户端处理函数: clientWriter
func clientWriter(conn net.Conn, ch <-chan string) {
    for msg := range ch {
        // 在消息结尾使用 \r\n ,提升平台兼容
        fmt.Fprintf(conn, "%s\r\n", msg) // 注意,忽略网络层面的错误
    }
}

给客户端发送的消息字符串需要用"\n"结尾。如果换成"\r\n"结尾,平台的兼容性应该会更好。至少windows上的telnet客户端可以直接使用了。

使用客户端进行聊天

完整的源码就是上面的四段代码,拼在一起就能运行了。
和之前使用回声服务器一样,可以用 telnet 或者也可以用之前写的 netcat 作为客户端来聊天。
当有 n 个客户 session 在连接的时候,程序并发运行着 2n+2 个相互通信的 goroutine,它不需要隐式的加锁操作也能做到并发安全。clients map 被限制在广播器这一个 goroutine 中,所以不会被并发的访问。唯一被多个 goroutine 共享的变量是通道以及 net.Conn 的实例,它们也都是并发安全的。

聊天服务器功能扩展

上面的聊天服务器提供了一个很好的架构,现在再在其之上扩展功能就很方便了。

通知当前的用户列表

在新用户到来之后,告知该新用户当前在聊天室的所有的用户列表。每个用户加入后,系统都会自动生成一个用户名(基于用户的网络连接,之后会添加设置用户名的功能),就是要把这些存在的用户名打印出来。
所有的用户列表只在广播器的 clients map 中,但是这个 map 又不包括用户名。所以先要修改数据类型,把每个连接的数据结构加上一个新的用户名字段:

type client chan<- string // 对外发送消息的通道
type clientInfo struct {
    name string
    ch   client
}

原本使用 client 作为元素的通道和 map,现在全部也都要换成 clientInfo 作为元素。像新用户发送当前用户列表的任务也在广播器中完成:

// 广播器
type client chan<- string // 对外发送消息的通道
type clientInfo struct {
    name string
    ch   client
}

var (
    entering = make(chan clientInfo)
    leaving  = make(chan clientInfo)
    messages = make(chan string) // 所有接受的客户消息
)

func broadcaster() {
    clients := make(map[clientInfo]bool) // 所有连接的客户端集合
    for {
        select {
        case msg := <-messages:
            // 把所有接收的消息广播给所有的客户
            // 发送消息通道
            for cli := range clients {
                cli.ch <- msg
            }
        case cli := <-entering:
            // 在每一个新用户到来的时候,通知当前存在的用户
            var users []string
            for cli := range clients {
                users = append(users, cli.name)
            }
            if len(users) > 0 {
                cli.ch <- fmt.Sprintf("Other users in room: %s", strings.Join(users, "; "))
            } else {
                cli.ch <- "You are the only user in this room."
            }

            clients[cli] = true
        case cli := <-leaving:
            delete(clients, cli)
            close(cli.ch)
        }
    }
}

客户端处理函数还需要做少量的修改,主要是因为数据结构变了。原本给 entering 和 leaving 通道发送的是 ch。现在要发送封装好 who 的结构体。客户端处理函数的代码略,之后的扩展中会贴出来:

cli := clientInfo{who, ch}
entering <- cli

断掉长时间空闲的客户端

如果在一段时间里,客户端没有任何输入,服务器就将客户端断开。之前的逻辑是,客户端处理函数会一直在阻塞在 input.Scan() 这里等待客户端输入。只要在另外一个 goroutine 中调用 conn.Close(),就可以让当前阻塞的读操作变成非阻塞,就像 input.Scan() 输入完成的读操作一样。不过这么做的话会有一点小问题,原本在主 goroutine 的结尾有一个conn.Close()操作,现在在定时的 goroutine 中还需要有一个关闭的操作。如果因为定时而结束的,就会有两次关闭操作。
这里关闭的是 socket 连接,本质上就是文件句柄。尝试多次关闭貌似不会有什么问题,不过要解决这个问题也不难。一种是把响应用户输入的操作也放到 goroutine 中。现有有两个 goroutine 在运行,主 goroutine 则只要一直阻塞,通过一个通道等待其中任何一个 goroutine 完成后发送的信号即可。这样关闭的操作只在主 goroutine 中操作。下面的是客户端处理函数,包括上一个功能里修改的部分:

// 客户端处理函数: handleConn
func handleConn(conn net.Conn) {
    ch := make(chan string) // 对外发送客户消息的通道
    go clientWriter(conn, ch)

    who := conn.RemoteAddr().String()
    cli := clientInfo{who, ch}       // 打包好用户名和通道
    ch <- "You are " + who           // 这条单发给自己
    messages <- who + " has arrived" // 这条进行进行广播,但是自己还没加到广播列表中
    entering <- cli                  // 然后把自己加到广播列表中

    done := make(chan struct{}, 2) // 等待下面两个 goroutine 其中一个执行完成。使用缓冲通道防止 goroutine 泄漏
    // 计算超时的goroutine
    inputSignal := make(chan struct{}) // 有任何输入,就发送一个信号
    timeout := 15 * time.Second        // 客户端空闲的超时时间
    go func() {
        timer := time.NewTimer(timeout)
        for {
            select {
            case <-inputSignal:
                timer.Reset(timeout)
            case <-timer.C:
                // 超时,断开连接
                done <- struct{}{}
                return
            }
        }
    }()

    go func() {
        input := bufio.NewScanner(conn)
        for input.Scan() {
            inputSignal <- struct{}{}
            if len(strings.TrimSpace(input.Text())) == 0 { // 禁止发送纯空白字符
                continue
            }
            messages <- who + ": " + input.Text()
        }
        // 注意,忽略input.Err()中可能的错误
        done <- struct{}{}
    }()

    <-done
    leaving <- cli
    messages <- who + " has left"
    conn.Close()
}

这里还简单加了一个限制客户端发送空消息的功能,在 input.Scan() 循环中。空消息不会发送广播,但是可以重置定时器的时间。

客户端可以输入名字

在客户端连接后,不立刻进入聊天室,而是先输入一个名字。考虑到名字不能和已有的名字重复,而现有的名字都保存在广播器里的 clients 这个 map 中。所以客户端输入的名字需要在 clients 中查找一下是否已经有人用了。现在有了按名字进行查找的需求,clients 类型更适合使用一个以名字为 key 的 map 而不是原本的集合。这个 map 的 value 就是向该客户发送消息的通道,也就是最初这个集合的 key 的值:

clients := make(map[string]client) // 所有连接的客户端集合

客户端处理函数
在客户端处理函数的开头,需要增加注册用户名的过程。用户名注册的处理过程比较复杂,所以单独封装到了一个函数 clientRegiste 中:

// 客户端处理函数
func handleConn(conn net.Conn) {
    who := clientRegiste(conn) // 新增这一行,注册获取用户名

    ch := make(chan string) // 对外发送客户消息的通道
    go clientWriter(conn, ch)

    // who := conn.RemoteAddr().String() // 去掉这一行
    // 之后的代码不变
}

这里使用一个交互的方式来获取用户名,代替原本通过连接的信息自动生成。这个函数是串行的,只有在返回用户名后,才会继续执行下去。之后的代码和之前是一样的。
在 clientRegiste 函数中,不停的和终端进行交互,处理收到的消息,如果用户名可用,继续执行之后的流程。如果用户名不可用,则提示用户继续处理:

// 客户端处理函数 clientRegiste
// 注册用户名
func clientRegiste(conn net.Conn) (who string) {
    ch := make(chan bool)
    fmt.Fprint(conn, "input nickname: ") // 注意,忽略网络层面的错误
    input := bufio.NewScanner(conn)
    for input.Scan() {
        if len(strings.TrimSpace(input.Text())) == 0 { // 禁止发送纯空白字符
            continue
        }
        who = input.Text()
        register <- registeInfo{who, ch}
        if <-ch {
            break
        }
        fmt.Fprintf(conn, "name %q is existed\r\ntry other name: ", who)
    }
    // 注意,忽略input.Err()中可能的错误
    return who
}

这里只有最简单的功能,还可以增加输入超时,以及尝试次数的限制。所以把这个函数独立出来完成功能,更方便之后对注册函数进行扩展。
函数的主要逻辑就是 input.Scan() 的循环,这和 handleConn 中的循环十分相似。如果之后再加上输入超时,这两段的处理逻辑只有极小部分的差别,所以这部分代码也可以单独写一个函数。这里避免过早的优化,暂时就先这样,看着也比较清晰。之后要添加超时功能的时候,再把这部分重复的代码独立出来。这部分优化最后完整的代码里会有。

广播器
在广播器的 select 里要加一个分支,用来处理用户名的请求。收到请求后,判断是否已经存在,把结果返回给 clientRegiste。因为 clients 是只有广播器可见的,所以这里要使用通道传递过来,判断后再用通道把结果传回去。这样可以保证 clients 变量只在这一个 goroutine 里被使用(包括修改)。另外,每个客户端的注册都使用一个通道将注册信息发送给广播器,但是广播器返回的内容,需要对每个客户端使用不同的通道。所以这里,广播器新创建了专门用于注册交互的数据结构:

type registeInfo struct {
    name string
    ch   chan<- bool
}

var register = make(chan registeInfo) // 注册用户名的通道

客户注册的函数创建一个布尔型的通道,加上用户的名字封装到 registeInfo 结构体中。然后广播器判断后,把结果通道 registeInfo 里的 ch 字段这个通道,把结果返回给对应的客户注册函数。
下面是广播器 broadcaster 的代码,主要是 select 新增了一个分支,处理注册用户名:

// 广播器
type client chan<- string // 对外发送消息的通道
type clientInfo struct {
    name string
    ch   client
}

var (
    entering = make(chan clientInfo)
    leaving  = make(chan clientInfo)
    messages = make(chan string) // 所有接受的客户消息
)

type registeInfo struct {
    name string
    ch   chan<- bool
}

var register = make(chan registeInfo) // 注册用户名的通道

func broadcaster() {
    clients := make(map[string]client) // 所有连接的客户端集合
    for {
        select {
        case msg := <-messages:
            // 把所有接收的消息广播给所有的客户
            // 发送消息通道
            for _, cli := range clients {
                cli <- msg
            }
        case user := <-register:
            // 先判断新用户名是否有重复
            _, ok := clients[user.name]
            user.ch <- !ok
        case cliSt := <-entering:
            // 在每一个新用户到来的时候,通知当前存在的用户
            var users []string
            for user := range clients {
                users = append(users, user)
            }
            if len(users) > 0 {
                cliSt.ch <- fmt.Sprintf("Other users in room: %s", strings.Join(users, "; "))
            } else {
                cliSt.ch <- "You are the only user in this room."
            }

            clients[cliSt.name] = cliSt.ch
        case cliSt := <-leaving:
            delete(clients, cliSt.name)
            close(cliSt.ch)
        }
    }
}

预防客户端延迟影响

最后还有一个问题,就是客户端可能会卡或者延迟,但是客户端的问题不能影响到服务器的正常运行。不过我没法实现一个这样的有延迟的客户端,默认操作系统应该就已经非常友好的帮我们处理掉了,把从网络上接收到的数据暂存在缓冲区里(对于TCP连接还有乱序重组和超时重传,这些我们都不需要关心了),等待程序去读取。代码里接收的操作应该是直接从缓冲区读取,这时服务的已经发送完毕了。所以现在只能照着下面的思路写了:

任何客户程序读取数据的时间很长最终会造成所有的客户卡住。修改广播器,使它满足如果一个向客户写入的通道没有准备好接受它,那么跳过这条消息。还可以给每一个向客户发送消息的通道增加缓冲,这样大多数的消息不会丢弃;广播器在这个通道上应该使用非阻塞的发送方式。

客户端处理函数中创建的发送消息的通道改用有缓冲区的通道:

// 客户端处理函数
func handleConn(conn net.Conn) {
    defer conn.Close() // 退出时关闭客户端连接,现在有分支了,并且可能会提前退出

    who, ok := clientRegiste(conn) // 注册获取用户名
    if !ok { // 用户名未注册成功
        fmt.Fprintln(conn, "\r\nName registe failed...")
        return
    }

    ch := make(chan string, 10) // 有缓冲区,对外发送客户消息的通道
    go clientWriter(conn, ch)

    // 省略后面的代码
}

然后广播器的 select 对应的 messages 通道的分支,改成非阻塞的方式:

select {
case msg := <-messages:
    // 把所有接收的消息广播给所有的客户
    // 发送消息通道
    for name, cli := range clients {
        select {
        case cli <- msg:
        default:
            fmt.Fprintf(os.Stderr, "send message failed: %s: %s\n", name, msg)
        }
    }
// 其他分支略过
}

下面是聊天服务器最后完整的代码。这里的改变还包括了上一节最后提到的注册用户名时的输入的超时。已经两次用到了输入超时,分别在 handleConn 和 clientRegiste 中,这里也就把这部分代码单独写了一个函数 inputWithTimeout。完整代码如下:

package main

import (
    "bufio"
    "fmt"
    "log"
    "net"
    "os"
    "strings"
    "time"
)

func main() {
    listener, err := net.Listen("tcp", ":8000")
    if err != nil {
        log.Fatal(err)
    }
    go broadcaster()
    for {
        conn, err := listener.Accept()
        if err != nil {
            log.Print(err)
            continue
        }
        go handleConn(conn)
    }
}

// 广播器
type client chan<- string // 对外发送消息的通道
type clientInfo struct {
    name string
    ch   client
}

var (
    entering = make(chan clientInfo)
    leaving  = make(chan clientInfo)
    messages = make(chan string) // 所有接受的客户消息
)

type registeInfo struct {
    name string
    ch   chan<- bool
}

var register = make(chan registeInfo) // 注册用户名的通道

func broadcaster() {
    clients := make(map[string]client) // 所有连接的客户端集合
    for {
        select {
        case msg := <-messages:
            // 把所有接收的消息广播给所有的客户
            // 发送消息通道
            for name, cli := range clients {
                select {
                case cli <- msg:
                default:
                    fmt.Fprintf(os.Stderr, "send message failed: %s: %s\n", name, msg)
                }
            }
        case user := <-register:
            // 先判断新用户名是否有重复
            _, ok := clients[user.name]
            user.ch <- !ok
        case cliSt := <-entering:
            // 在每一个新用户到来的时候,通知当前存在的用户
            var users []string
            for user := range clients {
                users = append(users, user)
            }
            if len(users) > 0 {
                cliSt.ch <- fmt.Sprintf("Other users in room: %s", strings.Join(users, "; "))
            } else {
                cliSt.ch <- "You are the only user in this room."
            }

            clients[cliSt.name] = cliSt.ch
        case cliSt := <-leaving:
            delete(clients, cliSt.name)
            close(cliSt.ch)
        }
    }
}

// 客户端处理函数
func handleConn(conn net.Conn) {
    defer conn.Close() // 退出时关闭客户端连接,现在有分支了,并且可能会提前退出

    who, ok := clientRegiste(conn) // 注册获取用户名
    if !ok {                       // 用户名未注册成功
        fmt.Fprintln(conn, "\r\nName registe failed...")
        return
    }

    ch := make(chan string, 10) // 有缓冲区,对外发送客户消息的通道
    go clientWriter(conn, ch)

    cli := clientInfo{who, ch}       // 打包好用户名和通道
    ch <- "You are " + who           // 这条单发给自己
    messages <- who + " has arrived" // 现在这条广播自己也能收到
    entering <- cli

    inputFunc := func(sig chan<- struct{}) {
        input := bufio.NewScanner(conn)
        for input.Scan() {
            sig <- struct{}{}                              // 向 sig 发送信号,会重新开始计时
            if len(strings.TrimSpace(input.Text())) == 0 { // 禁止发送纯空白字符
                continue
            }
            messages <- who + ": " + input.Text()
        }
        // 注意,忽略input.Err()中可能的错误
    }
    inputWithTimeout(conn, 300*time.Second, inputFunc)

    leaving <- cli
    messages <- who + " has left"
}

func clientWriter(conn net.Conn, ch <-chan string) {
    for msg := range ch {
        // windows 需要 \r 了正常显示
        fmt.Fprintln(conn, msg+"\r") // 注意,忽略网络层面的错误
    }
}

// 注册用户名
func clientRegiste(conn net.Conn) (who string, ok bool) {
    inputFunc := func(sig chan<- struct{}) {
        input := bufio.NewScanner(conn)
        ch := make(chan bool)
        fmt.Fprint(conn, "input nickname: ") // 注意,忽略网络层面的错误
        for input.Scan() {
            if len(strings.TrimSpace(input.Text())) == 0 { // 禁止发送纯空白字符
                continue
            }
            who = input.Text()
            register <- registeInfo{who, ch}
            if <-ch {
                ok = true
                break
            }
            fmt.Fprintf(conn, "name %q is existed\r\ntry other name: ", who)
        }
        // 注意,忽略input.Err()中可能的错误
    }
    inputWithTimeout(conn, 15*time.Second, inputFunc)
    return who, ok
}

// 为 input.Scan 封装超时退出的功能
func inputWithTimeout(conn net.Conn, timeout time.Duration, input func(sig chan<- struct{})) {
    done := make(chan struct{}, 2)
    inputSignal := make(chan struct{})
    go func() {
        timer := time.NewTimer(timeout)
        for {
            select {
            case <-inputSignal:
                timer.Reset(timeout)
            case <-timer.C:
                // 超时,断开连接
                done <- struct{}{}
                return
            }
        }
    }()

    go func() {
        input(inputSignal)
        done <- struct{}{}
    }()

    <-done
}