Skip to content

Commit

Permalink
picker_wrapper: improve handling of context errors
Browse files Browse the repository at this point in the history
pickerWrapper had logic very similar to status.FromContextError() for
transforming Context errors to status errors. This patch removes the
duplication by delegating to the status library. Besides removing the
code duplication, the status library is arguably more robust because it
doesn't rely on ctx.Error() to only ever return two types of errors.

I believe this patch and the previous one stand on their own, but, FWIW,
they're also motivating by me wanting to experiment in the CockroachDB
codebase with using a custom implementation of context.Context whose
Err() method can return better errors than the stdlib context.Context.
These errors would still wrap context.Canceled.  Such an implementation
would technically break the documentation of context.Context, which
seems to exhaustively list the sentinel error that context.Context can
return. Still, as grpc#4977 showed, most
code should support wrapped context errors. This patch moves from "most
code" to "all code" in gRPC. I haven't checked which of the callsites
I've touched use contexts that might be inherited from a gRPC client, as
opposed to contexts derived inside gRPC from context.Background (which
contexts would not be affected by whatever I do outside of gRPC), but
unifying all the context error handling code seems like a good idea to
me universally.
  • Loading branch information
andreimatei committed Jan 7, 2022
1 parent dba1390 commit 3a8aa01
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 14 deletions.
9 changes: 1 addition & 8 deletions picker_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,15 +105,8 @@ func (pw *pickerWrapper) pick(ctx context.Context, failfast bool, info balancer.
var errStr string
if lastPickErr != nil {
errStr = "latest balancer error: " + lastPickErr.Error()
} else {
errStr = ctx.Err().Error()
}
switch ctx.Err() {
case context.DeadlineExceeded:
return nil, nil, status.Error(codes.DeadlineExceeded, errStr)
case context.Canceled:
return nil, nil, status.Error(codes.Canceled, errStr)
}
return nil, nil, status.FromContextError(ctx.Err(), errStr).Err()
case <-ch:
}
continue
Expand Down
16 changes: 11 additions & 5 deletions status/status.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,23 @@ func Code(err error) codes.Code {
// FromContextError converts a context error or wrapped context error into a
// Status. It returns a Status with codes.OK if err is nil, or a Status with
// codes.Unknown if err is non-nil and not a context error.
func FromContextError(err error) *Status {
//
// If msg != "" and err != nil, msg is used as the Status' message. If msg == "",
// err.Error() is used.
func FromContextError(err error, msg string) *Status {
if err == nil {
return nil
}
if msg == "" {
msg = err.Error()
}
if errors.Is(err, context.DeadlineExceeded) {
return New(codes.DeadlineExceeded, err.Error())
return New(codes.DeadlineExceeded, msg)
}
if errors.Is(err, context.Canceled) {
return New(codes.Canceled, err.Error())
return New(codes.Canceled, msg)
}
return New(codes.Unknown, err.Error())
return New(codes.Unknown, msg)
}

// MustFromContextError is like FromContextError, except that it expects err to
Expand All @@ -140,5 +146,5 @@ func MustFromContextError(err error) error {
if err == nil {
return status.New(codes.Internal, "Expected non-nil context error").Err()
}
return FromContextError(err).Err()
return FromContextError(err, "").Err()
}
5 changes: 4 additions & 1 deletion status/status_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -358,17 +358,20 @@ func mustMarshalAny(msg proto.Message) *apb.Any {
func (s) TestFromContextError(t *testing.T) {
testCases := []struct {
in error
msg string
want *Status
}{
{in: nil, want: New(codes.OK, "")},
{in: nil, msg: "ignored", want: New(codes.OK, "")},
{in: context.DeadlineExceeded, want: New(codes.DeadlineExceeded, context.DeadlineExceeded.Error())},
{in: context.Canceled, want: New(codes.Canceled, context.Canceled.Error())},
{in: errors.New("other"), want: New(codes.Unknown, "other")},
{in: errors.New("other"), msg: "my msg", want: New(codes.Unknown, "my msg")},
{in: fmt.Errorf("wrapped: %w", context.DeadlineExceeded), want: New(codes.DeadlineExceeded, "wrapped: "+context.DeadlineExceeded.Error())},
{in: fmt.Errorf("wrapped: %w", context.Canceled), want: New(codes.Canceled, "wrapped: "+context.Canceled.Error())},
}
for _, tc := range testCases {
got := FromContextError(tc.in)
got := FromContextError(tc.in, tc.msg)
if got.Code() != tc.want.Code() || got.Message() != tc.want.Message() {
t.Errorf("FromContextError(%v) = %v; want %v", tc.in, got, tc.want)
}
Expand Down

0 comments on commit 3a8aa01

Please sign in to comment.