Skip to content

Commit

Permalink
xds: use the new cert-provider instances if present (#8494)
Browse files Browse the repository at this point in the history
  • Loading branch information
sanjaypujare authored Sep 8, 2021
1 parent f71eedf commit 2260381
Show file tree
Hide file tree
Showing 12 changed files with 474 additions and 95 deletions.
95 changes: 56 additions & 39 deletions xds/src/main/java/io/grpc/xds/ClientXdsClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -456,10 +456,6 @@ static void validateCommonTlsContext(
if (commonTlsContext.hasTlsParams()) {
throw new ResourceInvalidException("common-tls-context with tls_params is not supported");
}
if (commonTlsContext.hasValidationContext()) {
throw new ResourceInvalidException(
"common-tls-context with validation_context is not supported");
}
if (commonTlsContext.hasValidationContextSdsSecretConfig()) {
throw new ResourceInvalidException(
"common-tls-context with validation_context_sds_secret_config is not supported");
Expand All @@ -473,54 +469,50 @@ static void validateCommonTlsContext(
"common-tls-context with validation_context_certificate_provider_instance is not"
+ " supported");
}
String certInstanceName = null;
if (!commonTlsContext.hasTlsCertificateCertificateProviderInstance()) {
String certInstanceName = getIdentityCertInstanceName(commonTlsContext);
if (certInstanceName == null) {
if (server) {
throw new ResourceInvalidException(
"tls_certificate_certificate_provider_instance is required in downstream-tls-context");
"tls_certificate_provider_instance is required in downstream-tls-context");
}
if (commonTlsContext.getTlsCertificatesCount() > 0) {
throw new ResourceInvalidException(
"common-tls-context with tls_certificates is not supported");
"tls_certificate_provider_instance is unset");
}
if (commonTlsContext.getTlsCertificateSdsSecretConfigsCount() > 0) {
throw new ResourceInvalidException(
"common-tls-context with tls_certificate_sds_secret_configs is not supported");
"tls_certificate_provider_instance is unset");
}
if (commonTlsContext.hasTlsCertificateCertificateProvider()) {
throw new ResourceInvalidException(
"common-tls-context with tls_certificate_certificate_provider is not supported");
}
} else {
certInstanceName = commonTlsContext.getTlsCertificateCertificateProviderInstance()
.getInstanceName();
}
if (certInstanceName != null) {
if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) {
throw new ResourceInvalidException(
"CertificateProvider instance name '" + certInstanceName
+ "' not defined in the bootstrap file.");
"tls_certificate_provider_instance is unset");
}
} else if (certProviderInstances == null || !certProviderInstances.contains(certInstanceName)) {
throw new ResourceInvalidException(
"CertificateProvider instance name '" + certInstanceName
+ "' not defined in the bootstrap file.");
}
String rootCaInstanceName = null;
if (!commonTlsContext.hasCombinedValidationContext()) {
String rootCaInstanceName = getRootCertInstanceName(commonTlsContext);
if (rootCaInstanceName == null) {
if (!server) {
throw new ResourceInvalidException(
"combined_validation_context is required in upstream-tls-context");
"ca_certificate_provider_instance is required in upstream-tls-context");
}
} else {
CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext
= commonTlsContext.getCombinedValidationContext();
if (!combinedCertificateValidationContext.hasValidationContextCertificateProviderInstance()) {
if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) {
throw new ResourceInvalidException(
"validation_context_certificate_provider_instance is required in"
+ " combined_validation_context");
}
rootCaInstanceName = combinedCertificateValidationContext
.getValidationContextCertificateProviderInstance().getInstanceName();
if (combinedCertificateValidationContext.hasDefaultValidationContext()) {
CertificateValidationContext certificateValidationContext
= combinedCertificateValidationContext.getDefaultValidationContext();
"ca_certificate_provider_instance name '" + rootCaInstanceName
+ "' not defined in the bootstrap file.");
}
CertificateValidationContext certificateValidationContext = null;
if (commonTlsContext.hasValidationContext()) {
certificateValidationContext = commonTlsContext.getValidationContext();
} else if (commonTlsContext.hasCombinedValidationContext() && commonTlsContext
.getCombinedValidationContext().hasDefaultValidationContext()) {
certificateValidationContext = commonTlsContext.getCombinedValidationContext()
.getDefaultValidationContext();
}
if (certificateValidationContext != null) {
if (certificateValidationContext.getMatchSubjectAltNamesCount() > 0 && server) {
throw new ResourceInvalidException(
"match_subject_alt_names only allowed in upstream_tls_context");
Expand All @@ -547,13 +539,38 @@ static void validateCommonTlsContext(
}
}
}
if (rootCaInstanceName != null) {
if (certProviderInstances == null || !certProviderInstances.contains(rootCaInstanceName)) {
throw new ResourceInvalidException(
"ValidationContextProvider instance name '" + rootCaInstanceName
+ "' not defined in the bootstrap file.");
}

private static String getIdentityCertInstanceName(CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasTlsCertificateProviderInstance()) {
return commonTlsContext.getTlsCertificateProviderInstance().getInstanceName();
} else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) {
return commonTlsContext.getTlsCertificateCertificateProviderInstance().getInstanceName();
}
return null;
}

