上一篇笔者总结的grpc的文档只是整体介绍了一下grpc的框架和特性,但是一些细节的实现和一些概念性知识点依然存在疑惑,查了一些资料,都比较千篇一律,因此现在尝试学习grpc的源码来解答自己的疑惑
目标
- 了解grpc项目的总架构
- 了解grpc连接池、client/server端的解析工作
- 了解grpc拦截器等特性的设计模式
server端处理流程
lis, err := net.Listen("tcp", port)
if err != nil {
logs.Error("failed to listen, err is %s", err)
}
s := grpc.NewServer()
pb.RegisterUserServiceServer(s, &server{})
if err := s.Serve(lis); err != nil {
logs.Error("failed to serve, err is %s", err)
}
1、新建一个tcp端口监听
lis, err := net.Listen("tcp", port)
2、创建server
s := grpc.NewServer()
该方法即创建server结
构体,为属性赋值,比如拦截器等,拦截器的流程后面会详细讲
server的strut如下
type Server struct {
opts serverOptions
mu sync.Mutex // guards following
lis map[net.Listener]bool
conns map[transport.ServerTransport]bool
serve bool
drain bool
cv *sync.Cond // signaled when connections close for GracefulStop
m map[string]*service // service name -> service info
events trace.EventLog
quit *grpcsync.Event
done *grpcsync.Event
channelzRemoveOnce sync.Once
serveWG sync.WaitGroup // counts active Serve goroutines for GracefulStop
channelzID int64 // channelz unique identification number
czData *channelzData
}
其中比较重要的是几个map : list, cons m, list是tcp端口监听,m是service name对应的service信息
grpc.NewServer里可以携带参数,参数为opt …ServerOption,比如interceptors
3、注册server
func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
ht := reflect.TypeOf(sd.HandlerType).Elem()
st := reflect.TypeOf(ss)
if !st.Implements(ht) {
grpclog.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
}
s.register(sd, ss)
}
func (s *Server) register(sd *ServiceDesc, ss interface{}) {
s.mu.Lock()
defer s.mu.Unlock()
s.printf("RegisterService(%q)", sd.ServiceName)
if s.serve {
grpclog.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
}
if _, ok := s.m[sd.ServiceName]; ok {
grpclog.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
}
srv := &service{
server: ss,
md: make(map[string]*MethodDesc),
sd: make(map[string]*StreamDesc),
mdata: sd.Metadata,
}
for i := range sd.Methods {
d := &sd.Methods[i]
srv.md[d.MethodName] = d
}
for i := range sd.Streams {
d := &sd.Streams[i]
srv.sd[d.StreamName] = d
}
s.m[sd.ServiceName] = srv
}
代码很简单,其实就是将service的信息注入到server结构体中,比如将server的name注入到,还将每个service的每个method name封装成map,通过method name,可以找到对应的handler
ServiceDesc的结构如下:
type ServiceDesc struct {
ServiceName string
// The pointer to the service interface. Used to check whether the user
// provided implementation satisfies the interface requirements.
HandlerType interface{}
Methods []MethodDesc
Streams []StreamDesc
Metadata interface{}
}
Method Desc在实际业务中类似下面
var _UserService_serviceDesc = grpc.ServiceDesc{
ServiceName: "service.UserService",
HandlerType: (*UserServiceServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "login",
Handler: _UserService_Login_Handler,
},
{
MethodName: "getUserInfo",
Handler: _UserService_GetUserInfo_Handler,
},
{
MethodName: "updateUserInfo",
Handler: _UserService_UpdateUserInfo_Handler,
},
{
MethodName: "uploadProfilePic",
Handler: _UserService_UploadProfilePic_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "userService.proto",
}
代码中遇到的stream,可以先忽略,这个跟grpc中stream模式有关
4、grpc server的真正启动
上述的几步只是注册server和启动tcp端口,grpc server还未真正启动起来, grpc server的启动的接口为: Serve
代码如下:
func (s *Server) Serve(lis net.Listener) error {
------- 非关键的代码先注释掉
var tempDelay time.Duration // how long to sleep on accept failure
for {
rawConn, err := lis.Accept()
if err != nil {
if ne, ok := err.(interface {
Temporary() bool
}); ok && ne.Temporary() {
if tempDelay == 0 {
tempDelay = 5 * time.Millisecond
} else {
tempDelay *= 2
}
if max := 1 * time.Second; tempDelay > max {
tempDelay = max
}
s.mu.Lock()
s.printf("Accept error: %v; retrying in %v", err, tempDelay)
s.mu.Unlock()
timer := time.NewTimer(tempDelay)
select {
case <-timer.C:
case <-s.quit.Done():
timer.Stop()
return nil
}
continue
}
s.mu.Lock()
s.printf("done serving; Accept = %v", err)
s.mu.Unlock()
if s.quit.HasFired() {
return nil
}
return err
}
tempDelay = 0
// Start a new goroutine to deal with rawConn so we don't stall this Accept
// loop goroutine.
//
// Make sure we account for the goroutine so GracefulStop doesn't nil out
// s.conns before this conn can be added.
s.serveWG.Add(1)
go func() {
s.handleRawConn(rawConn)
s.serveWG.Done()
}()
}
}
可以看到大致逻辑是在一个for循环中,监听注册的tcp端口,如果有请求过来了,如果接受tcp请求时,发生err, 则休眠一段时间(刚开始休眠5ms, 然后不停翻倍,最多休眠1s), 则在一个协程中进行处理,同时使用waitgroup进行计数(这个跟优雅重启有关),所以核心处理逻辑是最后的handleRawConn方法
func (s *Server) handleRawConn(rawConn net.Conn) {
if s.quit.HasFired() {
rawConn.Close()
return
}
rawConn.SetDeadline(time.Now().Add(s.opts.connectionTimeout))
conn, authInfo, err := s.useTransportAuthenticator(rawConn)
if err != nil {
// ErrConnDispatched means that the connection was dispatched away from
// gRPC; those connections should be left open.
if err != credentials.ErrConnDispatched {
s.mu.Lock()
s.errorf("ServerHandshake(%q) failed: %v", rawConn.RemoteAddr(), err)
s.mu.Unlock()
channelz.Warningf(s.channelzID, "grpc: Server.Serve failed to complete security handshake from %q: %v", rawConn.RemoteAddr(), err)
rawConn.Close()
}
rawConn.SetDeadline(time.Time{})
return
}
// Finish handshaking (HTTP2)
st := s.newHTTP2Transport(conn, authInfo)
if st == nil {
return
}
rawConn.SetDeadline(time.Time{})
if !s.addConn(st) {
return
}
go func() {
s.serveStreams(st)
s.removeConn(st)
}()
}
可以看到handleRawConn方法中,先进行auth校验(这是grpc的特性,只不过一般的项目中未必会用到), 然后开启了一个http2连接st,
newHTTP2Transport就是建立一个http2的Transport,抓取的wireshark包如下图所示
看下面用到的st,最关键的代码是新开启的协程里的s.serveStreams(st),看里面的代码基本很明朗了:
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
sm := stream.Method()
if sm != "" && sm[0] == '/' {
sm = sm[1:]
}
pos := strings.LastIndex(sm, "/")
if pos == -1 {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"Malformed method name %q", []interface{}{sm}}, true)
trInfo.tr.SetError()
}
errDesc := fmt.Sprintf("malformed method name: %q", stream.Method())
if err := t.WriteStatus(stream, status.New(codes.ResourceExhausted, errDesc)); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
}
channelz.Warningf(s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
}
if trInfo != nil {
trInfo.tr.Finish()
}
return
}
service := sm[:pos]
method := sm[pos+1:]
srv, knownService := s.m[service]
if knownService {
if md, ok := srv.md[method]; ok {
s.processUnaryRPC(t, stream, srv, md, trInfo)
return
}
if sd, ok := srv.sd[method]; ok {
s.processStreamingRPC(t, stream, srv, sd, trInfo)
return
}
}
// Unknown service, or known server unknown method.
if unknownDesc := s.opts.unknownStreamDesc; unknownDesc != nil {
s.processStreamingRPC(t, stream, nil, unknownDesc, trInfo)
return
}
var errDesc string
if !knownService {
errDesc = fmt.Sprintf("unknown service %v", service)
} else {
errDesc = fmt.Sprintf("unknown method %v for service %v", method, service)
}
if trInfo != nil {
trInfo.tr.LazyPrintf("%s", errDesc)
trInfo.tr.SetError()
}
if err := t.WriteStatus(stream, status.New(codes.Unimplemented, errDesc)); err != nil {
if trInfo != nil {
trInfo.tr.LazyLog(&fmtStringer{"%v", []interface{}{err}}, true)
trInfo.tr.SetError()
}
channelz.Warningf(s.channelzID, "grpc: Server.handleStream failed to write status: %v", err)
}
if trInfo != nil {
trInfo.tr.Finish()
}
}
grpc请求的路径是service/method,因此可以看到代码前面是找到/的位置,然后分割得出service和name名称,再通过上面已经讲到的grpc server结构体找到service和method,再调用s.processUnaryRPC(t, stream, srv, md, trInfo)或者s.processStreamingRPC(t, stream, srv, sd, trInfo)进行调用,我们只看比较简单的unary rpc处理接口,可以看到方法中终于有对handler进行的处理
ctx := NewContextWithServerTransportStream(stream.Context(), stream)
reply, appErr := md.Handler(srv.server, ctx, df, s.opts.unaryInt)
以及返回信息
if err := s.sendResponse(t, stream, reply, cp, opts, comp); err != nil {
md.Handler就是调用创建server时创建的MethodDesc的handler,其中最后一个参数s.opts.unaryInt就是server端的interceptor,在第2步创建Server时塞给server结构体的
func NewServer(opt ...ServerOption) *Server {
opts := defaultServerOptions
for _, o := range opt {
o.apply(&opts)
}
s := &Server{
lis: make(map[net.Listener]bool),
opts: opts,
conns: make(map[transport.ServerTransport]bool),
m: make(map[string]*service),
quit: grpcsync.NewEvent(),
done: grpcsync.NewEvent(),
czData: new(channelzData),
}
chainUnaryServerInterceptors(s)
chainStreamServerInterceptors(s)
s.cv = sync.NewCond(&s.mu)
if EnableTracing {
_, file, line, _ := runtime.Caller(1)
s.events = trace.NewEventLog("grpc.Server", fmt.Sprintf("%s:%d", file, line))
}
if channelz.IsOn() {
s.channelzID = channelz.RegisterServer(&channelzServer{s}, "")
}
return s
}
每个method name对应一个handler接口,handler的接口逻辑如下:
func _UserService_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(LoginRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(UserServiceServer).Login(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/service.UserService/Login",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(UserServiceServer).Login(ctx, req.(*LoginRequest))
}
return interceptor(ctx, in, info, handler)
}
可以看到,如果没有设置拦截器,则直接调用服务中实现的接口进行返回,如果是,则调用拦截器方法进行处和返回
sendResponse代码如下:
即先进行encode,默认是proto, 也可以自定义,然后进行compress, 最后再回写给client
func (s *Server) sendResponse(t transport.ServerTransport, stream *transport.Stream, msg interface{}, cp Compressor, opts *transport.Options, comp encoding.Compressor) error {
data, err := encode(s.getCodec(stream.ContentSubtype()), msg)
if err != nil {
channelz.Error(s.channelzID, "grpc: server failed to encode response: ", err)
return err
}
compData, err := compress(data, cp, comp)
if err != nil {
channelz.Error(s.channelzID, "grpc: server failed to compress response: ", err)
return err
}
hdr, payload := msgHeader(data, compData)
// TODO(dfawley): should we be checking len(data) instead?
if len(payload) > s.opts.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", len(payload), s.opts.maxSendMessageSize)
}
err = t.Write(stream, hdr, payload, opts)
if err == nil && s.opts.statsHandler != nil {
s.opts.statsHandler.HandleRPC(stream.Context(), outPayload(false, msg, data, payload, time.Now()))
}
return err
}
总结
server端整体概括,就是监听端口,然后将受到的请求转发,根据method name转发到对应的handler进行处理
client端处理流程
client端发起grpc请求代码如下
1、建立连接
conn, err := grpc.Dial(address, grpc.WithInsecure(), grpc.WithBlock())
Dial方法返回clientConn
func Dial(target string, opts ...DialOption) (*ClientConn, error) {
return DialContext(context.Background(), target, opts...)
}
建立连接的核心代码为DialContext方法
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
csMgr: &connectivityStateManager{},
conns: make(map[*addrConn]struct{}),
dopts: defaultDialOptions(),
blockingpicker: newPickerWrapper(),
czData: new(channelzData),
firstResolveEvent: grpcsync.NewEvent(),
}
cc.retryThrottler.Store((*retryThrottler)(nil))
cc.ctx, cc.cancel = context.WithCancel(context.Background())
for _, opt := range opts {
opt.apply(&cc.dopts)
}
chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc)
defer func() {
if err != nil {
cc.Close()
}
}()
if channelz.IsOn() {
if cc.dopts.channelzParentID != 0 {
cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, cc.dopts.channelzParentID, target)
channelz.AddTraceEvent(cc.channelzID, 0, &channelz.TraceEventDesc{
Desc: "Channel Created",
Severity: channelz.CtINFO,
Parent: &channelz.TraceEventDesc{
Desc: fmt.Sprintf("Nested Channel(id:%d) created", cc.channelzID),
Severity: channelz.CtINFO,
},
})
} else {
cc.channelzID = channelz.RegisterChannel(&channelzChannel{cc}, 0, target)
channelz.Info(cc.channelzID, "Channel Created")
}
cc.csMgr.channelzID = cc.channelzID
}
if !cc.dopts.insecure {
if cc.dopts.copts.TransportCredentials == nil && cc.dopts.copts.CredsBundle == nil {
return nil, errNoTransportSecurity
}
if cc.dopts.copts.TransportCredentials != nil && cc.dopts.copts.CredsBundle != nil {
return nil, errTransportCredsAndBundle
}
} else {
if cc.dopts.copts.TransportCredentials != nil || cc.dopts.copts.CredsBundle != nil {
return nil, errCredentialsConflict
}
for _, cd := range cc.dopts.copts.PerRPCCredentials {
if cd.RequireTransportSecurity() {
return nil, errTransportCredentialsMissing
}
}
}
if cc.dopts.defaultServiceConfigRawJSON != nil {
scpr := parseServiceConfig(*cc.dopts.defaultServiceConfigRawJSON)
if scpr.Err != nil {
return nil, fmt.Errorf("%s: %v", invalidDefaultServiceConfigErrPrefix, scpr.Err)
}
cc.dopts.defaultServiceConfig, _ = scpr.Config.(*ServiceConfig)
}
cc.mkp = cc.dopts.copts.KeepaliveParams
if cc.dopts.copts.Dialer == nil {
cc.dopts.copts.Dialer = func(ctx context.Context, addr string) (net.Conn, error) {
network, addr := parseDialTarget(addr)
return (&net.Dialer{}).DialContext(ctx, network, addr)
}
if cc.dopts.withProxy {
cc.dopts.copts.Dialer = newProxyDialer(cc.dopts.copts.Dialer)
}
}
if cc.dopts.copts.UserAgent != "" {
cc.dopts.copts.UserAgent += " " + grpcUA
} else {
cc.dopts.copts.UserAgent = grpcUA
}
if cc.dopts.timeout > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, cc.dopts.timeout)
defer cancel()
}
defer func() {
select {
case <-ctx.Done():
conn, err = nil, ctx.Err()
default:
}
}()
scSet := false
if cc.dopts.scChan != nil {
// Try to get an initial service config.
select {
case sc, ok := <-cc.dopts.scChan:
if ok {
cc.sc = &sc
scSet = true
}
default:
}
}
if cc.dopts.bs == nil {
cc.dopts.bs = backoff.DefaultExponential
}
// Determine the resolver to use.
cc.parsedTarget = grpcutil.ParseTarget(cc.target)
channelz.Infof(cc.channelzID, "parsed scheme: %q", cc.parsedTarget.Scheme)
resolverBuilder := cc.getResolver(cc.parsedTarget.Scheme)
if resolverBuilder == nil {
// If resolver builder is still nil, the parsed target's scheme is
// not registered. Fallback to default resolver and set Endpoint to
// the original target.
channelz.Infof(cc.channelzID, "scheme %q not registered, fallback to default scheme", cc.parsedTarget.Scheme)
cc.parsedTarget = resolver.Target{
Scheme: resolver.GetDefaultScheme(),
Endpoint: target,
}
resolverBuilder = cc.getResolver(cc.parsedTarget.Scheme)
if resolverBuilder == nil {
return nil, fmt.Errorf("could not get resolver for default scheme: %q", cc.parsedTarget.Scheme)
}
}
creds := cc.dopts.copts.TransportCredentials
if creds != nil && creds.Info().ServerName != "" {
cc.authority = creds.Info().ServerName
} else if cc.dopts.insecure && cc.dopts.authority != "" {
cc.authority = cc.dopts.authority
} else {
// Use endpoint from "scheme://authority/endpoint" as the default
// authority for ClientConn.
cc.authority = cc.parsedTarget.Endpoint
}
if cc.dopts.scChan != nil && !scSet {
// Blocking wait for the initial service config.
select {
case sc, ok := <-cc.dopts.scChan:
if ok {
cc.sc = &sc
}
case <-ctx.Done():
return nil, ctx.Err()
}
}
if cc.dopts.scChan != nil {
go cc.scWatcher()
}
var credsClone credentials.TransportCredentials
if creds := cc.dopts.copts.TransportCredentials; creds != nil {
credsClone = creds.Clone()
}
cc.balancerBuildOpts = balancer.BuildOptions{
DialCreds: credsClone,
CredsBundle: cc.dopts.copts.CredsBundle,
Dialer: cc.dopts.copts.Dialer,
ChannelzParentID: cc.channelzID,
Target: cc.parsedTarget,
}
// Build the resolver.
rWrapper, err := newCCResolverWrapper(cc, resolverBuilder)
if err != nil {
return nil, fmt.Errorf("failed to build resolver: %v", err)
}
cc.mu.Lock()
cc.resolverWrapper = rWrapper
cc.mu.Unlock()
// A blocking dial blocks until the clientConn is ready.
if cc.dopts.block {
for {
s := cc.GetState()
if s == connectivity.Ready {
break
} else if cc.dopts.copts.FailOnNonTempDialError && s == connectivity.TransientFailure {
if err = cc.blockingpicker.connectionError(); err != nil {
terr, ok := err.(interface {
Temporary() bool
})
if ok && !terr.Temporary() {
return nil, err
}
}
}
if !cc.WaitForStateChange(ctx, s) {
// ctx got timeout or canceled.
return nil, ctx.Err()
}
}
}
return cc, nil
}
DialContext核心逻辑如下:
1.1、初始化ClientConn结构体
可以看到初始化了ClientConn结构体
cc := &ClientConn{
target: target,
csMgr: &connectivityStateManager{},
conns: make(map[*addrConn]struct{}),
dopts: defaultDialOptions(),
blockingpicker: newPickerWrapper(),
czData: new(channelzData),
firstResolveEvent: grpcsync.NewEvent(),
}
target是ip/port, dopts是一些参数,比如常用的WithInsecure, WithBlock
下面的代码就是将Dial时添加的opts添加到defaultOptios中,其实本质上是通过这些opts方法设置options的一些属性
for _, opt := range opts {
opt.apply(&cc.dopts)
}
看defaultDialOptions方法如下:
type dialOptions struct {
unaryInt UnaryClientInterceptor
streamInt StreamClientInterceptor
chainUnaryInts []UnaryClientInterceptor
chainStreamInts []StreamClientInterceptor
cp Compressor
dc Decompressor
bs internalbackoff.Strategy
block bool
insecure bool
timeout time.Duration
scChan <-chan ServiceConfig
authority string
copts transport.ConnectOptions
callOptions []CallOption
// This is used by v1 balancer dial option WithBalancer to support v1
// balancer, and also by WithBalancerName dial option.
balancerBuilder balancer.Builder
channelzParentID int64
disableServiceConfig bool
disableRetry bool
disableHealthCheck bool
healthCheckFunc internal.HealthChecker
minConnectTimeout func() time.Duration
defaultServiceConfig *ServiceConfig // defaultServiceConfig is parsed from defaultServiceConfigRawJSON.
defaultServiceConfigRawJSON *string
// This is used by ccResolverWrapper to backoff between successive calls to
// resolver.ResolveNow(). The user will have no need to configure this, but
// we need to be able to configure this in tests.
resolveNowBackoff func(int) time.Duration
resolvers []resolver.Builder
withProxy bool
}
可以看到block,insecure等参数,WithInsecure这些方法,最终是给defaultOptions的对应的参数赋值
另外,还初始化了connectivityStateManager以及newPickerWrapper参数,
connectivityStateManager是指连接状态管理,每个连接具有 “IDLE”、“CONNECTING”、“READY”、“TRANSIENT_FAILURE”、“SHUTDOWN”、“Invalid-State” 这几种状态
newPickerWrapper跟负载均衡有关,里面封装subConn,表示一个server多个实例的address列表,具体逻辑需要再进行深度探讨
1.2、接口构建拦截器
chainUnaryClientInterceptors构建拦截器
func chainUnaryClientInterceptors(cc *ClientConn) {
interceptors := cc.dopts.chainUnaryInts
// Prepend dopts.unaryInt to the chaining interceptors if it exists, since unaryInt will
// be executed before any other chained interceptors.
if cc.dopts.unaryInt != nil {
interceptors = append([]UnaryClientInterceptor{cc.dopts.unaryInt}, interceptors...)
}
var chainedInt UnaryClientInterceptor
if len(interceptors) == 0 {
chainedInt = nil
} else if len(interceptors) == 1 {
chainedInt = interceptors[0]
} else {
chainedInt = func(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, invoker UnaryInvoker, opts ...CallOption) error {
return interceptors[0](ctx, method, req, reply, cc, getChainUnaryInvoker(interceptors, 0, invoker), opts...)
}
}
cc.dopts.unaryInt = chainedInt
}
可以可做把所有的拦截器合并成了一个UnaryClientInterceptor结构体,并且这个结构体采用的是责任链设计模式(cc.dopts.unaryInt基本取得是第一个拦截器,并且设置第一个拦截器的invoker方法为第二个拦截器,也就是执行第一个拦截器的逻辑时,会再执行第二个拦截器的方法,依次往后类推,所以是责任链设计模式)
1.3、根据option参数进行一些逻辑处理
根据options的一些参数,比如timeout, autority等设置超时,权限验证等参数
2、根据创建的conn创建client
userServiceClient := pb.NewUserServiceClient(conn)
NewUserServiceClient是pb生成的go文件中的方法,其实就是封装了conn而已
3、rpc调用
func (c *userServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*UserInfoDTO, error) {
out := new(UserInfoDTO)
err := c.cc.Invoke(ctx, "/service.UserService/login", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
通过调用cc的Invoke方法进行处理,这个cc就是client封装的conn
Invoke方法代码为
func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
// allow interceptor to see all applicable call options, which means those
// configured as defaults from dial option as well as per-call options
opts = combine(cc.dopts.callOptions, opts)
if cc.dopts.unaryInt != nil {
return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
}
return invoke(ctx, method, args, reply, cc, opts...)
}
再看最后的invoke方法:
func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
if err != nil {
return err
}
if err := cs.SendMsg(req); err != nil {
return err
}
return cs.RecvMsg(reply)
}
可以看到有个SendMsg和RecvMsg
3.1 SendMsg
SendMsg代码如下:
func (cs *clientStream) SendMsg(m interface{}) (err error) {
defer func() {
if err != nil && err != io.EOF {
// Call finish on the client stream for errors generated by this SendMsg
// call, as these indicate problems created by this client. (Transport
// errors are converted to an io.EOF error in csAttempt.sendMsg; the real
// error will be returned from RecvMsg eventually in that case, or be
// retried.)
cs.finish(err)
}
}()
if cs.sentLast {
return status.Errorf(codes.Internal, "SendMsg called after CloseSend")
}
if !cs.desc.ClientStreams {
cs.sentLast = true
}
// load hdr, payload, data
hdr, payload, data, err := prepareMsg(m, cs.codec, cs.cp, cs.comp)
if err != nil {
return err
}
// TODO(dfawley): should we be checking len(data) instead?
if len(payload) > *cs.callInfo.maxSendMessageSize {
return status.Errorf(codes.ResourceExhausted, "trying to send message larger than max (%d vs. %d)", len(payload), *cs.callInfo.maxSendMessageSize)
}
msgBytes := data // Store the pointer before setting to nil. For binary logging.
op := func(a *csAttempt) error {
err := a.sendMsg(m, hdr, payload, data)
// nil out the message and uncomp when replaying; they are only needed for
// stats which is disabled for subsequent attempts.
m, data = nil, nil
return err
}
err = cs.withRetry(op, func() { cs.bufferForRetryLocked(len(hdr)+len(payload), op) })
if cs.binlog != nil && err == nil {
cs.binlog.Log(&binarylog.ClientMessage{
OnClientSide: true,
Message: msgBytes,
})
}
return
}
先通过prepareMsg准备数据,再通过csAttempt的sendMsg发送数据
func (a *csAttempt) sendMsg(m interface{}, hdr, payld, data []byte) error {
cs := a.cs
if a.trInfo != nil {
a.mu.Lock()
if a.trInfo.tr != nil {
a.trInfo.tr.LazyLog(&payload{sent: true, msg: m}, true)
}
a.mu.Unlock()
}
if err := a.t.Write(a.s, hdr, payld, &transport.Options{Last: !cs.desc.ClientStreams}); err != nil {
if !cs.desc.ClientStreams {
// For non-client-streaming RPCs, we return nil instead of EOF on error
// because the generated code requires it. finish is not called; RecvMsg()
// will call it with the stream's status independently.
return nil
}
return io.EOF
}
if a.statsHandler != nil {
a.statsHandler.HandleRPC(cs.ctx, outPayload(true, m, data, payld, time.Now()))
}
if channelz.IsOn() {
a.t.IncrMsgSent()
}
return nil
}
a.t中这个t是Transport
3.2 RecvMsg
func (cs *clientStream) RecvMsg(m interface{}) error {
if cs.binlog != nil && !cs.serverHeaderBinlogged {
// Call Header() to binary log header if it's not already logged.
cs.Header()
}
var recvInfo *payloadInfo
if cs.binlog != nil {
recvInfo = &payloadInfo{}
}
err := cs.withRetry(func(a *csAttempt) error {
return a.recvMsg(m, recvInfo)
}, cs.commitAttemptLocked)
if cs.binlog != nil && err == nil {
cs.binlog.Log(&binarylog.ServerMessage{
OnClientSide: true,
Message: recvInfo.uncompressedBytes,
})
}
if err != nil || !cs.desc.ServerStreams {
// err != nil or non-server-streaming indicates end of stream.
cs.finish(err)
if cs.binlog != nil {
// finish will not log Trailer. Log Trailer here.
logEntry := &binarylog.ServerTrailer{
OnClientSide: true,
Trailer: cs.Trailer(),
Err: err,
}
if logEntry.Err == io.EOF {
logEntry.Err = nil
}
if peer, ok := peer.FromContext(cs.Context()); ok {
logEntry.PeerAddr = peer.Addr
}
cs.binlog.Log(logEntry)
}
}
return err
}
也是会调用csAttempt的recvMsg方法,然后经过层层调用会到达parser的recvMsg方法
func (p *parser) recvMsg(maxReceiveMessageSize int) (pf payloadFormat, msg []byte, err error) {
if _, err := p.r.Read(p.header[:]); err != nil {
return 0, nil, err
}
pf = payloadFormat(p.header[0])
length := binary.BigEndian.Uint32(p.header[1:])
if length == 0 {
return pf, nil, nil
}
if int64(length) > int64(maxInt) {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max length allowed on current machine (%d vs. %d)", length, maxInt)
}
if int(length) > maxReceiveMessageSize {
return 0, nil, status.Errorf(codes.ResourceExhausted, "grpc: received message larger than max (%d vs. %d)", length, maxReceiveMessageSize)
}
// TODO(bradfitz,zhaoq): garbage. reuse buffer after proto decoding instead
// of making it for each message:
msg = make([]byte, int(length))
if _, err := p.r.Read(msg); err != nil {
if err == io.EOF {
err = io.ErrUnexpectedEOF
}
return 0, nil, err
}
return pf, msg, nil
}
可以看到方法的第一行Read函数,r类型是io.Reader,这个形式跟golang/java发起http调用然后读取返回参数是一致的
总结
本文通过对调用流程进行追踪,总体梳理了一下grpc的源码,大致了解了grpc内部的运行机制,但是很多细节以及其他的特性其实没有深究,本次可以看做是grpc的源码入门。 后续有时间,会再度对grpc的代码框架、网络连接等源码进行深究