Grpc 源码学习

grpc 是谷歌推出的一个高性能、开源和通用的 RPC 框架,基于 HTTP2,使用 Protocol Buffer 来定义消息结构。

架构​

image.png

定义消息结构

syntax = "proto3";

package pb;
//定义一个Chunker服务
service Chunker {
    //提供一个Chunker给外部调用
    rpc Chunker(Empty) returns (stream Chunk) {}
}

message Empty{}

message Chunk {
    bytes chunk = 1;
}

服务端​

//服务提供者
type chunkerSrv []byte //定义一个类型这个类型必须实现Chunker函数

func (c chunkerSrv) Chunker(_ *pb.Empty, srv pb.Chunker_ChunkerServer) error {
	//........
}

func main() {
	listen, err := net.Listen("tcp", ":8888")
	if err != nil {
		log.Fatal(err)
	}
	s := grpc.NewServer()
	blob := make([]byte, 1024*1024*1024*2) //2G
	rand.Read(blob)
	pb.RegisterChunkerServer(s, chunkerSrv(blob)) //注册服务
	log.Println("serving on localhost:8888")
	log.Fatal(s.Serve(listen))
}

GRPC服务注册过程​

服务注册主要是这个函数 pb.RegisterChunkerServer(s, chunkerSrv(blob)) 这个函数是 proto 文件自动生成的​

// chunkerSrv类型实现了ChunkerServer这个接口
type ChunkerServer interface { //proto自动生成
	Chunker(*Empty, Chunker_ChunkerServer) error
}

func RegisterChunkerServer(s *grpc.Server, srv ChunkerServer) {
	s.RegisterService(&_Chunker_serviceDesc, srv) //35行
}

// 一个rpc服务的定义类
var _Chunker_serviceDesc = grpc.ServiceDesc{
	ServiceName: "pb.Chunker",          //proto文件中定义的包名.服务名
	HandlerType: (*ChunkerServer)(nil), //这个服务函数提供者的类型
	Methods:     []grpc.MethodDesc{},   //服务提供的方法集合 如果不是流式传输会定义在这里
	Streams: []grpc.StreamDesc{ //流式传输定义在这里
		{
			StreamName:    "Chunker",                //函数名
			Handler:       _Chunker_Chunker_Handler, //处理函数
			ServerStreams: true,                     // 服务端流式 RPC
		},
	},
	Metadata: "chunker.proto", //元数据
}

type StreamHandler func(srv interface{}, stream ServerStream) error

// 流式rpc服务定义
type StreamDesc struct {
	//函数名
	StreamName    string
	Handler       StreamHandler // 函数处理器
	ServerStreams bool          // 服务端流式 RPC
	ClientStreams bool          // 客户端流式 RPC
}

// 服务端调用的处理函数
func _Chunker_Chunker_Handler(srv interface{}, stream grpc.ServerStream) error {
	m := new(Empty)
	if err := stream.RecvMsg(m); err != nil {
		return err
	}
	//srv是我们注册的chunkerSrv类型,调用他的Chunker方法来供请求使用
	return srv.(ChunkerServer).Chunker(m, &chunkerChunkerServer{stream})
}

// 注册的逻辑
func (s *Server) RegisterService(sd *ServiceDesc, ss interface{}) {
	//判断服务提供者和定义好的服务是不是同一类型
	if ss != nil {
		ht := reflect.TypeOf(sd.HandlerType).Elem()
		st := reflect.TypeOf(ss)
		if !st.Implements(ht) {
			logger.Fatalf("grpc: Server.RegisterService found the handler of type %v that does not satisfy %v", st, ht)
		}
	}
	//进行注册
	s.register(sd, ss)
}

