Skip to content

Commit

Permalink
alts: Read max number of concurrent ALTS handshakes from environment …
Browse files Browse the repository at this point in the history
…variable. (#6267)

* Read max number of concurrent ALTS handshakes from environment variable.

* Refactor to use new envconfig file.

* Remove impossible if condition in acquire().

* Use weighted semaphore.

* Add e2e test for concurrent ALTS handshakes.

* Separate into client and server semaphores.

* Use TryAcquire instead of Acquire.

* Attempt to fix go.sum error.

* Run go mod tidy compat=1.17.

* Update go.mod for examples subdirectory.

* Run go mod tidy -compat=1.17 on examples subdirectory.

* Update go.mod in subdirectories.

* Update go.mod in security/advancedtls/examples.

* Missed another go.mod update.

* Do not upgrade glog because it requires Golang 1.19.

* Fix glog version in examples/go.sum.

* More glog cleanup.

* Fix glog issue in gcp/observability/go.sum.

* Move ALTS env var into envconfig.go.

* Fix go.mod files.

* Revert go.sum files.

* Revert interop/observability/go.mod change.

* Run go mod tidy -compat=1.17 on examples/.

* Run gofmt.

* Add comment describing test init function.
  • Loading branch information
matthewstevenson88 authored Jun 8, 2023
1 parent 2ac1aae commit 907bdaa
Show file tree
Hide file tree
Showing 8 changed files with 95 additions and 73 deletions.
91 changes: 65 additions & 26 deletions credentials/alts/alts_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
"github.com/golang/protobuf/proto"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials/alts/internal/handshaker"
"google.golang.org/grpc/credentials/alts/internal/handshaker/service"
altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
Expand All @@ -51,6 +52,14 @@ type s struct {
grpctest.Tester
}

func init() {
// The vmOnGCP global variable MUST be forced to true. Otherwise, if
// this test is run anywhere except on a GCP VM, then an ALTS handshake
// will immediately fail.
once.Do(func() {})
vmOnGCP = true
}

func Test(t *testing.T) {
grpctest.RunSubTests(t, s{})
}
Expand Down Expand Up @@ -308,14 +317,6 @@ func (s) TestCheckRPCVersions(t *testing.T) {
// server, where both client and server offload to a local, fake handshaker
// service.
func (s) TestFullHandshake(t *testing.T) {
// The vmOnGCP global variable MUST be reset to true after the client
// or server credentials have been created, but before the ALTS
// handshake begins. If vmOnGCP is not reset and this test is run
// anywhere except for a GCP VM, then the ALTS handshake will
// immediately fail.
once.Do(func() {})
vmOnGCP = true

// Start the fake handshaker service and the server.
var wait sync.WaitGroup
defer wait.Wait()
Expand All @@ -325,26 +326,41 @@ func (s) TestFullHandshake(t *testing.T) {
defer stopServer()

// Ping the server, authenticating with ALTS.
clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
conn, err := grpc.Dial(serverAddress, grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("grpc.Dial(%v) failed: %v", serverAddress, err)
establishAltsConnection(t, handshakerAddress, serverAddress)

// Close open connections to the fake handshaker service.
if err := service.CloseForTesting(); err != nil {
t.Errorf("service.CloseForTesting() failed: %v", err)
}
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
defer cancel()
c := testgrpc.NewTestServiceClient(conn)
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
_, err = c.UnaryCall(ctx, &testpb.SimpleRequest{})
if err == nil {
break
}
if code := status.Code(err); code == codes.Unavailable {
// The server is not ready yet. Try again.
continue
}
t.Fatalf("c.UnaryCall() failed: %v", err)
}

// TestConcurrentHandshakes performs a several, concurrent ALTS handshakes
// between a test client and server, where both client and server offload to a
// local, fake handshaker service.
func (s) TestConcurrentHandshakes(t *testing.T) {
// Set the max number of concurrent handshakes to 3, so that we can
// test the handshaker behavior when handshakes are queued by
// performing more than 3 concurrent handshakes (specifically, 10).
handshaker.ResetConcurrentHandshakeSemaphoreForTesting(3)

// Start the fake handshaker service and the server.
var wait sync.WaitGroup
defer wait.Wait()
stopHandshaker, handshakerAddress := startFakeHandshakerService(t, &wait)
defer stopHandshaker()
stopServer, serverAddress := startServer(t, handshakerAddress, &wait)
defer stopServer()

// Ping the server, authenticating with ALTS.
var waitForConnections sync.WaitGroup
for i := 0; i < 10; i++ {
waitForConnections.Add(1)
go func() {
establishAltsConnection(t, handshakerAddress, serverAddress)
waitForConnections.Done()
}()
}
waitForConnections.Wait()

// Close open connections to the fake handshaker service.
if err := service.CloseForTesting(); err != nil {
Expand All @@ -366,6 +382,29 @@ func versions(minMajor, minMinor, maxMajor, maxMinor uint32) *altspb.RpcProtocol
}
}

func establishAltsConnection(t *testing.T, handshakerAddress, serverAddress string) {
clientCreds := NewClientCreds(&ClientOptions{HandshakerServiceAddress: handshakerAddress})
conn, err := grpc.Dial(serverAddress, grpc.WithTransportCredentials(clientCreds))
if err != nil {
t.Fatalf("grpc.Dial(%v) failed: %v", serverAddress, err)
}
defer conn.Close()
ctx, cancel := context.WithTimeout(context.Background(), defaultTestLongTimeout)
defer cancel()
c := testgrpc.NewTestServiceClient(conn)
for ; ctx.Err() == nil; <-time.After(defaultTestShortTimeout) {
_, err = c.UnaryCall(ctx, &testpb.SimpleRequest{})
if err == nil {
break
}
if code := status.Code(err); code == codes.Unavailable {
// The server is not ready yet. Try again.
continue
}
t.Fatalf("c.UnaryCall() failed: %v", err)
}
}

func startFakeHandshakerService(t *testing.T, wait *sync.WaitGroup) (stop func(), address string) {
listener, err := testutils.LocalTCPListener()
if err != nil {
Expand Down
57 changes: 16 additions & 41 deletions credentials/alts/internal/handshaker/handshaker.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ import (
"fmt"
"io"
"net"
"sync"

"golang.org/x/sync/semaphore"
grpc "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
Expand All @@ -35,15 +35,13 @@ import (
"google.golang.org/grpc/credentials/alts/internal/conn"
altsgrpc "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/internal/envconfig"
)

const (
// The maximum byte size of receive frames.
frameLimit = 64 * 1024 // 64 KB
rekeyRecordProtocolName = "ALTSRP_GCM_AES128_REKEY"
// maxPendingHandshakes represents the maximum number of concurrent
// handshakes.
maxPendingHandshakes = 100
)

var (
Expand All @@ -59,9 +57,9 @@ var (
return conn.NewAES128GCMRekey(s, keyData)
},
}
// control number of concurrent created (but not closed) handshakers.
mu sync.Mutex
concurrentHandshakes = int64(0)
// control number of concurrent created (but not closed) handshakes.
clientHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
serverHandshakes = semaphore.NewWeighted(int64(envconfig.ALTSMaxConcurrentHandshakes))
// errDropped occurs when maxPendingHandshakes is reached.
errDropped = errors.New("maximum number of concurrent ALTS handshakes is reached")
// errOutOfBound occurs when the handshake service returns a consumed
Expand All @@ -77,30 +75,6 @@ func init() {
}
}

func acquire() bool {
mu.Lock()
// If we need n to be configurable, we can pass it as an argument.
n := int64(1)
success := maxPendingHandshakes-concurrentHandshakes >= n
if success {
concurrentHandshakes += n
}
mu.Unlock()
return success
}

func release() {
mu.Lock()
// If we need n to be configurable, we can pass it as an argument.
n := int64(1)
concurrentHandshakes -= n
if concurrentHandshakes < 0 {
mu.Unlock()
panic("bad release")
}
mu.Unlock()
}

// ClientHandshakerOptions contains the client handshaker options that can
// provided by the caller.
type ClientHandshakerOptions struct {
Expand Down Expand Up @@ -134,10 +108,6 @@ func DefaultServerHandshakerOptions() *ServerHandshakerOptions {
return &ServerHandshakerOptions{}
}

// TODO: add support for future local and remote endpoint in both client options
// and server options (server options struct does not exist now. When
// caller can provide endpoints, it should be created.

// altsHandshaker is used to complete an ALTS handshake between client and
// server. This handshaker talks to the ALTS handshaker service in the metadata
// server.
Expand Down Expand Up @@ -185,10 +155,10 @@ func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn,
// ClientHandshake starts and completes a client ALTS handshake for GCP. Once
// done, ClientHandshake returns a secure connection.
func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
if !clientHandshakes.TryAcquire(1) {
return nil, nil, errDropped
}
defer release()
defer clientHandshakes.Release(1)

if h.side != core.ClientSide {
return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client handshaker")
Expand Down Expand Up @@ -238,10 +208,10 @@ func (h *altsHandshaker) ClientHandshake(ctx context.Context) (net.Conn, credent
// ServerHandshake starts and completes a server ALTS handshake for GCP. Once
// done, ServerHandshake returns a secure connection.
func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error) {
if !acquire() {
if !serverHandshakes.TryAcquire(1) {
return nil, nil, errDropped
}
defer release()
defer serverHandshakes.Release(1)

if h.side != core.ServerSide {
return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server handshaker")
Expand All @@ -264,8 +234,6 @@ func (h *altsHandshaker) ServerHandshake(ctx context.Context) (net.Conn, credent
}

// Prepare server parameters.
// TODO: currently only ALTS parameters are provided. Might need to use
// more options in the future.
params := make(map[int32]*altspb.ServerHandshakeParameters)
params[int32(altspb.HandshakeProtocol_ALTS)] = &altspb.ServerHandshakeParameters{
RecordProtocols: recordProtocols,
Expand Down Expand Up @@ -391,3 +359,10 @@ func (h *altsHandshaker) Close() {
h.stream.CloseSend()
}
}

// ResetConcurrentHandshakeSemaphoreForTesting resets the handshake semaphores
// to allow numberOfAllowedHandshakes concurrent handshakes each.
func ResetConcurrentHandshakeSemaphoreForTesting(numberOfAllowedHandshakes int64) {
clientHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes)
serverHandshakes = semaphore.NewWeighted(numberOfAllowedHandshakes)
}
13 changes: 7 additions & 6 deletions credentials/alts/internal/handshaker/handshaker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ import (
core "google.golang.org/grpc/credentials/alts/internal"
altspb "google.golang.org/grpc/credentials/alts/internal/proto/grpc_gcp"
"google.golang.org/grpc/credentials/alts/internal/testutil"
"google.golang.org/grpc/internal/envconfig"
"google.golang.org/grpc/internal/grpctest"
)

Expand Down Expand Up @@ -134,7 +135,7 @@ func (s) TestClientHandshake(t *testing.T) {
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)},
} {
errc := make(chan error)
stat.Reset()
Expand Down Expand Up @@ -182,8 +183,8 @@ func (s) TestClientHandshake(t *testing.T) {
}

// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes)
}
}
}
Expand All @@ -194,7 +195,7 @@ func (s) TestServerHandshake(t *testing.T) {
numberOfHandshakes int
}{
{0 * time.Millisecond, 1},
{100 * time.Millisecond, 10 * maxPendingHandshakes},
{100 * time.Millisecond, 10 * int(envconfig.ALTSMaxConcurrentHandshakes)},
} {
errc := make(chan error)
stat.Reset()
Expand Down Expand Up @@ -239,8 +240,8 @@ func (s) TestServerHandshake(t *testing.T) {
}

// Ensure that there are no concurrent calls more than the limit.
if stat.MaxConcurrentCalls > maxPendingHandshakes {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, maxPendingHandshakes)
if stat.MaxConcurrentCalls > int(envconfig.ALTSMaxConcurrentHandshakes) {
t.Errorf("Observed %d concurrent handshakes; want <= %d", stat.MaxConcurrentCalls, envconfig.ALTSMaxConcurrentHandshakes)
}
}
}
Expand Down
1 change: 1 addition & 0 deletions examples/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ require (
github.com/envoyproxy/go-control-plane v0.11.1-0.20230524094728-9239064ad72f // indirect
github.com/envoyproxy/protoc-gen-validate v0.10.1 // indirect
golang.org/x/net v0.9.0 // indirect
golang.org/x/sync v0.1.0 // indirect
golang.org/x/sys v0.7.0 // indirect
golang.org/x/text v0.9.0 // indirect
google.golang.org/appengine v1.6.7 // indirect
Expand Down
1 change: 1 addition & 0 deletions examples/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -1005,6 +1005,7 @@ golang.org/x/sync v0.0.0-20220601150217-0de741cfad7f/go.mod h1:RxMgew5VJxzue5/jJ
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220819030929-7fc1605a5dde/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220929204114-8fcdb60fdcc0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.1.0 h1:wsuoTGHzEhffawBOhz5CYhcrV4IdKZbEyZjBMuTp12o=
golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
Expand Down
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ require (
github.com/google/uuid v1.3.0
golang.org/x/net v0.9.0
golang.org/x/oauth2 v0.7.0
golang.org/x/sync v0.0.0-20190423024810-112230192c58
golang.org/x/sys v0.7.0
google.golang.org/genproto/googleapis/rpc v0.0.0-20230525234030-28d5490b6b19
google.golang.org/protobuf v1.30.0
Expand Down
1 change: 1 addition & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ golang.org/x/oauth2 v0.7.0 h1:qe6s0zUXlPX80/dITx3440hWZ7GwMwgDDyrSGTPJG/g=
golang.org/x/oauth2 v0.7.0/go.mod h1:hPLQkd9LyjfXTiRohC/41GhcFqxisoUQ99sCUOHO9x4=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20190423024810-112230192c58 h1:8gQV6CLnAEikrhgkHFbMAEhagSSnXWGV915qUMm9mrU=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
Expand Down
3 changes: 3 additions & 0 deletions internal/envconfig/envconfig.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ var (
// pick_first LB policy, which can be enabled by setting the environment
// variable "GRPC_EXPERIMENTAL_PICKFIRST_LB_CONFIG" to "true".
PickFirstLBConfig = boolFromEnv("GRPC_EXPERIMENTAL_PICKFIRST_LB_CONFIG", false)
// ALTSMaxConcurrentHandshakes is the maximum number of concurrent ALTS
// handshakes that can be performed.
ALTSMaxConcurrentHandshakes = uint64FromEnv("GRPC_ALTS_MAX_CONCURRENT_HANDSHAKES", 100, 1, 100)
)

func boolFromEnv(envVar string, def bool) bool {
Expand Down

0 comments on commit 907bdaa

Please sign in to comment.