Skip to content

Commit

Permalink
Add ClientTransportFilter (#10646)
Browse files Browse the repository at this point in the history
* Add ClientTransportFilter
  • Loading branch information
joybestourous authored Jan 3, 2024
1 parent 7692a9f commit 91d15ce
Show file tree
Hide file tree
Showing 20 changed files with 205 additions and 14 deletions.
51 changes: 51 additions & 0 deletions api/src/main/java/io/grpc/ClientTransportFilter.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
/*
* Copyright 2023 The gRPC Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package io.grpc;

/**
* Listens on the client transport life-cycle events. These filters do not have the capability
* to modify the channels or transport life-cycle event behavior, but they can be useful hooks
* for transport observability. Multiple filters may be registered to the client.
*
* @since 1.61.0
*/
@ExperimentalApi("https://gitub.com/grpc/grpc-java/issues/10652")
public abstract class ClientTransportFilter {
/**
* Called when a transport is ready to accept traffic (when a connection has been established).
* The default implementation is a no-op.
*
* @param transportAttrs current transport attributes
*
* @return new transport attributes. Default implementation returns the passed-in attributes
* intact.
*/
public Attributes transportReady(Attributes transportAttrs) {
return transportAttrs;
}

/**
* Called when a transport completed shutting down. All resources have been released.
* All streams have either been closed or transferred off this transport.
* Default implementation is a no-op
*
* @param transportAttrs the effective transport attributes, which is what is returned by {@link
* #transportReady} of the last executed filter.
*/
public void transportTerminated(Attributes transportAttrs) {
}
}
6 changes: 6 additions & 0 deletions api/src/main/java/io/grpc/ForwardingChannelBuilder2.java
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ public T intercept(ClientInterceptor... interceptors) {
return thisT();
}

@Override
public T addTransportFilter(ClientTransportFilter transportFilter) {
delegate().addTransportFilter(transportFilter);
return thisT();
}

@Override
public T userAgent(String userAgent) {
delegate().userAgent(userAgent);
Expand Down
12 changes: 12 additions & 0 deletions api/src/main/java/io/grpc/ManagedChannelBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,18 @@ public T offloadExecutor(Executor executor) {
*/
public abstract T intercept(ClientInterceptor... interceptors);

/**
* Adds a {@link ClientTransportFilter}. The order of filters being added is the order they will
* be executed
*
* @return this
* @since 1.60.0
*/
@ExperimentalApi("https://github.com/grpc/grpc-java/issues/10652")
public T addTransportFilter(ClientTransportFilter filter) {
throw new UnsupportedOperationException();
}

/**
* Provides a custom {@code User-Agent} for the application.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -758,6 +758,7 @@ private void checkSecurityPolicy(IBinder binder) {
// triggers), could have shut us down.
if (!isShutdown()) {
setState(TransportState.READY);
attributes = clientTransportListener.filterTransport(attributes);
clientTransportListener.transportReady();
}
}
Expand Down
19 changes: 18 additions & 1 deletion core/src/main/java/io/grpc/internal/InternalSubchannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import io.grpc.ChannelLogger;
import io.grpc.ChannelLogger.ChannelLogLevel;
import io.grpc.ClientStreamTracer;
import io.grpc.ClientTransportFilter;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
import io.grpc.EquivalentAddressGroup;
Expand Down Expand Up @@ -77,6 +78,8 @@ final class InternalSubchannel implements InternalInstrumented<ChannelStats>, Tr
private final ChannelTracer channelTracer;
private final ChannelLogger channelLogger;

private final List<ClientTransportFilter> transportFilters;

/**
* All field must be mutated in the syncContext.
*/
Expand Down Expand Up @@ -159,7 +162,8 @@ protected void handleNotInUse() {
ClientTransportFactory transportFactory, ScheduledExecutorService scheduledExecutor,
Supplier<Stopwatch> stopwatchSupplier, SynchronizationContext syncContext, Callback callback,
InternalChannelz channelz, CallTracer callsTracer, ChannelTracer channelTracer,
InternalLogId logId, ChannelLogger channelLogger) {
InternalLogId logId, ChannelLogger channelLogger,
List<ClientTransportFilter> transportFilters) {
Preconditions.checkNotNull(addressGroups, "addressGroups");
Preconditions.checkArgument(!addressGroups.isEmpty(), "addressGroups is empty");
checkListHasNoNulls(addressGroups, "addressGroups contains null entry");
Expand All @@ -180,6 +184,7 @@ protected void handleNotInUse() {
this.channelTracer = Preconditions.checkNotNull(channelTracer, "channelTracer");
this.logId = Preconditions.checkNotNull(logId, "logId");
this.channelLogger = Preconditions.checkNotNull(channelLogger, "channelLogger");
this.transportFilters = transportFilters;
}

ChannelLogger getChannelLogger() {
Expand Down Expand Up @@ -539,6 +544,15 @@ private class TransportListener implements ManagedClientTransport.Listener {
this.transport = transport;
}

@Override
public Attributes filterTransport(Attributes attributes) {
for (ClientTransportFilter filter : transportFilters) {
attributes = Preconditions.checkNotNull(filter.transportReady(attributes),
"Filter %s returned null", filter);
}
return attributes;
}

@Override
public void transportReady() {
channelLogger.log(ChannelLogLevel.INFO, "READY");
Expand Down Expand Up @@ -607,6 +621,9 @@ public void transportTerminated() {
channelLogger.log(ChannelLogLevel.INFO, "{0} Terminated", transport.getLogId());
channelz.removeClientSocket(transport);
handleTransportInUseState(transport, false);
for (ClientTransportFilter filter : transportFilters) {
filter.transportTerminated(transport.getAttributes());
}
syncContext.execute(new Runnable() {
@Override
public void run() {
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/java/io/grpc/internal/ManagedChannelImpl.java
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import io.grpc.ClientInterceptor;
import io.grpc.ClientInterceptors;
import io.grpc.ClientStreamTracer;
import io.grpc.ClientTransportFilter;
import io.grpc.CompressorRegistry;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
Expand Down Expand Up @@ -209,6 +210,8 @@ public void uncaughtException(Thread t, Throwable e) {
* {@link RealChannel}.
*/
private final Channel interceptorChannel;

private final List<ClientTransportFilter> transportFilters;
@Nullable private final String userAgent;

// Only null after channel is terminated. Must be assigned from the syncContext.
Expand Down Expand Up @@ -661,6 +664,7 @@ ClientStream newSubstream(
channel = builder.binlog.wrapChannel(channel);
}
this.interceptorChannel = ClientInterceptors.intercept(channel, interceptors);
this.transportFilters = new ArrayList<>(builder.transportFilters);
this.stopwatchSupplier = checkNotNull(stopwatchSupplier, "stopwatchSupplier");
if (builder.idleTimeoutMillis == IDLE_TIMEOUT_MILLIS_DISABLE) {
this.idleTimeoutMillis = builder.idleTimeoutMillis;
Expand Down Expand Up @@ -1566,7 +1570,8 @@ void onStateChange(InternalSubchannel is, ConnectivityStateInfo newState) {
callTracerFactory.create(),
subchannelTracer,
subchannelLogId,
subchannelLogger);
subchannelLogger,
transportFilters);
oobChannelTracer.reportEvent(new ChannelTrace.Event.Builder()
.setDescription("Child Subchannel created")
.setSeverity(ChannelTrace.Event.Severity.CT_INFO)
Expand Down Expand Up @@ -1990,7 +1995,8 @@ void onNotInUse(InternalSubchannel is) {
callTracerFactory.create(),
subchannelTracer,
subchannelLogId,
subchannelLogger);
subchannelLogger,
transportFilters);

channelTracer.reportEvent(new ChannelTrace.Event.Builder()
.setDescription("Child Subchannel started")
Expand Down Expand Up @@ -2148,6 +2154,11 @@ public void transportReady() {
// Don't care
}

@Override
public Attributes filterTransport(Attributes attributes) {
return attributes;
}

@Override
public void transportInUse(final boolean inUse) {
inUseStateAggregator.updateObjectInUse(delayedTransport, inUse);
Expand Down
20 changes: 15 additions & 5 deletions core/src/main/java/io/grpc/internal/ManagedChannelImplBuilder.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package io.grpc.internal;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
Expand All @@ -27,6 +28,7 @@
import io.grpc.CallCredentials;
import io.grpc.ChannelCredentials;
import io.grpc.ClientInterceptor;
import io.grpc.ClientTransportFilter;
import io.grpc.CompressorRegistry;
import io.grpc.DecompressorRegistry;
import io.grpc.EquivalentAddressGroup;
Expand Down Expand Up @@ -137,6 +139,8 @@ public static ManagedChannelBuilder<?> forTarget(String target) {
private final List<ClientInterceptor> interceptors = new ArrayList<>();
NameResolverRegistry nameResolverRegistry = NameResolverRegistry.getDefaultRegistry();

final List<ClientTransportFilter> transportFilters = new ArrayList<>();

final String target;
@Nullable
final ChannelCredentials channelCredentials;
Expand Down Expand Up @@ -267,11 +271,11 @@ public ManagedChannelImplBuilder(
String target, @Nullable ChannelCredentials channelCreds, @Nullable CallCredentials callCreds,
ClientTransportFactoryBuilder clientTransportFactoryBuilder,
@Nullable ChannelBuilderDefaultPortProvider channelBuilderDefaultPortProvider) {
this.target = Preconditions.checkNotNull(target, "target");
this.target = checkNotNull(target, "target");
this.channelCredentials = channelCreds;
this.callCredentials = callCreds;
this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");
this.clientTransportFactoryBuilder = checkNotNull(clientTransportFactoryBuilder,
"clientTransportFactoryBuilder");
this.directServerAddress = null;

if (channelBuilderDefaultPortProvider != null) {
Expand Down Expand Up @@ -323,8 +327,8 @@ public ManagedChannelImplBuilder(SocketAddress directServerAddress, String autho
this.target = makeTargetStringForDirectAddress(directServerAddress);
this.channelCredentials = channelCreds;
this.callCredentials = callCreds;
this.clientTransportFactoryBuilder = Preconditions
.checkNotNull(clientTransportFactoryBuilder, "clientTransportFactoryBuilder");
this.clientTransportFactoryBuilder = checkNotNull(clientTransportFactoryBuilder,
"clientTransportFactoryBuilder");
this.directServerAddress = directServerAddress;
NameResolverRegistry reg = new NameResolverRegistry();
reg.register(new DirectAddressNameResolverProvider(directServerAddress,
Expand Down Expand Up @@ -374,6 +378,12 @@ public ManagedChannelImplBuilder intercept(ClientInterceptor... interceptors) {
return intercept(Arrays.asList(interceptors));
}

@Override
public ManagedChannelImplBuilder addTransportFilter(ClientTransportFilter hook) {
transportFilters.add(checkNotNull(hook, "transport filter"));
return this;
}

@Deprecated
@Override
public ManagedChannelImplBuilder nameResolverFactory(NameResolver.Factory resolverFactory) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package io.grpc.internal;

import io.grpc.Attributes;
import io.grpc.Status;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -104,5 +105,11 @@ interface Listener {
* at least one stream.
*/
void transportInUse(boolean inUse);

/**
* Called just before {@link #transportReady} to allow direct modification of transport
* Attributes.
*/
Attributes filterTransport(Attributes attributes);
}
}
5 changes: 5 additions & 0 deletions core/src/main/java/io/grpc/internal/OobChannel.java
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,11 @@ public void transportReady() {
// Don't care
}

@Override
public Attributes filterTransport(Attributes attributes) {
return attributes;
}

@Override
public void transportInUse(boolean inUse) {
// Don't care
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
import io.grpc.internal.TestUtils.MockClientTransportInfo;
import java.net.SocketAddress;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedList;
import java.util.List;
import java.util.concurrent.BlockingQueue;
Expand Down Expand Up @@ -1360,7 +1361,8 @@ private void createInternalSubchannel(EquivalentAddressGroup ... addrs) {
channelz, CallTracer.getDefaultFactory().create(),
subchannelTracer,
logId,
new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()));
new ChannelLoggerImpl(subchannelTracer, fakeClock.getTimeProvider()),
Collections.emptyList());
}

private void assertNoCallbackInvoke() {
Expand Down
44 changes: 44 additions & 0 deletions core/src/test/java/io/grpc/internal/ManagedChannelImplTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@
import io.grpc.ClientInterceptors;
import io.grpc.ClientStreamTracer;
import io.grpc.ClientStreamTracer.StreamInfo;
import io.grpc.ClientTransportFilter;
import io.grpc.CompositeChannelCredentials;
import io.grpc.ConnectivityState;
import io.grpc.ConnectivityStateInfo;
Expand Down Expand Up @@ -139,6 +140,7 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
Expand Down Expand Up @@ -4240,6 +4242,48 @@ public void createResolvingOobChannel() throws Exception {
}
}

@Test
public void transportFilters() {

final AtomicInteger readyCallbackCalled = new AtomicInteger(0);
final AtomicInteger terminationCallbackCalled = new AtomicInteger(0);
ClientTransportFilter transportFilter = new ClientTransportFilter() {
@Override
public Attributes transportReady(Attributes transportAttrs) {
readyCallbackCalled.incrementAndGet();
return transportAttrs;
}

@Override
public void transportTerminated(Attributes transportAttrs) {
terminationCallbackCalled.incrementAndGet();
}
};

channelBuilder.addTransportFilter(transportFilter);
assertEquals(0, readyCallbackCalled.get());

createChannel();
final Subchannel subchannel =
createSubchannelSafely(helper, addressGroup, Attributes.EMPTY, subchannelStateListener);
requestConnectionSafely(helper, subchannel);
verify(mockTransportFactory)
.newClientTransport(
any(SocketAddress.class), any(ClientTransportOptions.class), any(ChannelLogger.class));
MockClientTransportInfo transportInfo = transports.poll();
ManagedClientTransport.Listener transportListener = transportInfo.listener;

transportListener.filterTransport(Attributes.EMPTY);
transportListener.transportReady();
assertEquals(1, readyCallbackCalled.get());
assertEquals(0, terminationCallbackCalled.get());

transportListener.transportShutdown(Status.OK);

transportListener.transportTerminated();
assertEquals(1, terminationCallbackCalled.get());
}

private static final class FakeBackoffPolicyProvider implements BackoffPolicy.Provider {
@Override
public BackoffPolicy get() {
Expand Down
Loading

0 comments on commit 91d15ce

Please sign in to comment.