private static String getRootCertInstanceName(CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasValidationContext()) {
if (commonTlsContext.getValidationContext().hasCaCertificateProviderInstance()) {
return commonTlsContext.getValidationContext().getCaCertificateProviderInstance()
.getInstanceName();
}
} else if (commonTlsContext.hasCombinedValidationContext()) {
CommonTlsContext.CombinedCertificateValidationContext combinedCertificateValidationContext
= commonTlsContext.getCombinedValidationContext();
if (combinedCertificateValidationContext.hasDefaultValidationContext()
&& combinedCertificateValidationContext.getDefaultValidationContext()
.hasCaCertificateProviderInstance()) {
return combinedCertificateValidationContext.getDefaultValidationContext()
.getCaCertificateProviderInstance().getInstanceName();
} else if (combinedCertificateValidationContext
.hasValidationContextCertificateProviderInstance()) {
return combinedCertificateValidationContext
.getValidationContextCertificateProviderInstance().getInstanceName();
}
}
return null;
}

private static void checkForUniqueness(Set<FilterChainMatch> uniqueSet,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
Expand Down Expand Up @@ -94,27 +93,12 @@ public CertProviderClientSslContextProvider getProvider(
@Nullable Map<String, CertificateProviderInfo> certProviders) {
checkNotNull(upstreamTlsContext, "upstreamTlsContext");
CommonTlsContext commonTlsContext = upstreamTlsContext.getCommonTlsContext();
CommonTlsContext.CertificateProviderInstance rootCertInstance = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) {
rootCertInstance =
combinedValidationContext.getValidationContextCertificateProviderInstance();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) {
rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
CommonTlsContext.CertificateProviderInstance certInstance = null;
if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) {
certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance();
}
CertificateValidationContext staticCertValidationContext = getStaticValidationContext(
commonTlsContext);
CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance(
commonTlsContext);
CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance(
commonTlsContext);
return new CertProviderClientSslContextProvider(
node,
certProviders,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;
import io.grpc.Internal;
import io.grpc.netty.GrpcSslContexts;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
Expand Down Expand Up @@ -97,27 +96,12 @@ public CertProviderServerSslContextProvider getProvider(
@Nullable Map<String, CertificateProviderInfo> certProviders) {
checkNotNull(downstreamTlsContext, "downstreamTlsContext");
CommonTlsContext commonTlsContext = downstreamTlsContext.getCommonTlsContext();
CommonTlsContext.CertificateProviderInstance rootCertInstance = null;
CertificateValidationContext staticCertValidationContext = null;
if (commonTlsContext.hasCombinedValidationContext()) {
CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) {
rootCertInstance =
combinedValidationContext.getValidationContextCertificateProviderInstance();
}
if (combinedValidationContext.hasDefaultValidationContext()) {
staticCertValidationContext = combinedValidationContext.getDefaultValidationContext();
}
} else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) {
rootCertInstance = commonTlsContext.getValidationContextCertificateProviderInstance();
} else if (commonTlsContext.hasValidationContext()) {
staticCertValidationContext = commonTlsContext.getValidationContext();
}
CommonTlsContext.CertificateProviderInstance certInstance = null;
if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) {
certInstance = commonTlsContext.getTlsCertificateCertificateProviderInstance();
}
CertificateValidationContext staticCertValidationContext = getStaticValidationContext(
commonTlsContext);
CommonTlsContext.CertificateProviderInstance rootCertInstance = getRootCertProviderInstance(
commonTlsContext);
CommonTlsContext.CertificateProviderInstance certInstance = getCertProviderInstance(
commonTlsContext);
return new CertProviderServerSslContextProvider(
node,
certProviders,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@

import io.envoyproxy.envoy.config.core.v3.Node;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateValidationContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CertificateProviderInstance;
import io.grpc.xds.Bootstrapper.CertificateProviderInfo;
import io.grpc.xds.EnvoyServerProtoData.BaseTlsContext;
import io.grpc.xds.internal.sds.CommonTlsContextUtil;
import io.grpc.xds.internal.sds.DynamicSslContextProvider;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
Expand Down Expand Up @@ -88,6 +90,52 @@ private static CertificateProviderInfo getCertProviderConfig(
return certProviders != null ? certProviders.get(pluginInstanceName) : null;
}

@Nullable
protected static CertificateProviderInstance getCertProviderInstance(
CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasTlsCertificateProviderInstance()) {
return CommonTlsContextUtil.convert(commonTlsContext.getTlsCertificateProviderInstance());
} else if (commonTlsContext.hasTlsCertificateCertificateProviderInstance()) {
return commonTlsContext.getTlsCertificateCertificateProviderInstance();
}
return null;
}

