Skip to content

Commit

Permalink
grpc: Add perTargetDialOption type and global list (#7234)
Browse files Browse the repository at this point in the history
  • Loading branch information
zasweq authored May 21, 2024
1 parent 2d2f417 commit aea78bd
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 37 deletions.
58 changes: 26 additions & 32 deletions clientconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,16 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
for _, opt := range opts {
opt.apply(&cc.dopts)
}

// Determine the resolver to use.
if err := cc.initParsedTargetAndResolverBuilder(); err != nil {
return nil, err
}

for _, opt := range globalPerTargetDialOptions {
opt.DialOptionForTarget(cc.parsedTarget.URL).apply(&cc.dopts)
}

chainUnaryClientInterceptors(cc)
chainStreamClientInterceptors(cc)

Expand All @@ -168,25 +178,16 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error)
}
cc.mkp = cc.dopts.copts.KeepaliveParams

// Register ClientConn with channelz.
cc.channelzRegistration(target)

// TODO: Ideally it should be impossible to error from this function after
// channelz registration. This will require removing some channelz logs
// from the following functions that can error. Errors can be returned to
// the user, and successful logs can be emitted here, after the checks have
// passed and channelz is subsequently registered.

// Determine the resolver to use.
if err := cc.parseTargetAndFindResolver(); err != nil {
channelz.RemoveEntry(cc.channelz.ID)
return nil, err
}
if err = cc.determineAuthority(); err != nil {
channelz.RemoveEntry(cc.channelz.ID)
if err = cc.initAuthority(); err != nil {
return nil, err
}

// Register ClientConn with channelz. Note that this is only done after
// channel creation cannot fail.
cc.channelzRegistration(target)
channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", cc.parsedTarget)
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)

cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz)
cc.pickerWrapper = newPickerWrapper(cc.dopts.copts.StatsHandlers)

