diff --git a/balancer/rls/internal/balancer.go b/balancer/rls/internal/balancer.go new file mode 100644 index 000000000000..7c4a4817466a --- /dev/null +++ b/balancer/rls/internal/balancer.go @@ -0,0 +1,211 @@ +// +build go1.10 + +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package rls + +import ( + "sync" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/grpcsync" + "google.golang.org/grpc/resolver" +) + +var ( + _ balancer.Balancer = (*rlsBalancer)(nil) + _ balancer.V2Balancer = (*rlsBalancer)(nil) + + // For overriding in tests. + newRLSClientFunc = newRLSClient +) + +// rlsBalancer implements the RLS LB policy. +type rlsBalancer struct { + done *grpcsync.Event + cc balancer.ClientConn + opts balancer.BuildOptions + + // Mutex protects all the state maintained by the LB policy. + // TODO(easwars): Once we add the cache, we will also have another lock for + // the cache alone. + mu sync.Mutex + lbCfg *lbConfig // Most recently received service config. + rlsCC *grpc.ClientConn // ClientConn to the RLS server. + rlsC *rlsClient // RLS client wrapper. + + ccUpdateCh chan *balancer.ClientConnState +} + +// run is a long running goroutine which handles all the updates that the +// balancer wishes to handle. The appropriate updateHandler will push the update +// on to a channel that this goroutine will select on, thereby the handling of +// the update will happen asynchronously. +func (lb *rlsBalancer) run() { + for { + // TODO(easwars): Handle other updates like subConn state changes, RLS + // responses from the server etc. + select { + case u := <-lb.ccUpdateCh: + lb.handleClientConnUpdate(u) + case <-lb.done.Done(): + return + } + } +} + +// handleClientConnUpdate handles updates to the service config. +// If the RLS server name or the RLS RPC timeout changes, it updates the control +// channel accordingly. +// TODO(easwars): Handle updates to other fields in the service config. +func (lb *rlsBalancer) handleClientConnUpdate(ccs *balancer.ClientConnState) { + grpclog.Infof("rls: service config: %+v", ccs.BalancerConfig) + lb.mu.Lock() + defer lb.mu.Unlock() + + if lb.done.HasFired() { + grpclog.Warning("rls: received service config after balancer close") + return + } + + newCfg := ccs.BalancerConfig.(*lbConfig) + if lb.lbCfg.Equal(newCfg) { + grpclog.Info("rls: new service config matches existing config") + return + } + + lb.updateControlChannel(newCfg) + lb.lbCfg = newCfg +} + +// UpdateClientConnState pushes the received ClientConnState update on the +// update channel which will be processed asynchronously by the run goroutine. +// Implements balancer.V2Balancer interface. +func (lb *rlsBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { + select { + case lb.ccUpdateCh <- &ccs: + case <-lb.done.Done(): + } + return nil +} + +// ResolverErr implements balancer.V2Balancer interface. +func (lb *rlsBalancer) ResolverError(error) { + // ResolverError is called by gRPC when the name resolver reports an error. + // TODO(easwars): How do we handle this? + grpclog.Fatal("rls: ResolverError is not yet unimplemented") +} + +// UpdateSubConnState implements balancer.V2Balancer interface. +func (lb *rlsBalancer) UpdateSubConnState(_ balancer.SubConn, _ balancer.SubConnState) { + grpclog.Fatal("rls: UpdateSubConnState is not yet implemented") +} + +// Cleans up the resources allocated by the LB policy including the clientConn +// to the RLS server. +// Implements balancer.Balancer and balancer.V2Balancer interfaces. +func (lb *rlsBalancer) Close() { + lb.mu.Lock() + defer lb.mu.Unlock() + + lb.done.Fire() + if lb.rlsCC != nil { + lb.rlsCC.Close() + } +} + +// HandleSubConnStateChange implements balancer.Balancer interface. +func (lb *rlsBalancer) HandleSubConnStateChange(_ balancer.SubConn, _ connectivity.State) { + grpclog.Fatal("UpdateSubConnState should be called instead of HandleSubConnStateChange") +} + +// HandleResolvedAddrs implements balancer.Balancer interface. +func (lb *rlsBalancer) HandleResolvedAddrs(_ []resolver.Address, _ error) { + grpclog.Fatal("UpdateClientConnState should be called instead of HandleResolvedAddrs") +} + +// updateControlChannel updates the RLS client if required. +// Caller must hold lb.mu. +func (lb *rlsBalancer) updateControlChannel(newCfg *lbConfig) { + oldCfg := lb.lbCfg + if newCfg.lookupService == oldCfg.lookupService && newCfg.lookupServiceTimeout == oldCfg.lookupServiceTimeout { + return + } + + // Use RPC timeout from new config, if different from existing one. + timeout := oldCfg.lookupServiceTimeout + if timeout != newCfg.lookupServiceTimeout { + timeout = newCfg.lookupServiceTimeout + } + + if newCfg.lookupService == oldCfg.lookupService { + // This is the case where only the timeout has changed. We will continue + // to use the existing clientConn. but will create a new rlsClient with + // the new timeout. + lb.rlsC = newRLSClientFunc(lb.rlsCC, lb.opts.Target.Endpoint, timeout) + return + } + + // This is the case where the RLS server name has changed. We need to create + // a new clientConn and close the old one. + var dopts []grpc.DialOption + if dialer := lb.opts.Dialer; dialer != nil { + dopts = append(dopts, grpc.WithContextDialer(dialer)) + } + dopts = append(dopts, dialCreds(lb.opts)) + + cc, err := grpc.Dial(newCfg.lookupService, dopts...) + if err != nil { + grpclog.Errorf("rls: dialRLS(%s, %v): %v", newCfg.lookupService, lb.opts, err) + // An error from a non-blocking dial indicates something serious. We + // should continue to use the old control channel if one exists, and + // return so that the rest of the config updates can be processes. + return + } + if lb.rlsCC != nil { + lb.rlsCC.Close() + } + lb.rlsCC = cc + lb.rlsC = newRLSClientFunc(cc, lb.opts.Target.Endpoint, timeout) +} + +func dialCreds(opts balancer.BuildOptions) grpc.DialOption { + // The control channel should use the same authority as that of the parent + // channel. This ensures that the identify of the RLS server and that of the + // backend is the same, so if the RLS config is injected by an attacker, it + // cannot cause leakage of private information contained in headers set by + // the application. + server := opts.Target.Authority + switch { + case opts.DialCreds != nil: + if err := opts.DialCreds.OverrideServerName(server); err != nil { + grpclog.Warningf("rls: OverrideServerName(%s) = (%v), using Insecure", server, err) + return grpc.WithInsecure() + } + return grpc.WithTransportCredentials(opts.DialCreds) + case opts.CredsBundle != nil: + return grpc.WithTransportCredentials(opts.CredsBundle.TransportCredentials()) + default: + grpclog.Warning("rls: no credentials available, using Insecure") + return grpc.WithInsecure() + } +} diff --git a/balancer/rls/internal/balancer_test.go b/balancer/rls/internal/balancer_test.go new file mode 100644 index 000000000000..990372d0e98f --- /dev/null +++ b/balancer/rls/internal/balancer_test.go @@ -0,0 +1,228 @@ +// +build go1.10 + +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package rls + +import ( + "net" + "testing" + "time" + + "google.golang.org/grpc" + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver" + "google.golang.org/grpc/credentials" + "google.golang.org/grpc/internal/grpctest" + "google.golang.org/grpc/internal/testutils" + "google.golang.org/grpc/testdata" +) + +type s struct { + grpctest.Tester +} + +func Test(t *testing.T) { + grpctest.RunSubTests(t, s{}) +} + +type listenerWrapper struct { + net.Listener + connCh *testutils.Channel +} + +// Accept waits for and returns the next connection to the listener. +func (l *listenerWrapper) Accept() (net.Conn, error) { + c, err := l.Listener.Accept() + if err != nil { + return nil, err + } + l.connCh.Send(c) + return c, nil +} + +func setupwithListener(t *testing.T, opts ...grpc.ServerOption) (*fakeserver.Server, *listenerWrapper, func()) { + t.Helper() + + l, err := net.Listen("tcp", "localhost:0") + if err != nil { + t.Fatalf("net.Listen(tcp, localhost:0): %v", err) + } + lw := &listenerWrapper{ + Listener: l, + connCh: testutils.NewChannel(), + } + + server, cleanup, err := fakeserver.Start(lw, opts...) + if err != nil { + t.Fatalf("fakeserver.Start(): %v", err) + } + t.Logf("Fake RLS server started at %s ...", server.Address) + + return server, lw, cleanup +} + +type testBalancerCC struct { + balancer.ClientConn +} + +// TestUpdateControlChannelFirstConfig tests the scenario where the LB policy +// receives its first service config and verifies that a control channel to the +// RLS server specified in the serviceConfig is established. +func (s) TestUpdateControlChannelFirstConfig(t *testing.T) { + server, lis, cleanup := setupwithListener(t) + defer cleanup() + + bb := balancer.Get(rlsBalancerName) + if bb == nil { + t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) + } + rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer) + defer rlsB.Close() + t.Log("Built RLS LB policy ...") + + lbCfg := &lbConfig{lookupService: server.Address} + t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) + rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) + + if _, err := lis.connCh.Receive(); err != nil { + t.Fatal("Timeout expired when waiting for LB policy to create control channel") + } + + // TODO: Verify channel connectivity state once control channel connectivity + // state monitoring is in place. + + // TODO: Verify RLS RPC can be made once we integrate with the picker. +} + +// TestUpdateControlChannelSwitch tests the scenario where a control channel +// exists and the LB policy receives a new serviceConfig with a different RLS +// server name. Verifies that the new control channel is created and the old one +// is closed (the leakchecker takes care of this). +func (s) TestUpdateControlChannelSwitch(t *testing.T) { + server1, lis1, cleanup1 := setupwithListener(t) + defer cleanup1() + + server2, lis2, cleanup2 := setupwithListener(t) + defer cleanup2() + + bb := balancer.Get(rlsBalancerName) + if bb == nil { + t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) + } + rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer) + defer rlsB.Close() + t.Log("Built RLS LB policy ...") + + lbCfg := &lbConfig{lookupService: server1.Address} + t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) + rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) + + if _, err := lis1.connCh.Receive(); err != nil { + t.Fatal("Timeout expired when waiting for LB policy to create control channel") + } + + lbCfg = &lbConfig{lookupService: server2.Address} + t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) + rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) + + if _, err := lis2.connCh.Receive(); err != nil { + t.Fatal("Timeout expired when waiting for LB policy to create control channel") + } + + // TODO: Verify channel connectivity state once control channel connectivity + // state monitoring is in place. + + // TODO: Verify RLS RPC can be made once we integrate with the picker. +} + +// TestUpdateControlChannelTimeout tests the scenario where the LB policy +// receives a service config update with a different lookupServiceTimeout, but +// the lookupService itself remains unchanged. It verifies that the LB policy +// does not create a new control channel in this case. +func (s) TestUpdateControlChannelTimeout(t *testing.T) { + server, lis, cleanup := setupwithListener(t) + defer cleanup() + + bb := balancer.Get(rlsBalancerName) + if bb == nil { + t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) + } + rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{}).(balancer.V2Balancer) + defer rlsB.Close() + t.Log("Built RLS LB policy ...") + + lbCfg := &lbConfig{lookupService: server.Address, lookupServiceTimeout: 1 * time.Second} + t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) + rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) + if _, err := lis.connCh.Receive(); err != nil { + t.Fatal("Timeout expired when waiting for LB policy to create control channel") + } + + lbCfg = &lbConfig{lookupService: server.Address, lookupServiceTimeout: 2 * time.Second} + t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) + rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) + if _, err := lis.connCh.Receive(); err != testutils.ErrRecvTimeout { + t.Fatal("LB policy created new control channel when only lookupServiceTimeout changed") + } + + // TODO: Verify channel connectivity state once control channel connectivity + // state monitoring is in place. + + // TODO: Verify RLS RPC can be made once we integrate with the picker. +} + +// TestUpdateControlChannelWithCreds tests the scenario where the control +// channel is to established with credentials from the parent channel. +func (s) TestUpdateControlChannelWithCreds(t *testing.T) { + sCreds, err := credentials.NewServerTLSFromFile(testdata.Path("server1.pem"), testdata.Path("server1.key")) + if err != nil { + t.Fatalf("credentials.NewServerTLSFromFile(server1.pem, server1.key) = %v", err) + } + cCreds, err := credentials.NewClientTLSFromFile(testdata.Path("ca.pem"), "") + if err != nil { + t.Fatalf("credentials.NewClientTLSFromFile(ca.pem) = %v", err) + } + + server, lis, cleanup := setupwithListener(t, grpc.Creds(sCreds)) + defer cleanup() + + bb := balancer.Get(rlsBalancerName) + if bb == nil { + t.Fatalf("balancer.Get(%s) = nil", rlsBalancerName) + } + rlsB := bb.Build(&testBalancerCC{}, balancer.BuildOptions{ + DialCreds: cCreds, + }).(balancer.V2Balancer) + defer rlsB.Close() + t.Log("Built RLS LB policy ...") + + lbCfg := &lbConfig{lookupService: server.Address} + t.Logf("Sending service config %+v to RLS LB policy ...", lbCfg) + rlsB.UpdateClientConnState(balancer.ClientConnState{BalancerConfig: lbCfg}) + + if _, err := lis.connCh.Receive(); err != nil { + t.Fatal("Timeout expired when waiting for LB policy to create control channel") + } + + // TODO: Verify channel connectivity state once control channel connectivity + // state monitoring is in place. + + // TODO: Verify RLS RPC can be made once we integrate with the picker. +} diff --git a/balancer/rls/internal/builder.go b/balancer/rls/internal/builder.go index ddb5e3cf0bb2..c38babff4d3d 100644 --- a/balancer/rls/internal/builder.go +++ b/balancer/rls/internal/builder.go @@ -21,16 +21,35 @@ // Package rls implements the RLS LB policy. package rls +import ( + "google.golang.org/grpc/balancer" + "google.golang.org/grpc/internal/grpcsync" +) + const rlsBalancerName = "rls" +func init() { + balancer.Register(&rlsBB{}) +} + // rlsBB helps build RLS load balancers and parse the service config to be // passed to the RLS load balancer. -type rlsBB struct { - // TODO(easwars): Implement the Build() method and register the builder. -} +type rlsBB struct{} // Name returns the name of the RLS LB policy and helps implement the // balancer.Balancer interface. func (*rlsBB) Name() string { return rlsBalancerName } + +func (*rlsBB) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { + lb := &rlsBalancer{ + done: grpcsync.NewEvent(), + cc: cc, + opts: opts, + lbCfg: &lbConfig{}, + ccUpdateCh: make(chan *balancer.ClientConnState), + } + go lb.run() + return lb +} diff --git a/balancer/rls/internal/client.go b/balancer/rls/internal/client.go index b6fe22572949..0e8a1c932f11 100644 --- a/balancer/rls/internal/client.go +++ b/balancer/rls/internal/client.go @@ -43,7 +43,6 @@ const grpcTargetType = "grpc" // throttling and asks this client to make an RPC call only after checking with // the throttler. type rlsClient struct { - cc *grpc.ClientConn stub rlspb.RouteLookupServiceClient // origDialTarget is the original dial target of the user and sent in each // RouteLookup RPC made to the RLS server. @@ -55,7 +54,6 @@ type rlsClient struct { func newRLSClient(cc *grpc.ClientConn, dialTarget string, rpcTimeout time.Duration) *rlsClient { return &rlsClient{ - cc: cc, stub: rlspb.NewRouteLookupServiceClient(cc), origDialTarget: dialTarget, rpcTimeout: rpcTimeout, diff --git a/balancer/rls/internal/client_test.go b/balancer/rls/internal/client_test.go index 386267b9033d..1a1a75d1be98 100644 --- a/balancer/rls/internal/client_test.go +++ b/balancer/rls/internal/client_test.go @@ -1,3 +1,5 @@ +// +build go1.10 + /* * * Copyright 2020 gRPC authors. @@ -30,25 +32,26 @@ import ( rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1" "google.golang.org/grpc/balancer/rls/internal/testutils/fakeserver" "google.golang.org/grpc/codes" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/status" ) const ( - defaultDialTarget = "dummy" - defaultRPCTimeout = 5 * time.Second - defaultTestTimeout = 1 * time.Second + defaultDialTarget = "dummy" + defaultRPCTimeout = 5 * time.Second ) func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) { t.Helper() - server, sCleanup, err := fakeserver.Start() + server, sCleanup, err := fakeserver.Start(nil) if err != nil { t.Fatalf("Failed to start fake RLS server: %v", err) } cc, cCleanup, err := server.ClientConn() if err != nil { + sCleanup() t.Fatalf("Failed to get a ClientConn to the RLS server: %v", err) } @@ -59,7 +62,7 @@ func setup(t *testing.T) (*fakeserver.Server, *grpc.ClientConn, func()) { } // TestLookupFailure verifies the case where the RLS server returns an error. -func TestLookupFailure(t *testing.T) { +func (s) TestLookupFailure(t *testing.T) { server, cc, cleanup := setup(t) defer cleanup() @@ -68,64 +71,50 @@ func TestLookupFailure(t *testing.T) { rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout) - errCh := make(chan error) + errCh := testutils.NewChannel() rlsClient.lookup("", nil, func(targets []string, headerData string, err error) { if err == nil { - errCh <- errors.New("rlsClient.lookup() succeeded, should have failed") + errCh.Send(errors.New("rlsClient.lookup() succeeded, should have failed")) return } if len(targets) != 0 || headerData != "" { - errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData) + errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (nil, \"\")", targets, headerData)) return } - errCh <- nil + errCh.Send(nil) }) - timer := time.NewTimer(defaultTestTimeout) - select { - case <-timer.C: - t.Fatal("Timeout when expecting a routeLookup callback") - case err := <-errCh: - timer.Stop() - if err != nil { - t.Fatal(err) - } + if e, err := errCh.Receive(); err != nil || e != nil { + t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) } } // TestLookupDeadlineExceeded tests the case where the RPC deadline associated // with the lookup expires. -func TestLookupDeadlineExceeded(t *testing.T) { +func (s) TestLookupDeadlineExceeded(t *testing.T) { _, cc, cleanup := setup(t) defer cleanup() // Give the Lookup RPC a small deadline, but don't setup the fake server to - // return anything. So the Lookup call will block and eventuall expire. + // return anything. So the Lookup call will block and eventually expire. rlsClient := newRLSClient(cc, defaultDialTarget, 100*time.Millisecond) - errCh := make(chan error) + errCh := testutils.NewChannel() rlsClient.lookup("", nil, func(_ []string, _ string, err error) { if st, ok := status.FromError(err); !ok || st.Code() != codes.DeadlineExceeded { - errCh <- fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded) + errCh.Send(fmt.Errorf("rlsClient.lookup() returned error: %v, want %v", err, codes.DeadlineExceeded)) return } - errCh <- nil + errCh.Send(nil) }) - timer := time.NewTimer(defaultTestTimeout) - select { - case <-timer.C: - t.Fatal("Timeout when expecting a routeLookup callback") - case err := <-errCh: - timer.Stop() - if err != nil { - t.Fatal(err) - } + if e, err := errCh.Receive(); err != nil || e != nil { + t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) } } // TestLookupSuccess verifies the successful Lookup API case. -func TestLookupSuccess(t *testing.T) { +func (s) TestLookupSuccess(t *testing.T) { server, cc, cleanup := setup(t) defer cleanup() @@ -148,33 +137,29 @@ func TestLookupSuccess(t *testing.T) { rlsClient := newRLSClient(cc, defaultDialTarget, defaultRPCTimeout) - errCh := make(chan error) + errCh := testutils.NewChannel() rlsClient.lookup(rlsReqPath, rlsReqKeyMap, func(targets []string, hd string, err error) { if err != nil { - errCh <- fmt.Errorf("rlsClient.Lookup() failed: %v", err) + errCh.Send(fmt.Errorf("rlsClient.Lookup() failed: %v", err)) return } if !cmp.Equal(targets, wantRespTargets) || hd != wantHeaderData { - errCh <- fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData) + errCh.Send(fmt.Errorf("rlsClient.lookup() = (%v, %s), want (%v, %s)", targets, hd, wantRespTargets, wantHeaderData)) return } - errCh <- nil + errCh.Send(nil) }) // Make sure that the fake server received the expected RouteLookupRequest // proto. - timer := time.NewTimer(defaultTestTimeout) - select { - case gotLookupRequest := <-server.RequestChan: - if !timer.Stop() { - <-timer.C - } - if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" { - t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff) - } - case <-timer.C: + req, err := server.RequestChan.Receive() + if err != nil { t.Fatalf("Timed out wile waiting for a RouteLookupRequest") } + gotLookupRequest := req.(*rlspb.RouteLookupRequest) + if diff := cmp.Diff(wantLookupRequest, gotLookupRequest, cmp.Comparer(proto.Equal)); diff != "" { + t.Fatalf("RouteLookupRequest diff (-want, +got):\n%s", diff) + } // We setup the fake server to return this response when it receives a // request. @@ -185,14 +170,7 @@ func TestLookupSuccess(t *testing.T) { }, } - timer = time.NewTimer(defaultTestTimeout) - select { - case <-timer.C: - t.Fatal("Timeout when expecting a routeLookup callback") - case err := <-errCh: - timer.Stop() - if err != nil { - t.Fatal(err) - } + if e, err := errCh.Receive(); err != nil || e != nil { + t.Fatalf("lookup error: %v, error receiving from channel: %v", e, err) } } diff --git a/balancer/rls/internal/config.go b/balancer/rls/internal/config.go index e1e36e445fa6..816ab093a650 100644 --- a/balancer/rls/internal/config.go +++ b/balancer/rls/internal/config.go @@ -71,6 +71,40 @@ type lbConfig struct { cpConfig map[string]json.RawMessage } +func (lbCfg *lbConfig) Equal(other *lbConfig) bool { + return lbCfg.kbMap.Equal(other.kbMap) && + lbCfg.lookupService == other.lookupService && + lbCfg.lookupServiceTimeout == other.lookupServiceTimeout && + lbCfg.maxAge == other.maxAge && + lbCfg.staleAge == other.staleAge && + lbCfg.cacheSizeBytes == other.cacheSizeBytes && + lbCfg.rpStrategy == other.rpStrategy && + lbCfg.defaultTarget == other.defaultTarget && + lbCfg.cpName == other.cpName && + lbCfg.cpTargetField == other.cpTargetField && + cpConfigEqual(lbCfg.cpConfig, other.cpConfig) +} + +func cpConfigEqual(am, bm map[string]json.RawMessage) bool { + if (bm == nil) != (am == nil) { + return false + } + if len(bm) != len(am) { + return false + } + + for k, jsonA := range am { + jsonB, ok := bm[k] + if !ok { + return false + } + if !bytes.Equal(jsonA, jsonB) { + return false + } + } + return true +} + // This struct resembles the JSON respresentation of the loadBalancing config // and makes it easier to unmarshal. type lbConfigJSON struct { diff --git a/balancer/rls/internal/config_test.go b/balancer/rls/internal/config_test.go index 6285f0cdadb8..9200a29d8a3c 100644 --- a/balancer/rls/internal/config_test.go +++ b/balancer/rls/internal/config_test.go @@ -49,20 +49,22 @@ func init() { balancer.Register(&dummyBB{}) } -func (lbCfg *lbConfig) Equal(other *lbConfig) bool { - // This only ignores the keyBuilderMap field because its internals are not - // exported, and hence not possible to specify in the want section of the - // test. - return lbCfg.lookupService == other.lookupService && - lbCfg.lookupServiceTimeout == other.lookupServiceTimeout && - lbCfg.maxAge == other.maxAge && - lbCfg.staleAge == other.staleAge && - lbCfg.cacheSizeBytes == other.cacheSizeBytes && - lbCfg.rpStrategy == other.rpStrategy && - lbCfg.defaultTarget == other.defaultTarget && - lbCfg.cpName == other.cpName && - lbCfg.cpTargetField == other.cpTargetField && - cmp.Equal(lbCfg.cpConfig, other.cpConfig) +// testEqual reports whether the lbCfgs a and b are equal. This is to be used +// only from tests. This ignores the keyBuilderMap field because its internals +// are not exported, and hence not possible to specify in the want section of +// the test. This is fine because we already have tests to make sure that the +// keyBuilder is parsed properly from the service config. +func testEqual(a, b *lbConfig) bool { + return a.lookupService == b.lookupService && + a.lookupServiceTimeout == b.lookupServiceTimeout && + a.maxAge == b.maxAge && + a.staleAge == b.staleAge && + a.cacheSizeBytes == b.cacheSizeBytes && + a.rpStrategy == b.rpStrategy && + a.defaultTarget == b.defaultTarget && + a.cpName == b.cpName && + a.cpTargetField == b.cpTargetField && + cmp.Equal(a.cpConfig, b.cpConfig) } func TestParseConfig(t *testing.T) { @@ -152,7 +154,7 @@ func TestParseConfig(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { lbCfg, err := builder.ParseConfig(test.input) - if err != nil || !cmp.Equal(lbCfg, test.wantCfg) { + if err != nil || !testEqual(lbCfg.(*lbConfig), test.wantCfg) { t.Errorf("ParseConfig(%s) = {%+v, %v}, want {%+v, nil}", string(test.input), lbCfg, err, test.wantCfg) } }) diff --git a/balancer/rls/internal/picker_test.go b/balancer/rls/internal/picker_test.go index a58db3b332fa..b14a9fad340a 100644 --- a/balancer/rls/internal/picker_test.go +++ b/balancer/rls/internal/picker_test.go @@ -28,13 +28,13 @@ import ( "testing" "time" - "google.golang.org/grpc/internal/grpcrand" - "github.com/google/go-cmp/cmp" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/rls/internal/cache" "google.golang.org/grpc/balancer/rls/internal/keys" rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1" + "google.golang.org/grpc/internal/grpcrand" + "google.golang.org/grpc/internal/testutils" "google.golang.org/grpc/metadata" ) @@ -502,7 +502,7 @@ func TestPick(t *testing.T) { for _, test := range tests { t.Run(test.desc, func(t *testing.T) { - rlsCh := make(chan error, 1) + rlsCh := testutils.NewChannel() randID := grpcrand.Intn(math.MaxInt32) // We instantiate a fakeChildPicker which will return a fakeSubConn // with configured id. Either the childPicker or the defaultPicker @@ -525,18 +525,18 @@ func TestPick(t *testing.T) { shouldThrottle: func() bool { return test.throttle }, startRLS: func(path string, km keys.KeyMap) { if !test.newRLSRequest { - rlsCh <- errors.New("RLS request attempted when none was expected") + rlsCh.Send(errors.New("RLS request attempted when none was expected")) return } if path != rpcPath { - rlsCh <- fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath) + rlsCh.Send(fmt.Errorf("RLS request initiated for rpcPath %s, want %s", path, rpcPath)) return } if km.Str != wantKeyMapStr { - rlsCh <- fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr) + rlsCh.Send(fmt.Errorf("RLS request initiated with keys %v, want %v", km.Str, wantKeyMapStr)) return } - rlsCh <- nil + rlsCh.Send(nil) }, defaultPick: func(info balancer.PickInfo) (balancer.PickResult, error) { if !test.useDefaultPick { @@ -569,15 +569,8 @@ func TestPick(t *testing.T) { // If the test specified that a new RLS request should be made, // verify it. if test.newRLSRequest { - timer := time.NewTimer(defaultTestTimeout) - select { - case err := <-rlsCh: - timer.Stop() - if err != nil { - t.Fatal(err) - } - case <-timer.C: - t.Fatal("Timeout waiting for RLS request to be sent out") + if rlsErr, err := rlsCh.Receive(); err != nil || rlsErr != nil { + t.Fatalf("startRLS() = %v, error receiving from channel: %v", rlsErr, err) } } }) diff --git a/balancer/rls/internal/testutils/fakeserver/fakeserver.go b/balancer/rls/internal/testutils/fakeserver/fakeserver.go index 1cdf81550243..93947da4ccef 100644 --- a/balancer/rls/internal/testutils/fakeserver/fakeserver.go +++ b/balancer/rls/internal/testutils/fakeserver/fakeserver.go @@ -22,6 +22,7 @@ package fakeserver import ( "context" + "errors" "fmt" "net" "time" @@ -29,9 +30,14 @@ import ( "google.golang.org/grpc" rlsgrpc "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1" rlspb "google.golang.org/grpc/balancer/rls/internal/proto/grpc_lookup_v1" + "google.golang.org/grpc/internal/testutils" ) -const defaultDialTimeout = 5 * time.Second +const ( + defaultDialTimeout = 5 * time.Second + defaultRPCTimeout = 5 * time.Second + defaultChannelBufferSize = 50 +) // Response wraps the response protobuf (xds/LRS) and error that the Server // should send out to the client through a call to stream.Send() @@ -43,29 +49,31 @@ type Response struct { // Server is a fake implementation of RLS. It exposes channels to send/receive // RLS requests and responses. type Server struct { - RequestChan chan *rlspb.RouteLookupRequest + RequestChan *testutils.Channel ResponseChan chan Response Address string } -// Start makes a new Server and gets it to start listening on a local port for -// gRPC requests. The returned cancel function should be invoked by the caller -// upon completion of the test. -func Start() (*Server, func(), error) { - lis, err := net.Listen("tcp", "localhost:0") - if err != nil { - return nil, func() {}, fmt.Errorf("net.Listen() failed: %v", err) +// Start makes a new Server which uses the provided net.Listener. If lis is nil, +// it creates a new net.Listener on a local port. The returned cancel function +// should be invoked by the caller upon completion of the test. +func Start(lis net.Listener, opts ...grpc.ServerOption) (*Server, func(), error) { + if lis == nil { + var err error + lis, err = net.Listen("tcp", "localhost:0") + if err != nil { + return nil, func() {}, fmt.Errorf("net.Listen() failed: %v", err) + } } - s := &Server{ // Give the channels a buffer size of 1 so that we can setup // expectations for one lookup call, without blocking. - RequestChan: make(chan *rlspb.RouteLookupRequest, 1), + RequestChan: testutils.NewChannelWithSize(defaultChannelBufferSize), ResponseChan: make(chan Response, 1), Address: lis.Addr().String(), } - server := grpc.NewServer() + server := grpc.NewServer(opts...) rlsgrpc.RegisterRouteLookupServiceServer(server, s) go server.Serve(lis) @@ -74,9 +82,17 @@ func Start() (*Server, func(), error) { // RouteLookup implements the RouteLookupService. func (s *Server) RouteLookup(ctx context.Context, req *rlspb.RouteLookupRequest) (*rlspb.RouteLookupResponse, error) { - s.RequestChan <- req - resp := <-s.ResponseChan - return resp.Resp, resp.Err + s.RequestChan.Send(req) + + // The leakchecker fails if we don't exit out of here in a reasonable time. + timer := time.NewTimer(defaultRPCTimeout) + select { + case <-timer.C: + return nil, errors.New("default RPC timeout exceeded") + case resp := <-s.ResponseChan: + timer.Stop() + return resp.Resp, resp.Err + } } // ClientConn returns a grpc.ClientConn connected to the fakeServer. diff --git a/internal/testutils/channel.go b/internal/testutils/channel.go new file mode 100644 index 000000000000..35f67ea285c1 --- /dev/null +++ b/internal/testutils/channel.go @@ -0,0 +1,68 @@ +/* + * + * Copyright 2020 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package testutils + +import ( + "errors" + "time" +) + +// ErrRecvTimeout is an error to indicate that a receive operation on the +// channel timed out. +var ErrRecvTimeout = errors.New("timed out when waiting for value on channel") + +const ( + // DefaultChanRecvTimeout is the default timeout for receive operations on the + // underlying channel. + DefaultChanRecvTimeout = 1 * time.Second + // DefaultChanBufferSize is the default buffer size of the underlying channel. + DefaultChanBufferSize = 1 +) + +// Channel wraps a generic channel and provides a timed receive operation. +type Channel struct { + ch chan interface{} +} + +// Send sends value on the underlying channel. +func (cwt *Channel) Send(value interface{}) { + cwt.ch <- value +} + +// Receive returns the value received on the underlying channel, or +// ErrRecvTimeout if DefaultChanRecvTimeout amount of time elapses. +func (cwt *Channel) Receive() (interface{}, error) { + timer := time.NewTimer(DefaultChanRecvTimeout) + select { + case <-timer.C: + return nil, ErrRecvTimeout + case got := <-cwt.ch: + timer.Stop() + return got, nil + } +} + +// NewChannel returns a new Channel. +func NewChannel() *Channel { + return NewChannelWithSize(DefaultChanBufferSize) +} + +// NewChannelWithSize returns a new Channel with a buffer of bufSize. +func NewChannelWithSize(bufSize int) *Channel { + return &Channel{ch: make(chan interface{}, bufSize)} +}