func (s *Server) register(sd *ServiceDesc, ss interface{}) {
	s.mu.Lock()
	defer s.mu.Unlock()
	s.printf("RegisterService(%q)", sd.ServiceName)
	if s.serve {
		logger.Fatalf("grpc: Server.RegisterService after Server.Serve for %q", sd.ServiceName)
	}
	if _, ok := s.services[sd.ServiceName]; ok {
		logger.Fatalf("grpc: Server.RegisterService found duplicate service registration for %q", sd.ServiceName)
	}
	info := &serviceInfo{ //一个服务的对象
		serviceImpl: ss,                           //服务的具体实现类
		methods:     make(map[string]*MethodDesc), //普通传输的方法
		streams:     make(map[string]*StreamDesc), //流式传输方法
		mdata:       sd.Metadata,
	}
	for i := range sd.Methods {
		d := &sd.Methods[i]
		info.methods[d.MethodName] = d
	}
	for i := range sd.Streams {
		d := &sd.Streams[i]
		info.streams[d.StreamName] = d
	}
	s.services[sd.ServiceName] = info //完成注册
}

监听请求​

通过 s.Serve(listen) 方法来监听请求​

func (s *Server) Serve(lis net.Listener) error {
	//......略
	for {
		rawConn, err := lis.Accept()
		if err != nil {
			//......略
			//如果发生错误,执行重试
		}
		go func() {
			s.handleRawConn(lis.Addr().String(), rawConn) //处理请求
			s.serveWG.Done()
		}()
	}
}

func (s *Server) handleRawConn(lisAddr string, rawConn net.Conn) {
	//........略

	// 使用http2传输
	st := s.newHTTP2Transport(conn, authInfo)
	if st == nil {
		return
	}

	rawConn.SetDeadline(time.Time{})
	if !s.addConn(lisAddr, st) {
		return
	}
	go func() {
		s.serveStreams(st)
		s.removeConn(lisAddr, st)
	}()
}

func (s *Server) serveStreams(st transport.ServerTransport) {
	defer st.Close()
	var wg sync.WaitGroup

	var roundRobinCounter uint32
	st.HandleStreams(func(stream *transport.Stream) {
		wg.Add(1)
		if s.opts.numServerWorkers > 0 {
			data := &serverWorkerData{st: st, wg: &wg, stream: stream}
			select {
			case s.serverWorkerChannels[atomic.AddUint32(&roundRobinCounter, 1)%s.opts.numServerWorkers] <- data:
			default:
				// If all stream workers are busy, fallback to the default code path.
				go func() {
					s.handleStream(st, stream, s.traceInfo(st, stream)) //真正的处理逻辑
					wg.Done()
				}()
			}
		} else {
			go func() {
				defer wg.Done()
				s.handleStream(st, stream, s.traceInfo(st, stream)) //真正的处理逻辑
			}()
		}
	}, func(ctx context.Context, method string) context.Context {
		if !EnableTracing {
			return ctx
		}
		tr := trace.New("grpc.Recv."+methodFamily(method), method)
		return trace.NewContext(ctx, tr)
	})
	wg.Wait()
}

处理请求​

//处理逻辑
func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Stream, trInfo *traceInfo) {
	sm := stream.Method()
	if sm != "" && sm[0] == '/' {
		sm = sm[1:]
	}
	//.....略
	service := sm[:pos]  //从stream里面获取请求的服务
	method := sm[pos+1:] //请求的方法名

	srv, knownService := s.services[service]
	if knownService { //如果服务在服务器中注册了
		if md, ok := srv.methods[method]; ok {
			s.processUnaryRPC(t, stream, srv, md, trInfo) //执行普通rpc
			return
		}
		if sd, ok := srv.streams[method]; ok {
			s.processStreamingRPC(t, stream, srv, sd, trInfo) //执行流式rpc
			return
		}
	}
	//。。。。略 如果此服务没有注册的异常处理
}

