diff --git a/gcp/observability/go.mod b/gcp/observability/go.mod index e730895d3aa4..2f9b5e0bc002 100644 --- a/gcp/observability/go.mod +++ b/gcp/observability/go.mod @@ -39,3 +39,5 @@ require ( ) replace google.golang.org/grpc => ../../ + +replace google.golang.org/grpc/stats/opencensus => ../../stats/opencensus diff --git a/gcp/observability/go.sum b/gcp/observability/go.sum index 109fef979a6b..b60536b2d605 100644 --- a/gcp/observability/go.sum +++ b/gcp/observability/go.sum @@ -1056,8 +1056,6 @@ google.golang.org/genproto v0.0.0-20221202195650-67e5cbc046fd/go.mod h1:cTsE614G google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f h1:BWUVssLB0HVOSY78gIdvk1dTVYtT1y8SBWtPYuTJ/6w= google.golang.org/genproto v0.0.0-20230110181048-76db0878b65f/go.mod h1:RGgjbofJ8xD9Sq1VVhDM1Vok1vRONV+rg+CjzG4SZKM= google.golang.org/grpc/cmd/protoc-gen-go-grpc v1.1.0/go.mod h1:6Kw0yEErY5E/yWrBtf03jp27GLLJujG4z/JK95pnjjw= -google.golang.org/grpc/stats/opencensus v0.0.0-20230221205128-8702a2ebf4b0 h1:v7h+HONu0plE0b3y9fBiOWlsqTdQQ5A9l9Ag2LXbEoE= -google.golang.org/grpc/stats/opencensus v0.0.0-20230221205128-8702a2ebf4b0/go.mod h1:l7+BYcyrDJFQo8nh4v8h5TJ6VfQ9QGBfFqVO7xoqQzI= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/gcp/observability/logging.go b/gcp/observability/logging.go index 6c6d8bf0aada..7be05975146c 100644 --- a/gcp/observability/logging.go +++ b/gcp/observability/logging.go @@ -29,6 +29,7 @@ import ( gcplogging "cloud.google.com/go/logging" "github.com/google/uuid" + "go.opencensus.io/trace" "google.golang.org/grpc" binlogpb "google.golang.org/grpc/binarylog/grpc_binarylog_v1" @@ -36,6 +37,7 @@ import ( "google.golang.org/grpc/internal/binarylog" iblog "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/grpcutil" + "google.golang.org/grpc/stats/opencensus" ) var lExporter loggingExporter @@ -237,13 +239,16 @@ type methodLoggerBuilder interface { } type binaryMethodLogger struct { - callID, serviceName, methodName, authority string + callID, serviceName, methodName, authority, projectID string - mlb methodLoggerBuilder - exporter loggingExporter + mlb methodLoggerBuilder + exporter loggingExporter + clientSide bool } -func (bml *binaryMethodLogger) Log(c iblog.LogEntryConfig) { +// buildGCPLoggingEntry converts the binary log log entry into a gcp logging +// entry. +func (bml *binaryMethodLogger) buildGCPLoggingEntry(ctx context.Context, c iblog.LogEntryConfig) gcplogging.Entry { binLogEntry := bml.mlb.Build(c) grpcLogEntry := &grpcLogEntry{ @@ -305,9 +310,6 @@ func (bml *binaryMethodLogger) Log(c iblog.LogEntryConfig) { setPeerIfPresent(binLogEntry, grpcLogEntry) case binlogpb.GrpcLogEntry_EVENT_TYPE_CANCEL: grpcLogEntry.Type = eventTypeCancel - default: - logger.Infof("Unknown event type: %v", binLogEntry.Type) - return } grpcLogEntry.ServiceName = bml.serviceName grpcLogEntry.MethodName = bml.methodName @@ -318,8 +320,25 @@ func (bml *binaryMethodLogger) Log(c iblog.LogEntryConfig) { Severity: 100, Payload: grpcLogEntry, } + if bml.clientSide { + // client side span, populated through opencensus trace package. + if span := trace.FromContext(ctx); span != nil { + sc := span.SpanContext() + gcploggingEntry.Trace = "projects/" + bml.projectID + "/traces/" + fmt.Sprintf("%x", sc.TraceID) + gcploggingEntry.SpanID = fmt.Sprintf("%x", sc.SpanID) + } + } else { + // server side span, populated through stats/opencensus package. + if tID, sID, ok := opencensus.GetTraceAndSpanID(ctx); ok { + gcploggingEntry.Trace = "projects/" + bml.projectID + "/traces/" + fmt.Sprintf("%x", tID) + gcploggingEntry.SpanID = fmt.Sprintf("%x", sID) + } + } + return gcploggingEntry +} - bml.exporter.EmitGcpLoggingEntry(gcploggingEntry) +func (bml *binaryMethodLogger) Log(ctx context.Context, c iblog.LogEntryConfig) { + bml.exporter.EmitGcpLoggingEntry(bml.buildGCPLoggingEntry(ctx, c)) } type eventConfig struct { @@ -336,7 +355,9 @@ type eventConfig struct { type binaryLogger struct { EventConfigs []eventConfig + projectID string exporter loggingExporter + clientSide bool } func (bl *binaryLogger) GetMethodLogger(methodName string) iblog.MethodLogger { @@ -352,9 +373,11 @@ func (bl *binaryLogger) GetMethodLogger(methodName string) iblog.MethodLogger { } return &binaryMethodLogger{ - exporter: bl.exporter, - mlb: iblog.NewTruncatingMethodLogger(eventConfig.HeaderBytes, eventConfig.MessageBytes), - callID: uuid.NewString(), + exporter: bl.exporter, + mlb: iblog.NewTruncatingMethodLogger(eventConfig.HeaderBytes, eventConfig.MessageBytes), + callID: uuid.NewString(), + projectID: bl.projectID, + clientSide: bl.clientSide, } } } @@ -372,7 +395,8 @@ func parseMethod(method string) (string, string, error) { return method[:pos], method[pos+1:], nil } -func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter loggingExporter) { +func registerClientRPCEvents(config *config, exporter loggingExporter) { + clientRPCEvents := config.CloudLogging.ClientRPCEvents if len(clientRPCEvents) == 0 { return } @@ -405,11 +429,14 @@ func registerClientRPCEvents(clientRPCEvents []clientRPCEvents, exporter logging clientSideLogger := &binaryLogger{ EventConfigs: eventConfigs, exporter: exporter, + projectID: config.ProjectID, + clientSide: true, } internal.AddGlobalDialOptions.(func(opt ...grpc.DialOption))(internal.WithBinaryLogger.(func(bl binarylog.Logger) grpc.DialOption)(clientSideLogger)) } -func registerServerRPCEvents(serverRPCEvents []serverRPCEvents, exporter loggingExporter) { +func registerServerRPCEvents(config *config, exporter loggingExporter) { + serverRPCEvents := config.CloudLogging.ServerRPCEvents if len(serverRPCEvents) == 0 { return } @@ -442,6 +469,8 @@ func registerServerRPCEvents(serverRPCEvents []serverRPCEvents, exporter logging serverSideLogger := &binaryLogger{ EventConfigs: eventConfigs, exporter: exporter, + projectID: config.ProjectID, + clientSide: false, } internal.AddGlobalServerOptions.(func(opt ...grpc.ServerOption))(internal.BinaryLogger.(func(bl binarylog.Logger) grpc.ServerOption)(serverSideLogger)) } @@ -456,9 +485,8 @@ func startLogging(ctx context.Context, config *config) error { return fmt.Errorf("unable to create CloudLogging exporter: %v", err) } - cl := config.CloudLogging - registerClientRPCEvents(cl.ClientRPCEvents, lExporter) - registerServerRPCEvents(cl.ServerRPCEvents, lExporter) + registerClientRPCEvents(config, lExporter) + registerServerRPCEvents(config, lExporter) return nil } diff --git a/gcp/observability/logging_test.go b/gcp/observability/logging_test.go index a02233387407..a42b1da550fd 100644 --- a/gcp/observability/logging_test.go +++ b/gcp/observability/logging_test.go @@ -68,6 +68,8 @@ type fakeLoggingExporter struct { mu sync.Mutex entries []*grpcLogEntry + + idsSeen []*traceAndSpanIDString } func (fle *fakeLoggingExporter) EmitGcpLoggingEntry(entry gcplogging.Entry) { @@ -76,6 +78,13 @@ func (fle *fakeLoggingExporter) EmitGcpLoggingEntry(entry gcplogging.Entry) { if entry.Severity != 100 { fle.t.Errorf("entry.Severity is not 100, this should be hardcoded") } + + ids := &traceAndSpanIDString{ + traceID: entry.Trace, + spanID: entry.SpanID, + } + fle.idsSeen = append(fle.idsSeen, ids) + grpcLogEntry, ok := entry.Payload.(*grpcLogEntry) if !ok { fle.t.Errorf("payload passed in isn't grpcLogEntry") diff --git a/gcp/observability/observability_test.go b/gcp/observability/observability_test.go index 87e9668eefdf..fa8cba1d38df 100644 --- a/gcp/observability/observability_test.go +++ b/gcp/observability/observability_test.go @@ -21,19 +21,24 @@ package observability import ( "context" "encoding/json" + "errors" "fmt" "io" "os" + "strings" "sync" "testing" "time" + "github.com/google/go-cmp/cmp" "go.opencensus.io/stats/view" "go.opencensus.io/trace" "google.golang.org/grpc/internal/envconfig" + "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpctest" "google.golang.org/grpc/internal/leakcheck" "google.golang.org/grpc/internal/stubserver" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" "google.golang.org/grpc/test/grpc_testing" ) @@ -80,6 +85,8 @@ type fakeOpenCensusExporter struct { // Number of spans SeenSpans int + idCh *testutils.Channel + t *testing.T mu sync.RWMutex } @@ -102,7 +109,39 @@ func (fe *fakeOpenCensusExporter) ExportView(vd *view.Data) { } } +type traceAndSpanID struct { + spanName string + traceID trace.TraceID + spanID trace.SpanID +} + +type traceAndSpanIDString struct { + traceID string + spanID string +} + +// idsToString is a helper that converts from generated trace and span IDs to +// the string version stored in trace message events. (hex 16 lowercase encoded, +// and extra data attached to trace id). +func idsToString(tasi traceAndSpanID, projectID string) traceAndSpanIDString { + return traceAndSpanIDString{ + traceID: "projects/" + projectID + "/traces/" + fmt.Sprintf("%x", tasi.traceID), + spanID: fmt.Sprintf("%x", tasi.spanID), + } +} + func (fe *fakeOpenCensusExporter) ExportSpan(vd *trace.SpanData) { + if fe.idCh != nil { + // This is what export span sees representing the trace/span ID which + // will populate different contexts throughout the system, convert in + // caller to string version as the logging code does. + fe.idCh.Send(traceAndSpanID{ + spanName: vd.Name, + traceID: vd.TraceID, + spanID: vd.SpanID, + }) + } + fe.mu.Lock() defer fe.mu.Unlock() fe.SeenSpans++ @@ -487,3 +526,532 @@ func (s) TestStartErrorsThenEnd(t *testing.T) { } End() } + +// TestLoggingLinkedWithTraceClientSide tests that client side logs get the +// trace and span id corresponding to the created Call Level Span for the RPC. +func (s) TestLoggingLinkedWithTraceClientSide(t *testing.T) { + fle := &fakeLoggingExporter{ + t: t, + } + oldNewLoggingExporter := newLoggingExporter + defer func() { + newLoggingExporter = oldNewLoggingExporter + }() + + newLoggingExporter = func(ctx context.Context, config *config) (loggingExporter, error) { + return fle, nil + } + + idCh := testutils.NewChannel() + + fe := &fakeOpenCensusExporter{ + t: t, + idCh: idCh, + } + oldNewExporter := newExporter + defer func() { + newExporter = oldNewExporter + }() + + newExporter = func(config *config) (tracingMetricsExporter, error) { + return fe, nil + } + + const projectID = "project-id" + tracesAndLogsConfig := &config{ + ProjectID: projectID, + CloudLogging: &cloudLogging{ + ClientRPCEvents: []clientRPCEvents{ + { + Methods: []string{"*"}, + MaxMetadataBytes: 30, + MaxMessageBytes: 30, + }, + }, + }, + CloudTrace: &cloudTrace{ + SamplingRate: 1.0, + }, + } + cleanup, err := setupObservabilitySystemWithConfig(tracesAndLogsConfig) + if err != nil { + t.Fatalf("error setting up observability %v", err) + } + defer cleanup() + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return &grpc_testing.SimpleResponse{}, nil + }, + FullDuplexCallF: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + _, err := stream.Recv() + if err != io.EOF { + return err + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Spawn a goroutine to receive the trace and span ids received by the + // exporter corresponding to a Unary RPC. + readerErrCh := testutils.NewChannel() + unaryDone := grpcsync.NewEvent() + go func() { + var traceAndSpanIDs []traceAndSpanID + val, err := idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok := val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + <-unaryDone.Done() + var tasiSent traceAndSpanIDString + for _, tasi := range traceAndSpanIDs { + if strings.HasPrefix(tasi.spanName, "Sent.") { + tasiSent = idsToString(tasi, projectID) + continue + } + } + + fle.mu.Lock() + for _, tasiSeen := range fle.idsSeen { + if diff := cmp.Diff(tasiSeen, &tasiSent, cmp.AllowUnexported(traceAndSpanIDString{})); diff != "" { + readerErrCh.Send(errors.New("got unexpected id, should be a client span")) + } + } + + fle.entries = nil + fle.mu.Unlock() + readerErrCh.Send(nil) + }() + if _, err := ss.Client.UnaryCall(ctx, &grpc_testing.SimpleRequest{Payload: &grpc_testing.Payload{Body: testOkPayload}}); err != nil { + t.Fatalf("Unexpected error from UnaryCall: %v", err) + } + unaryDone.Fire() + if chErr, err := readerErrCh.Receive(ctx); chErr != nil || err != nil { + if err != nil { + t.Fatalf("Should have received something from error channel: %v", err) + } + if chErr != nil { + t.Fatalf("Should have received a nil error from channel, instead received: %v", chErr) + } + } +} + +// TestLoggingLinkedWithTraceServerSide tests that server side logs get the +// trace and span id corresponding to the created Server Span for the RPC. +func (s) TestLoggingLinkedWithTraceServerSide(t *testing.T) { + fle := &fakeLoggingExporter{ + t: t, + } + oldNewLoggingExporter := newLoggingExporter + defer func() { + newLoggingExporter = oldNewLoggingExporter + }() + + newLoggingExporter = func(ctx context.Context, config *config) (loggingExporter, error) { + return fle, nil + } + + idCh := testutils.NewChannel() + + fe := &fakeOpenCensusExporter{ + t: t, + idCh: idCh, + } + oldNewExporter := newExporter + defer func() { + newExporter = oldNewExporter + }() + + newExporter = func(config *config) (tracingMetricsExporter, error) { + return fe, nil + } + + const projectID = "project-id" + tracesAndLogsConfig := &config{ + ProjectID: projectID, + CloudLogging: &cloudLogging{ + ServerRPCEvents: []serverRPCEvents{ + { + Methods: []string{"*"}, + MaxMetadataBytes: 30, + MaxMessageBytes: 30, + }, + }, + }, + CloudTrace: &cloudTrace{ + SamplingRate: 1.0, + }, + } + cleanup, err := setupObservabilitySystemWithConfig(tracesAndLogsConfig) + if err != nil { + t.Fatalf("error setting up observability %v", err) + } + defer cleanup() + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return &grpc_testing.SimpleResponse{}, nil + }, + FullDuplexCallF: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + _, err := stream.Recv() + if err != io.EOF { + return err + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Spawn a goroutine to receive the trace and span ids received by the + // exporter corresponding to a Unary RPC. + readerErrCh := testutils.NewChannel() + unaryDone := grpcsync.NewEvent() + go func() { + var traceAndSpanIDs []traceAndSpanID + val, err := idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok := val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + <-unaryDone.Done() + var tasiServer traceAndSpanIDString + for _, tasi := range traceAndSpanIDs { + if strings.HasPrefix(tasi.spanName, "grpc.") { + tasiServer = idsToString(tasi, projectID) + continue + } + } + + fle.mu.Lock() + for _, tasiSeen := range fle.idsSeen { + if diff := cmp.Diff(tasiSeen, &tasiServer, cmp.AllowUnexported(traceAndSpanIDString{})); diff != "" { + readerErrCh.Send(errors.New("got unexpected id, should be a server span")) + } + } + + fle.entries = nil + fle.mu.Unlock() + readerErrCh.Send(nil) + }() + if _, err := ss.Client.UnaryCall(ctx, &grpc_testing.SimpleRequest{Payload: &grpc_testing.Payload{Body: testOkPayload}}); err != nil { + t.Fatalf("Unexpected error from UnaryCall: %v", err) + } + unaryDone.Fire() + if chErr, err := readerErrCh.Receive(ctx); chErr != nil || err != nil { + if err != nil { + t.Fatalf("Should have received something from error channel: %v", err) + } + if chErr != nil { + t.Fatalf("Should have received a nil error from channel, instead received: %v", chErr) + } + } +} + +// TestLoggingLinkedWithTrace tests that client and server side logs get the +// trace and span id corresponding to either the Call Level Span or Server Span +// (no determinism, so can only assert one or the other), for Unary and +// Streaming RPCs. +func (s) TestLoggingLinkedWithTrace(t *testing.T) { + fle := &fakeLoggingExporter{ + t: t, + } + oldNewLoggingExporter := newLoggingExporter + defer func() { + newLoggingExporter = oldNewLoggingExporter + }() + + newLoggingExporter = func(ctx context.Context, config *config) (loggingExporter, error) { + return fle, nil + } + + idCh := testutils.NewChannel() + + fe := &fakeOpenCensusExporter{ + t: t, + idCh: idCh, + } + oldNewExporter := newExporter + defer func() { + newExporter = oldNewExporter + }() + + newExporter = func(config *config) (tracingMetricsExporter, error) { + return fe, nil + } + + const projectID = "project-id" + tracesAndLogsConfig := &config{ + ProjectID: projectID, + CloudLogging: &cloudLogging{ + ClientRPCEvents: []clientRPCEvents{ + { + Methods: []string{"*"}, + MaxMetadataBytes: 30, + MaxMessageBytes: 30, + }, + }, + ServerRPCEvents: []serverRPCEvents{ + { + Methods: []string{"*"}, + MaxMetadataBytes: 30, + MaxMessageBytes: 30, + }, + }, + }, + CloudTrace: &cloudTrace{ + SamplingRate: 1.0, + }, + } + cleanup, err := setupObservabilitySystemWithConfig(tracesAndLogsConfig) + if err != nil { + t.Fatalf("error setting up observability %v", err) + } + defer cleanup() + ss := &stubserver.StubServer{ + UnaryCallF: func(ctx context.Context, in *grpc_testing.SimpleRequest) (*grpc_testing.SimpleResponse, error) { + return &grpc_testing.SimpleResponse{}, nil + }, + FullDuplexCallF: func(stream grpc_testing.TestService_FullDuplexCallServer) error { + _, err := stream.Recv() + if err != io.EOF { + return err + } + return nil + }, + } + if err := ss.Start(nil); err != nil { + t.Fatalf("Error starting endpoint server: %v", err) + } + defer ss.Stop() + + ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout) + defer cancel() + + // Spawn a goroutine to receive the trace and span ids received by the + // exporter corresponding to a Unary RPC. + readerErrCh := testutils.NewChannel() + unaryDone := grpcsync.NewEvent() + go func() { + var traceAndSpanIDs []traceAndSpanID + val, err := idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok := val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + <-unaryDone.Done() + var tasiSent traceAndSpanIDString + var tasiServer traceAndSpanIDString + for _, tasi := range traceAndSpanIDs { + if strings.HasPrefix(tasi.spanName, "Sent.") { + tasiSent = idsToString(tasi, projectID) + continue + } + if strings.HasPrefix(tasi.spanName, "grpc.") { + tasiServer = idsToString(tasi, projectID) + } + } + + fle.mu.Lock() + for _, tasiSeen := range fle.idsSeen { + if diff := cmp.Diff(tasiSeen, &tasiSent, cmp.AllowUnexported(traceAndSpanIDString{})); diff != "" { + if diff2 := cmp.Diff(tasiSeen, &tasiServer, cmp.AllowUnexported(traceAndSpanIDString{})); diff2 != "" { + readerErrCh.Send(errors.New("got unexpected id, should be client or server span")) + } + } + } + + fle.entries = nil + fle.mu.Unlock() + readerErrCh.Send(nil) + }() + if _, err := ss.Client.UnaryCall(ctx, &grpc_testing.SimpleRequest{Payload: &grpc_testing.Payload{Body: testOkPayload}}); err != nil { + t.Fatalf("Unexpected error from UnaryCall: %v", err) + } + unaryDone.Fire() + if chErr, err := readerErrCh.Receive(ctx); chErr != nil || err != nil { + if err != nil { + t.Fatalf("Should have received something from error channel: %v", err) + } + if chErr != nil { + t.Fatalf("Should have received a nil error from channel, instead received: %v", chErr) + } + } + + fle.mu.Lock() + fle.idsSeen = nil + fle.mu.Unlock() + + // Test streaming. Spawn a goroutine to receive the trace and span ids + // received by the exporter corresponding to a streaming RPC. + readerErrCh = testutils.NewChannel() + streamDone := grpcsync.NewEvent() + go func() { + var traceAndSpanIDs []traceAndSpanID + + val, err := idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok := val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + + val, err = idCh.Receive(ctx) + if err != nil { + readerErrCh.Send(fmt.Errorf("error while waiting for IDs: %v", err)) + } + tasi, ok = val.(traceAndSpanID) + if !ok { + readerErrCh.Send(fmt.Errorf("received wrong type from channel: %T", val)) + } + traceAndSpanIDs = append(traceAndSpanIDs, tasi) + <-streamDone.Done() + var tasiSent traceAndSpanIDString + var tasiServer traceAndSpanIDString + for _, tasi := range traceAndSpanIDs { + if strings.HasPrefix(tasi.spanName, "Sent.") { + tasiSent = idsToString(tasi, projectID) + continue + } + if strings.HasPrefix(tasi.spanName, "grpc.") { + tasiServer = idsToString(tasi, projectID) + } + } + + fle.mu.Lock() + for _, tasiSeen := range fle.idsSeen { + if diff := cmp.Diff(tasiSeen, &tasiSent, cmp.AllowUnexported(traceAndSpanIDString{})); diff != "" { + if diff2 := cmp.Diff(tasiSeen, &tasiServer, cmp.AllowUnexported(traceAndSpanIDString{})); diff2 != "" { + readerErrCh.Send(errors.New("got unexpected id, should be client or server span")) + } + } + } + + fle.entries = nil + fle.mu.Unlock() + readerErrCh.Send(nil) + }() + + stream, err := ss.Client.FullDuplexCall(ctx) + if err != nil { + t.Fatalf("ss.Client.FullDuplexCall failed: %f", err) + } + + stream.CloseSend() + if _, err = stream.Recv(); err != io.EOF { + t.Fatalf("unexpected error: %v, expected an EOF error", err) + } + streamDone.Fire() + + if chErr, err := readerErrCh.Receive(ctx); chErr != nil || err != nil { + if err != nil { + t.Fatalf("Should have received something from error channel: %v", err) + } + if chErr != nil { + t.Fatalf("Should have received a nil error from channel, instead received: %v", chErr) + } + } +} diff --git a/internal/binarylog/binarylog.go b/internal/binarylog/binarylog.go index 809d73ccafb0..af03a40d990b 100644 --- a/internal/binarylog/binarylog.go +++ b/internal/binarylog/binarylog.go @@ -28,8 +28,10 @@ import ( "google.golang.org/grpc/internal/grpcutil" ) -// Logger is the global binary logger. It can be used to get binary logger for -// each method. +var grpclogLogger = grpclog.Component("binarylog") + +// Logger specifies MethodLoggers for method names with a Log call that +// takes a context. type Logger interface { GetMethodLogger(methodName string) MethodLogger } @@ -40,8 +42,6 @@ type Logger interface { // It is used to get a MethodLogger for each individual method. var binLogger Logger -var grpclogLogger = grpclog.Component("binarylog") - // SetLogger sets the binary logger. // // Only call this at init time. diff --git a/internal/binarylog/binarylog_test.go b/internal/binarylog/binarylog_test.go index 05138f8f309f..47f6a541e767 100644 --- a/internal/binarylog/binarylog_test.go +++ b/internal/binarylog/binarylog_test.go @@ -98,7 +98,6 @@ func (s) TestGetMethodLogger(t *testing.T) { t.Errorf("in: %q, method logger is nil, want non-nil", tc.in) continue } - if ml.headerMaxLen != tc.hdr || ml.messageMaxLen != tc.msg { t.Errorf("in: %q, want header: %v, message: %v, got header: %v, message: %v", tc.in, tc.hdr, tc.msg, ml.headerMaxLen, ml.messageMaxLen) } diff --git a/internal/binarylog/method_logger.go b/internal/binarylog/method_logger.go index d71e441778f4..56fcf008d3de 100644 --- a/internal/binarylog/method_logger.go +++ b/internal/binarylog/method_logger.go @@ -19,6 +19,7 @@ package binarylog import ( + "context" "net" "strings" "sync/atomic" @@ -49,7 +50,7 @@ var idGen callIDGenerator // MethodLogger is the sub-logger for each method. type MethodLogger interface { - Log(LogEntryConfig) + Log(context.Context, LogEntryConfig) } // TruncatingMethodLogger is a method logger that truncates headers and messages @@ -98,7 +99,7 @@ func (ml *TruncatingMethodLogger) Build(c LogEntryConfig) *binlogpb.GrpcLogEntry } // Log creates a proto binary log entry, and logs it to the sink. -func (ml *TruncatingMethodLogger) Log(c LogEntryConfig) { +func (ml *TruncatingMethodLogger) Log(ctx context.Context, c LogEntryConfig) { ml.sink.Write(ml.Build(c)) } diff --git a/internal/binarylog/method_logger_test.go b/internal/binarylog/method_logger_test.go index 5d1e09a39658..11255bb338b4 100644 --- a/internal/binarylog/method_logger_test.go +++ b/internal/binarylog/method_logger_test.go @@ -20,6 +20,7 @@ package binarylog import ( "bytes" + "context" "fmt" "net" "testing" @@ -335,7 +336,7 @@ func (s) TestLog(t *testing.T) { for i, tc := range testCases { buf.Reset() tc.want.SequenceIdWithinCall = uint64(i + 1) - ml.Log(tc.config) + ml.Log(context.Background(), tc.config) inSink := new(binlogpb.GrpcLogEntry) if err := proto.Unmarshal(buf.Bytes()[4:], inSink); err != nil { t.Errorf("failed to unmarshal bytes in sink to proto: %v", err) diff --git a/server.go b/server.go index 8d573dc6075b..087b9ad7c1f6 100644 --- a/server.go +++ b/server.go @@ -1253,7 +1253,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. logEntry.PeerAddr = peer.Addr } for _, binlog := range binlogs { - binlog.Log(logEntry) + binlog.Log(ctx, logEntry) } } @@ -1333,7 +1333,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Message: d, } for _, binlog := range binlogs { - binlog.Log(cm) + binlog.Log(stream.Context(), cm) } } if trInfo != nil { @@ -1366,7 +1366,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Header: h, } for _, binlog := range binlogs { - binlog.Log(sh) + binlog.Log(stream.Context(), sh) } } st := &binarylog.ServerTrailer{ @@ -1374,7 +1374,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Err: appErr, } for _, binlog := range binlogs { - binlog.Log(st) + binlog.Log(stream.Context(), st) } } return appErr @@ -1416,8 +1416,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Err: appErr, } for _, binlog := range binlogs { - binlog.Log(sh) - binlog.Log(st) + binlog.Log(stream.Context(), sh) + binlog.Log(stream.Context(), st) } } return err @@ -1431,8 +1431,8 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Message: reply, } for _, binlog := range binlogs { - binlog.Log(sh) - binlog.Log(sm) + binlog.Log(stream.Context(), sh) + binlog.Log(stream.Context(), sm) } } if channelz.IsOn() { @@ -1450,7 +1450,7 @@ func (s *Server) processUnaryRPC(t transport.ServerTransport, stream *transport. Err: appErr, } for _, binlog := range binlogs { - binlog.Log(st) + binlog.Log(stream.Context(), st) } } return t.WriteStatus(stream, statusOK) @@ -1587,7 +1587,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp logEntry.PeerAddr = peer.Addr } for _, binlog := range ss.binlogs { - binlog.Log(logEntry) + binlog.Log(stream.Context(), logEntry) } } @@ -1665,7 +1665,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp Err: appErr, } for _, binlog := range ss.binlogs { - binlog.Log(st) + binlog.Log(stream.Context(), st) } } t.WriteStatus(ss.s, appStatus) @@ -1683,7 +1683,7 @@ func (s *Server) processStreamingRPC(t transport.ServerTransport, stream *transp Err: appErr, } for _, binlog := range ss.binlogs { - binlog.Log(st) + binlog.Log(stream.Context(), st) } } return t.WriteStatus(ss.s, statusOK) diff --git a/stats/opencensus/e2e_test.go b/stats/opencensus/e2e_test.go index 251a75cdac54..943713f154d2 100644 --- a/stats/opencensus/e2e_test.go +++ b/stats/opencensus/e2e_test.go @@ -1477,11 +1477,11 @@ func (s) TestSpan(t *testing.T) { hasRemoteParent: false, }, } + fe.mu.Lock() + defer fe.mu.Unlock() if diff := cmp.Diff(fe.seenSpans, wantSI); diff != "" { t.Fatalf("got unexpected spans, diff (-got, +want): %v", diff) } - fe.mu.Lock() - defer fe.mu.Unlock() if err := validateTraceAndSpanIDs(fe.seenSpans); err != nil { t.Fatalf("Error in runtime data assertions: %v", err) } diff --git a/stats/opencensus/opencensus.go b/stats/opencensus/opencensus.go index 350cebfb4aca..fc7ee341ea63 100644 --- a/stats/opencensus/opencensus.go +++ b/stats/opencensus/opencensus.go @@ -152,13 +152,24 @@ func setRPCInfo(ctx context.Context, ri *rpcInfo) context.Context { return context.WithValue(ctx, rpcInfoKey{}, ri) } -// getSpanWithMsgCount returns the rpcInfo stored in the context, or nil +// getRPCInfo returns the rpcInfo stored in the context, or nil // if there isn't one. func getRPCInfo(ctx context.Context) *rpcInfo { ri, _ := ctx.Value(rpcInfoKey{}).(*rpcInfo) return ri } +// GetTraceAndSpanID returns the trace and span ID of the span in the context. +// Returns true if IDs present and false if IDs not present. +func GetTraceAndSpanID(ctx context.Context) (trace.TraceID, trace.SpanID, bool) { + ri, ok := ctx.Value(rpcInfoKey{}).(*rpcInfo) + if !ok { + return trace.TraceID{}, trace.SpanID{}, false + } + sc := ri.ti.span.SpanContext() + return sc.TraceID, sc.SpanID, true +} + type clientStatsHandler struct { to TraceOptions } diff --git a/stats/opencensus/trace.go b/stats/opencensus/trace.go index 2c8a93551fa7..afd5b4fd8912 100644 --- a/stats/opencensus/trace.go +++ b/stats/opencensus/trace.go @@ -40,6 +40,8 @@ type traceInfo struct { func (csh *clientStatsHandler) traceTagRPC(ctx context.Context, rti *stats.RPCTagInfo) (context.Context, *traceInfo) { // TODO: get consensus on whether this method name of "s.m" is correct. mn := "Attempt." + strings.Replace(removeLeadingSlash(rti.FullMethodName), "/", ".", -1) + // Returned context is ignored because will populate context with data + // that wraps the span instead. _, span := trace.StartSpan(ctx, mn, trace.WithSampler(csh.to.TS), trace.WithSpanKind(trace.SpanKindClient)) tcBin := propagation.Binary(span.SpanContext()) diff --git a/stream.go b/stream.go index 0926cd4f5755..d1226a4120f8 100644 --- a/stream.go +++ b/stream.go @@ -361,7 +361,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client } } for _, binlog := range cs.binlogs { - binlog.Log(logEntry) + binlog.Log(cs.ctx, logEntry) } } @@ -809,7 +809,7 @@ func (cs *clientStream) Header() (metadata.MD, error) { } cs.serverHeaderBinlogged = true for _, binlog := range cs.binlogs { - binlog.Log(logEntry) + binlog.Log(cs.ctx, logEntry) } } return m, nil @@ -890,7 +890,7 @@ func (cs *clientStream) SendMsg(m interface{}) (err error) { Message: data, } for _, binlog := range cs.binlogs { - binlog.Log(cm) + binlog.Log(cs.ctx, cm) } } return err @@ -914,7 +914,7 @@ func (cs *clientStream) RecvMsg(m interface{}) error { Message: recvInfo.uncompressedBytes, } for _, binlog := range cs.binlogs { - binlog.Log(sm) + binlog.Log(cs.ctx, sm) } } if err != nil || !cs.desc.ServerStreams { @@ -935,7 +935,7 @@ func (cs *clientStream) RecvMsg(m interface{}) error { logEntry.PeerAddr = peer.Addr } for _, binlog := range cs.binlogs { - binlog.Log(logEntry) + binlog.Log(cs.ctx, logEntry) } } } @@ -962,7 +962,7 @@ func (cs *clientStream) CloseSend() error { OnClientSide: true, } for _, binlog := range cs.binlogs { - binlog.Log(chc) + binlog.Log(cs.ctx, chc) } } // We never returned an error here for reasons. @@ -1004,7 +1004,7 @@ func (cs *clientStream) finish(err error) { OnClientSide: true, } for _, binlog := range cs.binlogs { - binlog.Log(c) + binlog.Log(cs.ctx, c) } } if err == nil { @@ -1573,7 +1573,7 @@ func (ss *serverStream) SendHeader(md metadata.MD) error { } ss.serverHeaderBinlogged = true for _, binlog := range ss.binlogs { - binlog.Log(sh) + binlog.Log(ss.ctx, sh) } } return err @@ -1646,14 +1646,14 @@ func (ss *serverStream) SendMsg(m interface{}) (err error) { } ss.serverHeaderBinlogged = true for _, binlog := range ss.binlogs { - binlog.Log(sh) + binlog.Log(ss.ctx, sh) } } sm := &binarylog.ServerMessage{ Message: data, } for _, binlog := range ss.binlogs { - binlog.Log(sm) + binlog.Log(ss.ctx, sm) } } if len(ss.statsHandler) != 0 { @@ -1701,7 +1701,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { if len(ss.binlogs) != 0 { chc := &binarylog.ClientHalfClose{} for _, binlog := range ss.binlogs { - binlog.Log(chc) + binlog.Log(ss.ctx, chc) } } return err @@ -1729,7 +1729,7 @@ func (ss *serverStream) RecvMsg(m interface{}) (err error) { Message: payInfo.uncompressedBytes, } for _, binlog := range ss.binlogs { - binlog.Log(cm) + binlog.Log(ss.ctx, cm) } } return nil diff --git a/test/end2end_test.go b/test/end2end_test.go index 9de88ebcc46f..a6d992286894 100644 --- a/test/end2end_test.go +++ b/test/end2end_test.go @@ -6799,7 +6799,7 @@ type mockMethodLogger struct { events uint64 } -func (mml *mockMethodLogger) Log(binarylog.LogEntryConfig) { +func (mml *mockMethodLogger) Log(context.Context, binarylog.LogEntryConfig) { atomic.AddUint64(&mml.events, 1) }