Skip to content

Commit

Permalink
DNS resolving with timeout (#6917)
Browse files Browse the repository at this point in the history
  • Loading branch information
and1truong authored Mar 5, 2024
1 parent 815e2e2 commit f7c5e6a
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 11 deletions.
29 changes: 19 additions & 10 deletions internal/resolver/dns/dns_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ import (
// addresses from SRV records. Must not be changed after init time.
var EnableSRVLookups = false

// ResolvingTimeout specifies the maximum duration for a DNS resolution request.
// If the timeout expires before a response is received, the request will be canceled.
//
// It is recommended to set this value at application startup. Avoid modifying this variable
// after initialization as it's not thread-safe for concurrent modification.
var ResolvingTimeout = 30 * time.Second

var logger = grpclog.Component("dns")

func init() {
Expand Down Expand Up @@ -221,18 +228,18 @@ func (d *dnsResolver) watcher() {
}
}

func (d *dnsResolver) lookupSRV() ([]resolver.Address, error) {
func (d *dnsResolver) lookupSRV(ctx context.Context) ([]resolver.Address, error) {
if !EnableSRVLookups {
return nil, nil
}
var newAddrs []resolver.Address
_, srvs, err := d.resolver.LookupSRV(d.ctx, "grpclb", "tcp", d.host)
_, srvs, err := d.resolver.LookupSRV(ctx, "grpclb", "tcp", d.host)
if err != nil {
err = handleDNSError(err, "SRV") // may become nil
return nil, err
}
for _, s := range srvs {
lbAddrs, err := d.resolver.LookupHost(d.ctx, s.Target)
lbAddrs, err := d.resolver.LookupHost(ctx, s.Target)
if err != nil {
err = handleDNSError(err, "A") // may become nil
if err == nil {
Expand Down Expand Up @@ -269,8 +276,8 @@ func handleDNSError(err error, lookupType string) error {
return err
}

func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
ss, err := d.resolver.LookupTXT(d.ctx, txtPrefix+d.host)
func (d *dnsResolver) lookupTXT(ctx context.Context) *serviceconfig.ParseResult {
ss, err := d.resolver.LookupTXT(ctx, txtPrefix+d.host)
if err != nil {
if envconfig.TXTErrIgnore {
return nil
Expand All @@ -297,8 +304,8 @@ func (d *dnsResolver) lookupTXT() *serviceconfig.ParseResult {
return d.cc.ParseServiceConfig(sc)
}

func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
addrs, err := d.resolver.LookupHost(d.ctx, d.host)
func (d *dnsResolver) lookupHost(ctx context.Context) ([]resolver.Address, error) {
addrs, err := d.resolver.LookupHost(ctx, d.host)
if err != nil {
err = handleDNSError(err, "A")
return nil, err
Expand All @@ -316,8 +323,10 @@ func (d *dnsResolver) lookupHost() ([]resolver.Address, error) {
}

func (d *dnsResolver) lookup() (*resolver.State, error) {
srv, srvErr := d.lookupSRV()
addrs, hostErr := d.lookupHost()
ctx, cancel := context.WithTimeout(d.ctx, ResolvingTimeout)
defer cancel()
srv, srvErr := d.lookupSRV(ctx)
addrs, hostErr := d.lookupHost(ctx)
if hostErr != nil && (srvErr != nil || len(srv) == 0) {
return nil, hostErr
}
Expand All @@ -327,7 +336,7 @@ func (d *dnsResolver) lookup() (*resolver.State, error) {
state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv})
}
if !d.disableServiceConfig {
state.ServiceConfig = d.lookupTXT()
state.ServiceConfig = d.lookupTXT(ctx)
}
return &state, nil
}
Expand Down
43 changes: 43 additions & 0 deletions internal/resolver/dns/dns_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import (
dnsinternal "google.golang.org/grpc/internal/resolver/dns/internal"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/resolver"
dnspublic "google.golang.org/grpc/resolver/dns"
"google.golang.org/grpc/serviceconfig"

_ "google.golang.org/grpc" // To initialize internal.ParseServiceConfig
Expand Down Expand Up @@ -1215,3 +1216,45 @@ func (s) TestReportError(t *testing.T) {
}
}
}

// Override the default dns.ResolvingTimeout with a test duration.
func overrideResolveTimeoutDuration(t *testing.T, dur time.Duration) {
t.Helper()

origDur := dns.ResolvingTimeout
dnspublic.SetResolvingTimeout(dur)

t.Cleanup(func() { dnspublic.SetResolvingTimeout(origDur) })
}

// Test verifies that the DNS resolver gets timeout error when net.Resolver
// takes too long to resolve a target.
func (s) TestResolveTimeout(t *testing.T) {
// Set DNS resolving timeout duration to 7ms
timeoutDur := 7 * time.Millisecond
overrideResolveTimeoutDuration(t, timeoutDur)

ctx, cancel := context.WithTimeout(context.Background(), defaultTestTimeout)
defer cancel()

// We are trying to resolve hostname which takes infinity time to resolve.
const target = "infinity"

// Define a testNetResolver with lookupHostCh, an unbuffered channel,
// so we can block the resolver until reaching timeout.
tr := &testNetResolver{
lookupHostCh: testutils.NewChannelWithSize(0),
hostLookupTable: map[string][]string{target: {"1.2.3.4"}},
}
overrideNetResolver(t, tr)

_, _, errCh := buildResolverWithTestClientConn(t, target)
select {
case <-ctx.Done():
t.Fatal("Timeout when waiting for the DNS resolver to timeout")
case err := <-errCh:
if err == nil || !strings.Contains(err.Error(), "context deadline exceeded") {
t.Fatalf(`Expected to see Timeout error; got: %v`, err)
}
}
}
5 changes: 4 additions & 1 deletion internal/resolver/dns/fake_net_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ type testNetResolver struct {

func (tr *testNetResolver) LookupHost(ctx context.Context, host string) ([]string, error) {
if tr.lookupHostCh != nil {
tr.lookupHostCh.Send(nil)
if err := tr.lookupHostCh.SendContext(ctx, nil); err != nil {
return nil, err
}
}

tr.mu.Lock()
Expand All @@ -50,6 +52,7 @@ func (tr *testNetResolver) LookupHost(ctx context.Context, host string) ([]strin
if addrs, ok := tr.hostLookupTable[host]; ok {
return addrs, nil
}

return nil, &net.DNSError{
Err: "hostLookup error",
Name: host,
Expand Down
18 changes: 18 additions & 0 deletions resolver/dns/dns_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,28 @@
package dns

import (
"time"

"google.golang.org/grpc/internal/resolver/dns"
"google.golang.org/grpc/resolver"
)

// SetResolvingTimeout sets the maximum duration for DNS resolution requests.
//
// This function affects the global timeout used by all channels using the DNS
// name resolver scheme.
//
// It must be called only at application startup, before any gRPC calls are
// made. Modifying this value after initialization is not thread-safe.
//
// The default value is 30 seconds. Setting the timeout too low may result in
// premature timeouts during resolution, while setting it too high may lead to
// unnecessary delays in service discovery. Choose a value appropriate for your
// specific needs and network environment.
func SetResolvingTimeout(timeout time.Duration) {
dns.ResolvingTimeout = timeout
}

// NewBuilder creates a dnsBuilder which is used to factory DNS resolvers.
//
// Deprecated: import grpc and use resolver.Get("dns") instead.
Expand Down

0 comments on commit f7c5e6a

Please sign in to comment.