Skip to content

Commit

Permalink
chore: expose NewClient method to end users (#7010)
Browse files Browse the repository at this point in the history
  • Loading branch information
bruuuuuuuce authored Mar 7, 2024
1 parent c31fce8 commit c808322
Show file tree
Hide file tree
Showing 11 changed files with 125 additions and 41 deletions.
22 changes: 16 additions & 6 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,11 @@ func (dcs *defaultConfigSelector) SelectConfig(rpcInfo iresolver.RPCInfo) (*ires
}, nil
}

// newClient returns a new client in idle mode.
func newClient(target string, opts ...DialOption) (conn *ClientConn, err error) {
func newClient(target, defaultScheme string, opts ...DialOption) (conn *ClientConn, err error) {
cc := &ClientConn{
target: target,
conns: make(map[*addrConn]struct{}),
dopts: defaultDialOptions(),
dopts: defaultDialOptions(defaultScheme),
czData: new(channelzData),
}

Expand Down Expand Up @@ -191,6 +190,11 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
return cc, nil
}

// NewClient returns a new client in idle mode.
func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) {
return newClient(target, "dns", opts...)
}

// DialContext creates a client connection to the given target. By default, it's
// a non-blocking dial (the function won't wait for connections to be
// established, and connecting happens in the background). To make it a blocking
Expand All @@ -208,7 +212,8 @@ func newClient(target string, opts ...DialOption) (conn *ClientConn, err error)
// https://github.com/grpc/grpc/blob/master/doc/naming.md.
// e.g. to use dns resolver, a "dns:///" prefix should be applied to the target.
func DialContext(ctx context.Context, target string, opts ...DialOption) (conn *ClientConn, err error) {
cc, err := newClient(target, opts...)
// At the end of this method, we kick the channel out of idle, rather than waiting for the first rpc.
cc, err := newClient(target, "passthrough", opts...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -1740,8 +1745,13 @@ func (cc *ClientConn) parseTargetAndFindResolver() error {
// We are here because the user's dial target did not contain a scheme or
// specified an unregistered scheme. We should fallback to the default
// scheme, except when a custom dialer is specified in which case, we should
// always use passthrough scheme.
defScheme := resolver.GetDefaultScheme()
// always use passthrough scheme. For either case, we need to respect any overridden
// global defaults set by the user.
defScheme := cc.dopts.defScheme
if internal.UserSetDefaultScheme {
defScheme = resolver.GetDefaultScheme()
}

channelz.Infof(logger, cc.channelzID, "fallback to scheme %q", defScheme)
canonicalTarget := defScheme + ":///" + cc.target

Expand Down
96 changes: 80 additions & 16 deletions clientconn_parsed_target_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,34 +28,87 @@ import (

"github.com/google/go-cmp/cmp"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/testutils"

"google.golang.org/grpc/resolver"
)

func generateTarget(scheme string, target string) resolver.Target {
return resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", scheme, target))}
}

// This is here just in case another test calls the SetDefaultScheme method.
func resetInitialResolverState() {
resolver.SetDefaultScheme("passthrough")
internal.UserSetDefaultScheme = false
}

func (s) TestParsedTarget_Success_WithoutCustomDialer(t *testing.T) {
defScheme := resolver.GetDefaultScheme()
resetInitialResolverState()
dialScheme := resolver.GetDefaultScheme()
newClientScheme := "dns"
tests := []struct {
target string
wantParsed resolver.Target
target string
wantDialParse resolver.Target
wantNewClientParse resolver.Target
}{
// No scheme is specified.
{target: "://a/b", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "://a/b"))}},
{target: "a//b", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "a//b"))}},
{
target: "://a/b",
wantDialParse: generateTarget(dialScheme, "://a/b"),
wantNewClientParse: generateTarget(newClientScheme, "://a/b"),
},
{
target: "a//b",
wantDialParse: generateTarget(dialScheme, "a//b"),
wantNewClientParse: generateTarget(newClientScheme, "a//b"),
},

// An unregistered scheme is specified.
{target: "a:///", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "a:///"))}},
{target: "a:b", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "a:b"))}},
{
target: "a:///",
wantDialParse: generateTarget(dialScheme, "a:///"),
wantNewClientParse: generateTarget(newClientScheme, "a:///"),
},
{
target: "a:b",
wantDialParse: generateTarget(dialScheme, "a:b"),
wantNewClientParse: generateTarget(newClientScheme, "a:b"),
},