Expand Down Expand Up @@ -587,11 +588,11 @@ type ClientConn struct {

// The following are initialized at dial time, and are read-only after that.
target string // User's dial target.
parsedTarget resolver.Target // See parseTargetAndFindResolver().
authority string // See determineAuthority().
parsedTarget resolver.Target // See initParsedTargetAndResolverBuilder().
authority string // See initAuthority().
dopts dialOptions // Default and user specified dial options.
channelz *channelz.Channel // Channelz object.
resolverBuilder resolver.Builder // See parseTargetAndFindResolver().
resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder().
idlenessMgr *idle.Manager

// The following provide their own synchronization, and therefore don't
Expand Down Expand Up @@ -1673,22 +1674,19 @@ func (cc *ClientConn) connectionError() error {
return cc.lastConnectionError
}

// parseTargetAndFindResolver parses the user's dial target and stores the
// parsed target in `cc.parsedTarget`.
// initParsedTargetAndResolverBuilder parses the user's dial target and stores
// the parsed target in `cc.parsedTarget`.
//
// The resolver to use is determined based on the scheme in the parsed target
// and the same is stored in `cc.resolverBuilder`.
//
// Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) parseTargetAndFindResolver() error {
channelz.Infof(logger, cc.channelz, "original dial target is: %q", cc.target)
func (cc *ClientConn) initParsedTargetAndResolverBuilder() error {
logger.Infof("original dial target is: %q", cc.target)

var rb resolver.Builder
parsedTarget, err := parseTarget(cc.target)
if err != nil {
channelz.Infof(logger, cc.channelz, "dial target %q parse failed: %v", cc.target, err)
} else {
channelz.Infof(logger, cc.channelz, "parsed dial target is: %#v", parsedTarget)
if err == nil {
rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb != nil {
cc.parsedTarget = parsedTarget
Expand All @@ -1707,15 +1705,12 @@ func (cc *ClientConn) parseTargetAndFindResolver() error {
defScheme = resolver.GetDefaultScheme()
}

channelz.Infof(logger, cc.channelz, "fallback to scheme %q", defScheme)
canonicalTarget := defScheme + ":///" + cc.target

parsedTarget, err = parseTarget(canonicalTarget)
if err != nil {
channelz.Infof(logger, cc.channelz, "dial target %q parse failed: %v", canonicalTarget, err)
return err
}
channelz.Infof(logger, cc.channelz, "parsed dial target is: %+v", parsedTarget)
rb = cc.getResolver(parsedTarget.URL.Scheme)
if rb == nil {
return fmt.Errorf("could not get resolver for default scheme: %q", parsedTarget.URL.Scheme)
Expand Down Expand Up @@ -1805,7 +1800,7 @@ func encodeAuthority(authority string) string {
// credentials do not match the authority configured through the dial option.
//
// Doesn't grab cc.mu as this method is expected to be called only at Dial time.
func (cc *ClientConn) determineAuthority() error {
func (cc *ClientConn) initAuthority() error {
dopts := cc.dopts
// Historically, we had two options for users to specify the serverName or
// authority for a channel. One was through the transport credentials
Expand Down Expand Up @@ -1838,6 +1833,5 @@ func (cc *ClientConn) determineAuthority() error {
} else {
cc.authority = encodeAuthority(endpoint)
}
channelz.Infof(logger, cc.channelz, "Channel authority set to %q", cc.authority)
return nil
}
33 changes: 32 additions & 1 deletion default_dial_option_server_option_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ package grpc

import (
"fmt"
"net/url"
"strings"
"testing"

Expand All @@ -40,6 +41,7 @@ func (s) TestAddGlobalDialOptions(t *testing.T) {
// Set and check the DialOptions
opts := []DialOption{WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials()), WithTransportCredentials(insecure.NewCredentials())}
internal.AddGlobalDialOptions.(func(opt ...DialOption))(opts...)
defer internal.ClearGlobalDialOptions()
for i, opt := range opts {
if globalDialOptions[i] != opt {
t.Fatalf("Unexpected global dial option at index %d: %v != %v", i, globalDialOptions[i], opt)
Expand All @@ -64,21 +66,50 @@ func (s) TestAddGlobalDialOptions(t *testing.T) {
func (s) TestDisableGlobalOptions(t *testing.T) {
// Set transport credentials as a global option.
internal.AddGlobalDialOptions.(func(opt ...DialOption))(WithTransportCredentials(insecure.NewCredentials()))
defer internal.ClearGlobalDialOptions()
// Dial with the disable global options dial option. This dial should fail
// due to the global dial options with credentials not being picked up due
// to global options being disabled.
noTSecStr := "no transport security set"
if _, err := Dial("fake", internal.DisableGlobalDialOptions.(func() DialOption)()); !strings.Contains(fmt.Sprint(err), noTSecStr) {
t.Fatalf("Dialing received unexpected error: %v, want error containing \"%v\"", err, noTSecStr)
}
internal.ClearGlobalDialOptions()
}

type testPerTargetDialOption struct{}

func (do *testPerTargetDialOption) DialOptionForTarget(parsedTarget url.URL) DialOption {
if parsedTarget.Scheme == "passthrough" {
return WithTransportCredentials(insecure.NewCredentials()) // credentials provided, should pass NewClient.
}
return EmptyDialOption{} // no credentials, should fail NewClient
}

// TestGlobalPerTargetDialOption configures a global per target dial option that
// produces transport credentials for channels using "passthrough" scheme.
// Channels that use the passthrough scheme should be successfully created due
// to picking up transport credentials, whereas other channels should fail at
// creation due to not having transport credentials.
func (s) TestGlobalPerTargetDialOption(t *testing.T) {
internal.AddGlobalPerTargetDialOptions.(func(opt any))(&testPerTargetDialOption{})
defer internal.ClearGlobalPerTargetDialOptions()
noTSecStr := "no transport security set"
if _, err := NewClient("dns:///fake"); !strings.Contains(fmt.Sprint(err), noTSecStr) {
t.Fatalf("Dialing received unexpected error: %v, want error containing \"%v\"", err, noTSecStr)
}
cc, err := NewClient("passthrough:///nice")
if err != nil {
t.Fatalf("Dialing with insecure credentials failed: %v", err)
}
cc.Close()
}

func (s) TestAddGlobalServerOptions(t *testing.T) {
const maxRecvSize = 998765
// Set and check the ServerOptions
opts := []ServerOption{Creds(insecure.NewCredentials()), MaxRecvMsgSize(maxRecvSize)}
internal.AddGlobalServerOptions.(func(opt ...ServerOption))(opts...)
defer internal.ClearGlobalServerOptions()
for i, opt := range opts {
if globalServerOptions[i] != opt {
t.Fatalf("Unexpected global server option at index %d: %v != %v", i, globalServerOptions[i], opt)
Expand Down
22 changes: 22 additions & 0 deletions dialoptions.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package grpc
import (
"context"
"net"
"net/url"
"time"

"google.golang.org/grpc/backoff"
Expand All @@ -43,6 +44,14 @@ func init() {
internal.ClearGlobalDialOptions = func() {
globalDialOptions = nil
}
internal.AddGlobalPerTargetDialOptions = func(opt any) {
if ptdo, ok := opt.(perTargetDialOption); ok {
globalPerTargetDialOptions = append(globalPerTargetDialOptions, ptdo)
}
}
internal.ClearGlobalPerTargetDialOptions = func() {
globalPerTargetDialOptions = nil
}
internal.WithBinaryLogger = withBinaryLogger
internal.JoinDialOptions = newJoinDialOption
internal.DisableGlobalDialOptions = newDisableGlobalDialOptions
Expand Down Expand Up @@ -89,6 +98,19 @@ type DialOption interface {

var globalDialOptions []DialOption

// perTargetDialOption takes a parsed target and returns a dial option to apply.
//
// This gets called after NewClient() parses the target, and allows per target
// configuration set through a returned DialOption. The DialOption will not take
// effect if specifies a resolver builder, as that Dial Option is factored in
// while parsing target.
type perTargetDialOption interface {
// DialOption returns a Dial Option to apply.
DialOptionForTarget(parsedTarget url.URL) DialOption
}

var globalPerTargetDialOptions []perTargetDialOption

// EmptyDialOption does not alter the dial configuration. It can be embedded in
// another structure to build custom dial options.
//
Expand Down
20 changes: 16 additions & 4 deletions internal/internal.go
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,14 @@ var (
// This is used in the 1.0 release of gcp/observability, and thus must not be
// deleted or changed.
ClearGlobalDialOptions func()

// AddGlobalPerTargetDialOptions adds a PerTargetDialOption that will be
// configured for newly created ClientConns.
AddGlobalPerTargetDialOptions any // func (opt any)
// ClearGlobalPerTargetDialOptions clears the slice of global late apply
// dial options.
ClearGlobalPerTargetDialOptions func()

// JoinDialOptions combines the dial options passed as arguments into a
// single dial option.
JoinDialOptions any // func(...grpc.DialOption) grpc.DialOption
Expand All @@ -126,7 +134,8 @@ var (
// deleted or changed.
BinaryLogger any // func(binarylog.Logger) grpc.ServerOption

// SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a provided grpc.ClientConn
// SubscribeToConnectivityStateChanges adds a grpcsync.Subscriber to a
// provided grpc.ClientConn.
SubscribeToConnectivityStateChanges any // func(*grpc.ClientConn, grpcsync.Subscriber)

// NewXDSResolverWithConfigForTesting creates a new xds resolver builder using
Expand Down Expand Up @@ -195,14 +204,17 @@ var (
// resource name.
TriggerXDSResourceNameNotFoundClient any // func(string, string) error

// FromOutgoingContextRaw returns the un-merged, intermediary contents of metadata.rawMD.
// FromOutgoingContextRaw returns the un-merged, intermediary contents of
// metadata.rawMD.
FromOutgoingContextRaw any // func(context.Context) (metadata.MD, [][]string, bool)

// UserSetDefaultScheme is set to true if the user has overridden the default resolver scheme.
// UserSetDefaultScheme is set to true if the user has overridden the
// default resolver scheme.
UserSetDefaultScheme bool = false
)

// HealthChecker defines the signature of the client-side LB channel health checking function.
// HealthChecker defines the signature of the client-side LB channel health
// checking function.
//
// The implementation is expected to create a health checking RPC stream by
// calling newStream(), watch for the health status of serviceName, and report
Expand Down

0 comments on commit aea78bd

Please sign in to comment.