Skip to content

Commit

Permalink
Make Endpoint cache InetSocketAddress more aggressively / Address…
Browse files Browse the repository at this point in the history
… the comments from @ikhoon
  • Loading branch information
trustin committed May 23, 2023
1 parent 6ea5bee commit ad3e99d
Show file tree
Hide file tree
Showing 7 changed files with 115 additions and 81 deletions.
35 changes: 28 additions & 7 deletions core/src/main/java/com/linecorp/armeria/client/Endpoint.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.linecorp.armeria.client;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;
import static java.util.Objects.requireNonNull;

import java.io.UnsupportedEncodingException;
Expand All @@ -30,7 +31,6 @@
import java.net.URISyntaxException;
import java.net.URLDecoder;
import java.net.UnknownHostException;
import java.nio.charset.StandardCharsets;
import java.nio.file.Paths;
import java.util.Comparator;
import java.util.List;
Expand Down Expand Up @@ -161,17 +161,20 @@ public static Endpoint of(SocketAddress addr) {
}

if (addr instanceof DomainSocketAddress) {
return of(((DomainSocketAddress) addr).authority());
final DomainSocketAddress domainSocketAddr = (DomainSocketAddress) addr;
final Endpoint endpoint = unsafeCreate(domainSocketAddr.authority(), 0);
endpoint.socketAddress = domainSocketAddr;
return endpoint;
}

checkArgument(addr instanceof InetSocketAddress,
"unsupported address: %s", addr);

final InetSocketAddress inetAddr = (InetSocketAddress) addr;
@SuppressWarnings("resource")
final Endpoint endpoint = of(inetAddr.getHostString(), inetAddr.getPort());
return inetAddr.isUnresolved() ? endpoint
: endpoint.withIpAddr(inetAddr.getAddress().getHostAddress());
final String ipAddr = inetAddr.isUnresolved() ? null : inetAddr.getAddress().getHostAddress();
final Endpoint endpoint = of(inetAddr.getHostString(), inetAddr.getPort()).withIpAddr(ipAddr);
endpoint.socketAddress = inetAddr;
return endpoint;
}

/**
Expand Down Expand Up @@ -239,6 +242,8 @@ private enum HostType {
private CompletableFuture<Endpoint> selectFuture;
@Nullable
private CompletableFuture<List<Endpoint>> whenReadyFuture;
@Nullable
private InetSocketAddress socketAddress;
private int hashCode;

private Endpoint(String host, @Nullable String ipAddr, int port, int weight, HostType hostType,
Expand Down Expand Up @@ -535,6 +540,8 @@ public Endpoint withIpAddr(@Nullable String ipAddr) {
return this;
}

checkState(!isDomainSocket(), "A domain socket endpoint can't have an IP address.");

if (NetUtil.isValidIpV4Address(ipAddr)) {
return withIpAddr(ipAddr, StandardProtocolFamily.INET);
}
Expand Down Expand Up @@ -799,7 +806,21 @@ public URI toUri(Scheme scheme, @Nullable String path) {
* @see #hasPort()
* @see #isDomainSocket()
*/
@UnstableApi
public InetSocketAddress toSocketAddress(int defaultPort) {
final InetSocketAddress socketAddress = this.socketAddress;
if (socketAddress != null) {
return socketAddress;
}

final InetSocketAddress newSocketAddress = toSocketAddress0(defaultPort);
if (hasPort() || isDomainSocket()) {
this.socketAddress = newSocketAddress;
}
return newSocketAddress;
}

