Skip to content

Commit

Permalink
grpc: Move some stats handler calls to gRPC layer, and add local addr…
Browse files Browse the repository at this point in the history
…ess to peer.Peer (#6716)
  • Loading branch information
zasweq authored Oct 25, 2023
1 parent 6e14274 commit c76d75f
Show file tree
Hide file tree
Showing 8 changed files with 135 additions and 161 deletions.
65 changes: 35 additions & 30 deletions internal/transport/handler_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,25 @@ func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []s
return nil, errors.New(msg)
}

var localAddr net.Addr
if la := r.Context().Value(http.LocalAddrContextKey); la != nil {
localAddr, _ = la.(net.Addr)
}
var authInfo credentials.AuthInfo
if r.TLS != nil {
authInfo = credentials.TLSInfo{State: *r.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
p := peer.Peer{
Addr: strAddr(r.RemoteAddr),
LocalAddr: localAddr,
AuthInfo: authInfo,
}
st := &serverHandlerTransport{
rw: w,
req: r,
closedCh: make(chan struct{}),
writes: make(chan func()),
peer: p,
contentType: contentType,
contentSubtype: contentSubtype,
stats: stats,
Expand Down Expand Up @@ -134,6 +148,8 @@ type serverHandlerTransport struct {

headerMD metadata.MD

peer peer.Peer

closeOnce sync.Once
closedCh chan struct{} // closed on Close

Expand Down Expand Up @@ -165,7 +181,13 @@ func (ht *serverHandlerTransport) Close(err error) {
})
}

func (ht *serverHandlerTransport) RemoteAddr() net.Addr { return strAddr(ht.req.RemoteAddr) }
func (ht *serverHandlerTransport) Peer() *peer.Peer {
return &peer.Peer{
Addr: ht.peer.Addr,
LocalAddr: ht.peer.LocalAddr,
AuthInfo: ht.peer.AuthInfo,
}
}

// strAddr is a net.Addr backed by either a TCP "ip:port" string, or
// the empty string if unknown.
Expand Down Expand Up @@ -347,10 +369,8 @@ func (ht *serverHandlerTransport) WriteHeader(s *Stream, md metadata.MD) error {
return err
}

func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*Stream)) {
// With this transport type there will be exactly 1 stream: this HTTP request.

ctx := ht.req.Context()
var cancel context.CancelFunc
if ht.timeoutSet {
ctx, cancel = context.WithTimeout(ctx, ht.timeout)
Expand All @@ -370,34 +390,19 @@ func (ht *serverHandlerTransport) HandleStreams(startStream func(*Stream)) {
ht.Close(errors.New("request is done processing"))
}()

ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
req := ht.req

s := &Stream{
id: 0, // irrelevant
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
}
pr := &peer.Peer{
Addr: ht.RemoteAddr(),
}
if req.TLS != nil {
pr.AuthInfo = credentials.TLSInfo{State: *req.TLS, CommonAuthInfo: credentials.CommonAuthInfo{SecurityLevel: credentials.PrivacyAndIntegrity}}
}
ctx = metadata.NewIncomingContext(ctx, ht.headerMD)
s.ctx = peer.NewContext(ctx, pr)
for _, sh := range ht.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: ht.RemoteAddr(),
Compression: s.recvCompress,
}
sh.HandleRPC(s.ctx, inHeader)
id: 0, // irrelevant
ctx: ctx,
requestRead: func(int) {},
cancel: cancel,
buf: newRecvBuffer(),
st: ht,
method: req.URL.Path,
recvCompress: req.Header.Get("grpc-encoding"),
contentSubtype: ht.contentSubtype,
headerWireLength: 0, // won't have access to header wire length until golang/go#18997.
}
s.trReader = &transportReader{
reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf, freeBuffer: func(*bytes.Buffer) {}},
Expand Down
10 changes: 5 additions & 5 deletions internal/transport/handler_server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ func (s) TestHandlerTransport_HandleStreams(t *testing.T) {
st.ht.WriteStatus(s, status.New(codes.OK, ""))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
context.Background(), func(s *Stream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -347,7 +347,7 @@ func handleStreamCloseBodyTest(t *testing.T, statusCode codes.Code, msg string)
st.ht.WriteStatus(s, status.New(statusCode, msg))
}
st.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
context.Background(), func(s *Stream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -396,7 +396,7 @@ func (s) TestHandlerTransport_HandleStreams_Timeout(t *testing.T) {
ht.WriteStatus(s, status.New(codes.DeadlineExceeded, "too slow"))
}
ht.HandleStreams(
func(s *Stream) { go runStream(s) },
context.Background(), func(s *Stream) { go runStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down Expand Up @@ -448,7 +448,7 @@ func (s) TestHandlerTransport_HandleStreams_WriteStatusWrite(t *testing.T) {
func testHandlerTransportHandleStreams(t *testing.T, handleStream func(st *handleStreamTest, s *Stream)) {
st := newHandleStreamTest(t)
st.ht.HandleStreams(
func(s *Stream) { go handleStream(st, s) },
context.Background(), func(s *Stream) { go handleStream(st, s) },
)
}

Expand Down Expand Up @@ -481,7 +481,7 @@ func (s) TestHandlerTransport_HandleStreams_ErrDetails(t *testing.T) {
hst.ht.WriteStatus(s, st)
}
hst.ht.HandleStreams(
func(s *Stream) { go handleStream(s) },
context.Background(), func(s *Stream) { go handleStream(s) },
)
wantHeader := http.Header{
"Date": nil,
Expand Down
82 changes: 27 additions & 55 deletions internal/transport/http2_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,15 +69,12 @@ var serverConnectionCounter uint64
// http2Server implements the ServerTransport interface with HTTP2.
type http2Server struct {
lastRead int64 // Keep this field 64-bit aligned. Accessed atomically.
ctx context.Context
done chan struct{}
conn net.Conn
loopy *loopyWriter
readerDone chan struct{} // sync point to enable testing.
writerDone chan struct{} // sync point to enable testing.
remoteAddr net.Addr
localAddr net.Addr
authInfo credentials.AuthInfo // auth info about the connection
peer peer.Peer
inTapHandle tap.ServerInHandle
framer *framer
// The max number of concurrent streams.
Expand Down Expand Up @@ -243,13 +240,15 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
}

done := make(chan struct{})
peer := peer.Peer{
Addr: conn.RemoteAddr(),
LocalAddr: conn.LocalAddr(),
AuthInfo: authInfo,
}
t := &http2Server{
ctx: setConnection(context.Background(), rawConn),
done: done,
conn: conn,
remoteAddr: conn.RemoteAddr(),
localAddr: conn.LocalAddr(),
authInfo: authInfo,
peer: peer,
framer: framer,
readerDone: make(chan struct{}),
writerDone: make(chan struct{}),
Expand All @@ -267,8 +266,6 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
bufferPool: newBufferPool(),
}
t.logger = prefixLoggerForServerTransport(t)
// Add peer information to the http2server context.
t.ctx = peer.NewContext(t.ctx, t.getPeer())

t.controlBuf = newControlBuffer(t.done)
if dynamicWindow {
Expand All @@ -277,15 +274,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,
updateFlowControl: t.updateFlowControl,
}
}
for _, sh := range t.stats {
t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
})
connBegin := &stats.ConnBegin{}
sh.HandleConn(t.ctx, connBegin)
}
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.remoteAddr, t.localAddr))
t.channelzID, err = channelz.RegisterNormalSocket(t, config.ChannelzParentID, fmt.Sprintf("%s -> %s", t.peer.Addr, t.peer.LocalAddr))
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -342,7 +331,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport,

// operateHeaders takes action on the decoded headers. Returns an error if fatal
// error encountered and transport needs to close, otherwise returns nil.
func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeadersFrame, handle func(*Stream)) error {
// Acquire max stream ID lock for entire duration
t.maxStreamMu.Lock()
defer t.maxStreamMu.Unlock()
Expand All @@ -369,10 +358,11 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(

buf := newRecvBuffer()
s := &Stream{
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
id: streamID,
st: t,
buf: buf,
fc: &inFlow{limit: uint32(t.initialWindowSize)},
headerWireLength: int(frame.Header().Length),
}
var (
// if false, content-type was missing or invalid
Expand Down Expand Up @@ -511,9 +501,9 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.state = streamReadDone
}
if timeoutSet {
s.ctx, s.cancel = context.WithTimeout(t.ctx, timeout)
s.ctx, s.cancel = context.WithTimeout(ctx, timeout)
} else {
s.ctx, s.cancel = context.WithCancel(t.ctx)
s.ctx, s.cancel = context.WithCancel(ctx)
}

// Attach the received metadata to the context.
Expand Down Expand Up @@ -592,18 +582,6 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
s.requestRead = func(n int) {
t.adjustWindow(s, uint32(n))
}
for _, sh := range t.stats {
s.ctx = sh.TagRPC(s.ctx, &stats.RPCTagInfo{FullMethodName: s.method})
inHeader := &stats.InHeader{
FullMethod: s.method,
RemoteAddr: t.remoteAddr,
LocalAddr: t.localAddr,
Compression: s.recvCompress,
WireLength: int(frame.Header().Length),
Header: mdata.Copy(),
}
sh.HandleRPC(s.ctx, inHeader)
}
s.ctxDone = s.ctx.Done()
s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone)
s.trReader = &transportReader{
Expand All @@ -629,7 +607,7 @@ func (t *http2Server) operateHeaders(frame *http2.MetaHeadersFrame, handle func(
// HandleStreams receives incoming streams using the given handler. This is
// typically run in a separate goroutine.
// traceCtx attaches trace to ctx and returns the new context.
func (t *http2Server) HandleStreams(handle func(*Stream)) {
func (t *http2Server) HandleStreams(ctx context.Context, handle func(*Stream)) {
defer close(t.readerDone)
for {
t.controlBuf.throttle()
Expand Down Expand Up @@ -664,7 +642,7 @@ func (t *http2Server) HandleStreams(handle func(*Stream)) {
}
switch frame := frame.(type) {
case *http2.MetaHeadersFrame:
if err := t.operateHeaders(frame, handle); err != nil {
if err := t.operateHeaders(ctx, frame, handle); err != nil {
t.Close(err)
break
}
Expand Down Expand Up @@ -1242,10 +1220,6 @@ func (t *http2Server) Close(err error) {
for _, s := range streams {
s.cancel()
}
for _, sh := range t.stats {
connEnd := &stats.ConnEnd{}
sh.HandleConn(t.ctx, connEnd)
}
}

// deleteStream deletes the stream s from transport's active streams.
Expand Down Expand Up @@ -1311,10 +1285,6 @@ func (t *http2Server) closeStream(s *Stream, rst bool, rstCode http2.ErrCode, eo
})
}

func (t *http2Server) RemoteAddr() net.Addr {
return t.remoteAddr
}

func (t *http2Server) Drain(debugData string) {
t.mu.Lock()
defer t.mu.Unlock()
Expand Down Expand Up @@ -1397,11 +1367,11 @@ func (t *http2Server) ChannelzMetric() *channelz.SocketInternalMetric {
LastMessageReceivedTimestamp: time.Unix(0, atomic.LoadInt64(&t.czData.lastMsgRecvTime)),
LocalFlowControlWindow: int64(t.fc.getSize()),
SocketOptions: channelz.GetSocketOption(t.conn),
LocalAddr: t.localAddr,
RemoteAddr: t.remoteAddr,
LocalAddr: t.peer.LocalAddr,
RemoteAddr: t.peer.Addr,
// RemoteName :
}
if au, ok := t.authInfo.(credentials.ChannelzSecurityInfo); ok {
if au, ok := t.peer.AuthInfo.(credentials.ChannelzSecurityInfo); ok {
s.Security = au.GetSecurityValue()
}
s.RemoteFlowControlWindow = t.getOutFlowWindow()
Expand Down Expand Up @@ -1433,10 +1403,12 @@ func (t *http2Server) getOutFlowWindow() int64 {
}
}

func (t *http2Server) getPeer() *peer.Peer {
// Peer returns the peer of the transport.
func (t *http2Server) Peer() *peer.Peer {
return &peer.Peer{
Addr: t.remoteAddr,
AuthInfo: t.authInfo, // Can be nil
Addr: t.peer.Addr,
LocalAddr: t.peer.LocalAddr,
AuthInfo: t.peer.AuthInfo, // Can be nil
}
}

Expand All @@ -1461,6 +1433,6 @@ func GetConnection(ctx context.Context) net.Conn {
// SetConnection adds the connection to the context to be able to get
// information about the destination ip and port for an incoming RPC. This also
// allows any unary or streaming interceptors to see the connection.
func setConnection(ctx context.Context, conn net.Conn) context.Context {
func SetConnection(ctx context.Context, conn net.Conn) context.Context {
return context.WithValue(ctx, connectionKey{}, conn)
}
Loading

0 comments on commit c76d75f

Please sign in to comment.