// A registered scheme is specified.
{target: "dns://a.server.com/google.com", wantParsed: resolver.Target{URL: *testutils.MustParseURL("dns://a.server.com/google.com")}},
{target: "unix-abstract:/ a///://::!@#$%25^&*()b", wantParsed: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:/ a///://::!@#$%25^&*()b")}},
{target: "unix-abstract:passthrough:abc", wantParsed: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:passthrough:abc")}},
{target: "passthrough:///unix:///a/b/c", wantParsed: resolver.Target{URL: *testutils.MustParseURL("passthrough:///unix:///a/b/c")}},
{
target: "dns://a.server.com/google.com",
wantDialParse: resolver.Target{URL: *testutils.MustParseURL("dns://a.server.com/google.com")},
wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("dns://a.server.com/google.com")},
},
{
target: "unix-abstract:/ a///://::!@#$%25^&*()b",
wantDialParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:/ a///://::!@#$%25^&*()b")},
wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:/ a///://::!@#$%25^&*()b")},
},
{
target: "unix-abstract:passthrough:abc",
wantDialParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:passthrough:abc")},
wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("unix-abstract:passthrough:abc")},
},
{
target: "passthrough:///unix:///a/b/c",
wantDialParse: resolver.Target{URL: *testutils.MustParseURL("passthrough:///unix:///a/b/c")},
wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("passthrough:///unix:///a/b/c")},
},

// Cases for `scheme:absolute-path`.
{target: "dns:/a/b/c", wantParsed: resolver.Target{URL: *testutils.MustParseURL("dns:/a/b/c")}},
{target: "unregistered:/a/b/c", wantParsed: resolver.Target{URL: *testutils.MustParseURL(fmt.Sprintf("%s:///%s", defScheme, "unregistered:/a/b/c"))}},
{
target: "dns:/a/b/c",
wantDialParse: resolver.Target{URL: *testutils.MustParseURL("dns:/a/b/c")},
wantNewClientParse: resolver.Target{URL: *testutils.MustParseURL("dns:/a/b/c")},
},
{
target: "unregistered:/a/b/c",
wantDialParse: generateTarget(dialScheme, "unregistered:/a/b/c"),
wantNewClientParse: generateTarget(newClientScheme, "unregistered:/a/b/c"),
},
}

for _, test := range tests {
Expand All @@ -66,8 +119,18 @@ func (s) TestParsedTarget_Success_WithoutCustomDialer(t *testing.T) {
}
defer cc.Close()

if !cmp.Equal(cc.parsedTarget, test.wantParsed) {
t.Errorf("cc.parsedTarget for dial target %q = %+v, want %+v", test.target, cc.parsedTarget, test.wantParsed)
if !cmp.Equal(cc.parsedTarget, test.wantDialParse) {
t.Errorf("cc.parsedTarget for dial target %q = %+v, want %+v", test.target, cc.parsedTarget, test.wantDialParse)
}

cc, err = NewClient(test.target, WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
t.Fatalf("NewClient(%q) failed: %v", test.target, err)
}
defer cc.Close()

if !cmp.Equal(cc.parsedTarget, test.wantNewClientParse) {
t.Errorf("cc.parsedTarget for newClient target %q = %+v, want %+v", test.target, cc.parsedTarget, test.wantNewClientParse)
}
})
}
Expand All @@ -93,6 +156,7 @@ func (s) TestParsedTarget_Failure_WithoutCustomDialer(t *testing.T) {
}

func (s) TestParsedTarget_WithCustomDialer(t *testing.T) {
resetInitialResolverState()
defScheme := resolver.GetDefaultScheme()
tests := []struct {
target string
Expand Down
10 changes: 5 additions & 5 deletions credentials/google/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ import (
"testing"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/grpctest"
"google.golang.org/grpc/internal/xds"
"google.golang.org/grpc/resolver"
)

Expand Down Expand Up @@ -109,31 +109,31 @@ func (s) TestClientHandshakeBasedOnClusterName(t *testing.T) {
{
name: "with non-CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "lalala").Attributes,
}),
// non-CFE backends should use alts.
wantTyp: "alts",
},
{
name: "with CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "google_cfe_bigtable.googleapis.com").Attributes,
}),
// CFE should use tls.
wantTyp: "tls",
},
{
name: "with xdstp CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://traffic-director-c2p.xds.googleapis.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
}),
// CFE should use tls.
wantTyp: "tls",
},
{
name: "with xdstp non-CFE cluster name",
ctx: icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, "xdstp://other.com/envoy.config.cluster.v3.Cluster/google_cfe_bigtable.googleapis.com").Attributes,
}),
// non-CFE should use atls.
wantTyp: "alts",
Expand Down
4 changes: 2 additions & 2 deletions credentials/google/xds.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import (
"strings"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/xds"
)

const cfeClusterNamePrefix = "google_cfe_"
Expand Down Expand Up @@ -63,7 +63,7 @@ func clusterName(ctx context.Context) string {
if chi.Attributes == nil {
return ""
}
cluster, _ := internal.GetXDSHandshakeClusterName(chi.Attributes)
cluster, _ := xds.GetXDSHandshakeClusterName(chi.Attributes)
return cluster
}

Expand Down
4 changes: 2 additions & 2 deletions credentials/google/xds_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,15 +23,15 @@ import (
"testing"

"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
icredentials "google.golang.org/grpc/internal/credentials"
"google.golang.org/grpc/internal/xds"
"google.golang.org/grpc/resolver"
)