private InetSocketAddress toSocketAddress0(int defaultPort) {
if (isDomainSocket()) {
final String decodedHost;
try {
Expand All @@ -810,7 +831,7 @@ public InetSocketAddress toSocketAddress(int defaultPort) {
}

assert decodedHost.startsWith("unix:") : decodedHost;
return DomainSocketAddress.of(Paths.get(decodedHost.substring(5))); // Strip "unix:"
return DomainSocketAddress.of(Paths.get(decodedHost.substring(5)));
}

final int port = hasPort() ? this.port : defaultPort;
Expand Down
93 changes: 30 additions & 63 deletions core/src/main/java/com/linecorp/armeria/client/HttpChannelPool.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,9 @@
*/
package com.linecorp.armeria.client;

import java.io.UnsupportedEncodingException;
import java.lang.reflect.Array;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.net.URLDecoder;
import java.net.UnknownHostException;
import java.nio.channels.UnsupportedAddressTypeException;
import java.util.ArrayDeque;
import java.util.ArrayList;
Expand Down Expand Up @@ -73,7 +69,6 @@
import io.netty.handler.proxy.Socks4ProxyHandler;
import io.netty.handler.proxy.Socks5ProxyHandler;
import io.netty.handler.ssl.SslContext;
import io.netty.util.NetUtil;
import io.netty.util.concurrent.Future;
import io.netty.util.concurrent.Promise;
import reactor.core.scheduler.NonBlocking;
Expand Down Expand Up @@ -381,7 +376,7 @@ private void connect(SessionProtocol desiredProtocol, PoolKey key, ChannelAcquis
timingsBuilder.socketConnectStart();

// Fail immediately if it is certain that the remote address doesn't support the desired protocol.
final SocketAddress remoteAddress = key.remoteAddress;
final SocketAddress remoteAddress = key.toRemoteAddress();
if (SessionProtocolNegotiationCache.isUnsupported(remoteAddress, desiredProtocol)) {
notifyConnect(desiredProtocol, key,
eventLoop.newFailedFuture(
Expand Down Expand Up @@ -642,73 +637,24 @@ public void close() {
static final class PoolKey {
final Endpoint endpoint;
final ProxyConfig proxyConfig;
final SocketAddress remoteAddress;

private final int hashCode;
private final String strVal;

PoolKey(Endpoint endpoint, ProxyConfig proxyConfig) {
this.endpoint = endpoint;
this.proxyConfig = proxyConfig;
remoteAddress = toRemoteAddress(endpoint, proxyConfig);
hashCode = endpoint.hashCode() * 31 + proxyConfig.hashCode();
strVal = generateString(endpoint, proxyConfig);
}

private static String generateString(Endpoint endpoint, ProxyConfig proxyConfig) {
final String host = endpoint.host();
final String ipAddr = endpoint.ipAddr();
final int port = endpoint.port();
final boolean isDomainSocket = endpoint.isDomainSocket();
final String proxyConfigStr = proxyConfig.proxyType() != ProxyType.DIRECT ? proxyConfig.toString()
: null;
try (TemporaryThreadLocals tempThreadLocals = TemporaryThreadLocals.acquire()) {
final StringBuilder buf = tempThreadLocals.stringBuilder();
buf.append('{').append(host);
if (ipAddr != null) {
buf.append('/').append(ipAddr);
}
if (!isDomainSocket) {
buf.append(':').append(port);
}
if (proxyConfigStr != null) {
buf.append(" via ");
buf.append(proxyConfigStr);
}
buf.append('}');
return buf.toString();
}
}

private static SocketAddress toRemoteAddress(Endpoint endpoint, ProxyConfig proxyConfig) {
final String host = endpoint.host();
final String ipAddr = endpoint.ipAddr();
if (ipAddr != null) {
try {
return new InetSocketAddress(
InetAddress.getByAddress(host, NetUtil.createByteArrayFromIpAddressString(ipAddr)),
endpoint.port());
} catch (UnknownHostException e) {
// Should never reach here because `Endpoint` validates the IP address.
throw new Error(e);
}
SocketAddress toRemoteAddress() {
final InetSocketAddress remoteAddr = endpoint.toSocketAddress(-1);
if (remoteAddr instanceof com.linecorp.armeria.common.util.DomainSocketAddress) {
return ((com.linecorp.armeria.common.util.DomainSocketAddress) remoteAddr).asNettyAddress();
}

// ipAddr can be null for domain sockets.
if (endpoint.isDomainSocket()) {
final String path;
try {
path = URLDecoder.decode(host.substring(7), "UTF-8");
} catch (UnsupportedEncodingException e) {
// Should never reach here.
throw new Error(e);
}
return new DomainSocketAddress(path);
}
assert !remoteAddr.isUnresolved() || proxyConfig.proxyType().isForwardProxy()
: remoteAddr + ", " + proxyConfig;

// ipAddr can be null for forward proxies.
assert proxyConfig.proxyType().isForwardProxy() : proxyConfig;
return InetSocketAddress.createUnresolved(host, endpoint.port());
return remoteAddr;
}

@Override
Expand All @@ -734,7 +680,28 @@ public int hashCode() {

@Override
public String toString() {
return strVal;
final String host = endpoint.host();
final String ipAddr = endpoint.ipAddr();
final int port = endpoint.port();
final boolean isDomainSocket = endpoint.isDomainSocket();
final String proxyConfigStr = proxyConfig.proxyType() != ProxyType.DIRECT ? proxyConfig.toString()
: null;
try (TemporaryThreadLocals tempThreadLocals = TemporaryThreadLocals.acquire()) {
final StringBuilder buf = tempThreadLocals.stringBuilder();
buf.append('{').append(host);
if (ipAddr != null) {
buf.append('/').append(ipAddr);
}
if (!isDomainSocket) {
buf.append(':').append(port);
}
if (proxyConfigStr != null) {
buf.append(" via ");
buf.append(proxyConfigStr);
}
buf.append('}');
return buf.toString();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ public static boolean isDomainSocketAddress(InetAddress addr) {
private String authority;
@Nullable
private Endpoint endpoint;
@Nullable
@SuppressWarnings("NullableOnContainingClass") // ErrorProne false positive
private io.netty.channel.unix.DomainSocketAddress nettyAddress;

private DomainSocketAddress(Path path) {
super(toInetAddress(path), 1);
Expand Down Expand Up @@ -132,8 +135,16 @@ public String authority() {
*
* @return the converted Netty address
*/
public io.netty.channel.unix.DomainSocketAddress toNettyAddress() {
return new io.netty.channel.unix.DomainSocketAddress(path.toFile());
public io.netty.channel.unix.DomainSocketAddress asNettyAddress() {
final io.netty.channel.unix.DomainSocketAddress nettyAddress = this.nettyAddress;
if (nettyAddress != null) {
return nettyAddress;
}

final io.netty.channel.unix.DomainSocketAddress newNettyAddress =
new io.netty.channel.unix.DomainSocketAddress(path.toFile());
this.nettyAddress = newNettyAddress;
return newNettyAddress;
}

/**
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/java/com/linecorp/armeria/server/Server.java
Original file line number Diff line number Diff line change
Expand Up @@ -578,7 +578,7 @@ private ChannelFuture doStart(ServerPort port) {
if (port.isDomainSocket()) {
if (transportType.supportsDomainSockets()) {
// Convert to Netty's DomainSocketAddress type.
localAddress = ((DomainSocketAddress) port.localAddress()).toNettyAddress();
localAddress = ((DomainSocketAddress) port.localAddress()).asNettyAddress();
channelType = transportType.domainServerChannelType();
} else {
throw new IllegalStateException(
Expand Down
39 changes: 39 additions & 0 deletions core/src/test/java/com/linecorp/armeria/client/EndpointTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,45 @@ void conversionToSocketAddress() throws Exception {
DomainSocketAddress.of(Paths.get("/foo.sock")));
}

@Test
void socketAddressCache() {
final Endpoint endpointWithPort = Endpoint.of("foo", 42);
assertThat(endpointWithPort.toSocketAddress(-1))
.isSameAs(endpointWithPort.toSocketAddress(1));

final Endpoint endpointWithoutPort = Endpoint.of("foo");
assertThat(endpointWithoutPort.toSocketAddress(80))
.isNotSameAs(endpointWithoutPort.toSocketAddress(80));

final Endpoint endpointWithDomainSocket = Endpoint.of("unix%3A%2Ffoo.sock");
assertThat(endpointWithDomainSocket.toSocketAddress(-1))
.isSameAs(endpointWithDomainSocket.toSocketAddress(1));
}

@Test
void socketAddressPrecache() throws Exception {
final DomainSocketAddress domainSocketAddress = DomainSocketAddress.of(Paths.get("/foo.sock"));
assertThat(Endpoint.of(domainSocketAddress).toSocketAddress(-1))
.isSameAs(domainSocketAddress);

final InetSocketAddress unresolvedSocketAddress = InetSocketAddress.createUnresolved("foo", 42);
assertThat(Endpoint.of(unresolvedSocketAddress).toSocketAddress(-1))
.isSameAs(unresolvedSocketAddress);

final InetSocketAddress resolvedSocketAddress = new InetSocketAddress(
InetAddress.getByAddress("foo", new byte[] { 127, 0, 0, 1 }), 42);
assertThat(Endpoint.of(resolvedSocketAddress).toSocketAddress(-1))
.isSameAs(resolvedSocketAddress);
}

@Test
void domainSocketWithIpAddr() {
final Endpoint endpoint = Endpoint.of("unix%3A%2Ffoo.sock");
assertThatThrownBy(() -> endpoint.withIpAddr("127.0.0.1"))
.isInstanceOf(IllegalStateException.class)
.hasMessageContaining("domain socket");
}

@Test
void setAndGetAttr() {
final Endpoint endpointA = Endpoint.parse("a");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,10 @@ void domainKeyGeneration() {
}

private static InetSocketAddress toRemoteAddress(Endpoint endpoint) throws UnknownHostException {
final String ipAddr;
if (endpoint.hasIpAddr()) {
ipAddr = endpoint.ipAddr();
assert ipAddr != null;
} else {
ipAddr = "127.0.0.1"; // Do not resolve the host name but just use local address for test.
if (!endpoint.hasIpAddr()) {
endpoint = endpoint.withIpAddr("127.0.0.1");
}
return toRemoteAddress(endpoint.host(), ipAddr, endpoint.port());
return endpoint.toSocketAddress(-1);
}

private static InetSocketAddress toRemoteAddress(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
receivedBuffers.add((ByteBuf) msg);
}
});
final Channel ch = b.connect(domainSocketAddress().toNettyAddress())
final Channel ch = b.connect(domainSocketAddress().asNettyAddress())
.syncUninterruptibly()
.channel();

Expand Down

0 comments on commit ad3e99d

Please sign in to comment.