@Nullable
protected static CertificateValidationContext getStaticValidationContext(
CommonTlsContext commonTlsContext) {
if (commonTlsContext.hasValidationContext()) {
return commonTlsContext.getValidationContext();
} else if (commonTlsContext.hasCombinedValidationContext()) {
CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasDefaultValidationContext()) {
return combinedValidationContext.getDefaultValidationContext();
}
}
return null;
}

@Nullable
protected static CommonTlsContext.CertificateProviderInstance getRootCertProviderInstance(
CommonTlsContext commonTlsContext) {
CertificateValidationContext certValidationContext = getStaticValidationContext(
commonTlsContext);
if (certValidationContext != null && certValidationContext.hasCaCertificateProviderInstance()) {
return CommonTlsContextUtil.convert(certValidationContext.getCaCertificateProviderInstance());
}
if (commonTlsContext.hasCombinedValidationContext()) {
CommonTlsContext.CombinedCertificateValidationContext combinedValidationContext =
commonTlsContext.getCombinedValidationContext();
if (combinedValidationContext.hasValidationContextCertificateProviderInstance()) {
return combinedValidationContext.getValidationContextCertificateProviderInstance();
}
} else if (commonTlsContext.hasValidationContextCertificateProviderInstance()) {
return commonTlsContext.getValidationContextCertificateProviderInstance();
}
return null;
}

@Override
public final void updateCertificate(PrivateKey key, List<X509Certificate> certChain) {
savedKey = key;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@

package io.grpc.xds.internal.sds;

import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CertificateProviderPluginInstance;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext;
import io.envoyproxy.envoy.extensions.transport_sockets.tls.v3.CommonTlsContext.CombinedCertificateValidationContext;

/** Class for utility functions for {@link CommonTlsContext}. */
final class CommonTlsContextUtil {
public final class CommonTlsContextUtil {

private CommonTlsContextUtil() {}

Expand All @@ -38,4 +39,15 @@ private static boolean hasCertProviderValidationContext(CommonTlsContext commonT
}
return commonTlsContext.hasValidationContextCertificateProviderInstance();
}

/**
* Converts {@link CertificateProviderPluginInstance} to
* {@link CommonTlsContext.CertificateProviderInstance}.
*/
public static CommonTlsContext.CertificateProviderInstance convert(
CertificateProviderPluginInstance pluginInstance) {
return CommonTlsContext.CertificateProviderInstance.newBuilder()
.setInstanceName(pluginInstance.getInstanceName())
.setCertificateName(pluginInstance.getCertificateName()).build();
}
}
Loading

0 comments on commit 2260381

Please sign in to comment.