func (s) TestIsDirectPathCluster(t *testing.T) {
c := func(cluster string) context.Context {
return icredentials.NewClientHandshakeInfoContext(context.Background(), credentials.ClientHandshakeInfo{
Attributes: internal.SetXDSHandshakeClusterName(resolver.Address{}, cluster).Attributes,
Attributes: xds.SetXDSHandshakeClusterName(resolver.Address{}, cluster).Attributes,
})
}

Expand Down
4 changes: 3 additions & 1 deletion dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ type dialOptions struct {
resolvers []resolver.Builder
idleTimeout time.Duration
recvBufferPool SharedBufferPool
defScheme string
}

// DialOption configures how we set up the connection.
Expand Down Expand Up @@ -631,7 +632,7 @@ func withHealthCheckFunc(f internal.HealthChecker) DialOption {
})
}

func defaultDialOptions() dialOptions {
func defaultDialOptions(defScheme string) dialOptions {
return dialOptions{
copts: transport.ConnectOptions{
ReadBufferSize: defaultReadBufSize,
Expand All @@ -643,6 +644,7 @@ func defaultDialOptions() dialOptions {
healthCheckFunc: internal.HealthCheckFunc,
idleTimeout: 30 * time.Minute,
recvBufferPool: nopBufferPool{},
defScheme: defScheme,
}
}

Expand Down
3 changes: 3 additions & 0 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ var (

// FromOutgoingContextRaw returns the un-merged, intermediary contents of metadata.rawMD.
FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool)

// UserSetDefaultScheme is set to true if the user has overridden the default resolver scheme.
UserSetDefaultScheme bool = false
)

// HealthChecker defines the signature of the client-side LB channel health checking function.
Expand Down
4 changes: 3 additions & 1 deletion internal/xds_handshake_cluster.go → internal/xds/xds.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@
* limitations under the License.
*/

package internal
// Package xds contains methods to Get/Set handshake cluster names. It is separated
// out from the top level /internal package to avoid circular dependencies.
package xds

import (
"google.golang.org/grpc/attributes"
Expand Down
7 changes: 5 additions & 2 deletions resolver/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ import (

"google.golang.org/grpc/attributes"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/serviceconfig"
)

Expand Down Expand Up @@ -63,16 +64,18 @@ func Get(scheme string) Builder {
}

// SetDefaultScheme sets the default scheme that will be used. The default
// default scheme is "passthrough".
// scheme is initially set to "passthrough".
//
// NOTE: this function must only be called during initialization time (i.e. in
// an init() function), and is not thread-safe. The scheme set last overrides
// previously set values.
func SetDefaultScheme(scheme string) {
defaultScheme = scheme
internal.UserSetDefaultScheme = true
}

// GetDefaultScheme gets the default scheme that will be used.
// GetDefaultScheme gets the default scheme that will be used by grpc.Dial. If
// SetDefaultScheme is never called, the default scheme used by grpc.NewClient is "dns" instead.
func GetDefaultScheme() string {
return defaultScheme
}
Expand Down
6 changes: 3 additions & 3 deletions xds/internal/balancer/clusterimpl/balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@ import (
"google.golang.org/grpc/balancer/base"
"google.golang.org/grpc/balancer/roundrobin"
"google.golang.org/grpc/connectivity"
"google.golang.org/grpc/internal"
"google.golang.org/grpc/internal/balancer/stub"
"google.golang.org/grpc/internal/grpctest"
internalserviceconfig "google.golang.org/grpc/internal/serviceconfig"
"google.golang.org/grpc/internal/testutils"
"google.golang.org/grpc/internal/xds"
"google.golang.org/grpc/resolver"
xdsinternal "google.golang.org/grpc/xds/internal"
"google.golang.org/grpc/xds/internal/testutils/fakeclient"
Expand Down Expand Up @@ -464,7 +464,7 @@ func (s) TestClusterNameInAddressAttributes(t *testing.T) {
if got, want := addrs1[0].Addr, testBackendAddrs[0].Addr; got != want {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
cn, ok := internal.GetXDSHandshakeClusterName(addrs1[0].Attributes)
cn, ok := xds.GetXDSHandshakeClusterName(addrs1[0].Attributes)
if !ok || cn != testClusterName {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn, ok, testClusterName)
}
Expand Down Expand Up @@ -495,7 +495,7 @@ func (s) TestClusterNameInAddressAttributes(t *testing.T) {
t.Fatalf("sc is created with addr %v, want %v", got, want)
}
// New addresses should have the new cluster name.
cn2, ok := internal.GetXDSHandshakeClusterName(addrs2[0].Attributes)
cn2, ok := xds.GetXDSHandshakeClusterName(addrs2[0].Attributes)
if !ok || cn2 != testClusterName2 {
t.Fatalf("sc is created with addr with cluster name %v, %v, want cluster name %v", cn2, ok, testClusterName2)
}
Expand Down
Loading

0 comments on commit c808322

Please sign in to comment.