// 执行流式rpc
func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transport.Stream, info *serviceInfo, sd *StreamDesc, trInfo *traceInfo) (err error) {
	//.....略  一些前置处理
	ss := &serverStream{
		ctx:                   ctx,
		t:                     t,
		s:                     stream,
		p:                     &parser{r: stream},
		codec:                 s.getCodec(stream.ContentSubtype()),
		maxReceiveMessageSize: s.opts.maxReceiveMessageSize, //配置最大接收消息限制
		maxSendMessageSize:    s.opts.maxSendMessageSize,    //配置最大发送消息限制
		trInfo:                trInfo,
		statsHandler:          sh,
	}

	var appErr error
	var server interface{}
	if info != nil {
		server = info.serviceImpl //这个就是我们注册好的函数提供者
	}
	if s.opts.streamInt == nil {
		appErr = sd.Handler(server, ss) //执行Handler函数 也是我们传进来的
	} else {
		info := &StreamServerInfo{
			FullMethod:     stream.Method(),
			IsClientStream: sd.ClientStreams,
			IsServerStream: sd.ServerStreams,
		}
		appErr = s.opts.streamInt(server, ss, info, sd.Handler)
	}
	//..... 略 返回值和错误处理
	return err
}

客户端​

eventti 请求调用 cmdb 的节点搜索 rpc 服务为例​

syntax = "proto3";

package rpc;

service NodeService {
  rpc SearchNode(ByJqlAndUsernameRequest) returns (common.SimpleJsonResponse) {};
}
// GrpcClient grpc客户端
type GrpcClient struct {
	Conn *grpc.ClientConn
}

// NewGrpcClient 新建grpc客户端
func NewGrpcClient() (*GrpcClient, error) {
	// 连接服务端接口
	conn, err := grpc.Dial(common.AppConfigInstance.GwayCfg.Address, grpc.WithUserAgent("sky-eventti"), grpc.WithInsecure())
	if err != nil {
		logs.Error(err)
		return nil, err
	}
	res := &GrpcClient{
		Conn: conn,
	}
	return res, nil
}
// SearchEventTypeMapping cmdb查询所有事件类型
func SearchEventTypeMapping(gc *GrpcClient) ([]models.EventTypeVO, error) {
	//....... 略
	client := rpc.NewNodeServiceClient(gc.Conn)
	defer gc.Close()
	reply, err := client.SearchNode(context.Background(), &rpc.ByJqlAndUsernameRequest{
		Jql: "label = eventTypeMapping",
	})
	if err != nil {
		logs.Error(err)
		return nil, err

	}
	jsonStr := reply.GetJsonStr()
	//.......略
}

// proto生成的函数
func (c *nodeServiceClient) SearchNode(ctx context.Context, in *ByJqlAndUsernameRequest, opts ...grpc.CallOption) (*common.SimpleJsonResponse, error) {
	out := new(common.SimpleJsonResponse)
	err := c.cc.Invoke(ctx, "/rpc.NodeService/SearchNode", in, out, opts...)
	if err != nil {
		return nil, err
	}
	return out, nil
}

// grpc连接的抽象接口 (grpc.ClientConn实现了它)
type ClientConnInterface interface {
	// Invoke performs a unary RPC and returns after the response is received
	// into reply.
	Invoke(ctx context.Context, method string, args interface{}, reply interface{}, opts ...CallOption) error
	// NewStream begins a streaming RPC.
	NewStream(ctx context.Context, desc *StreamDesc, method string, opts ...CallOption) (ClientStream, error)
}

func (cc *ClientConn) Invoke(ctx context.Context, method string, args, reply interface{}, opts ...CallOption) error {
	// 合并拦截器参数与请求参数
	opts = combine(cc.dopts.callOptions, opts)
	// 拦截器不为空,调用拦截器
	if cc.dopts.unaryInt != nil {
		return cc.dopts.unaryInt(ctx, method, args, reply, cc, invoke, opts...)
	}
	// 发起调用
	return invoke(ctx, method, args, reply, cc, opts...)
}

func invoke(ctx context.Context, method string, req, reply interface{}, cc *ClientConn, opts ...CallOption) error {
	// 创建客户端流
	cs, err := newClientStream(ctx, unaryStreamDesc, cc, method, opts...)
	if err != nil {
		return err
	}
	// 发送数据
	if err := cs.SendMsg(req); err != nil {
		return err
	}
	// 接收数据
	return cs.RecvMsg(reply)
}

以上就是 grpc 从服务注册然后到请求服务的整个流程。​