Skip to content

Commit

Permalink
Change InteopTester to take a ManagedChannel directly instead of usin…
Browse files Browse the repository at this point in the history
…g all the args to build a channel

Create a helper class for building okhttp based channels in android tester app
  • Loading branch information
zsurocking committed Feb 29, 2016
1 parent 643bb2c commit 5d025ed
Show file tree
Hide file tree
Showing 4 changed files with 158 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,7 @@
import com.google.protobuf.nano.EmptyProtos;
import com.google.protobuf.nano.MessageNano;

import android.annotation.TargetApi;
import android.net.SSLCertificateSocketFactory;
import android.os.AsyncTask;
import android.os.Build;
import android.support.annotation.Nullable;
import android.util.Log;

import static junit.framework.Assert.assertEquals;
Expand Down Expand Up @@ -68,31 +64,21 @@
import io.grpc.stub.StreamObserver;
import io.grpc.testing.StreamRecorder;

import java.io.InputStream;
import java.io.PrintWriter;
import java.io.StringWriter;
import java.lang.RuntimeException;
import java.lang.reflect.Method;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.ArrayBlockingQueue;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;

/**
* Implementation of the integration tests, as an AsyncTask.
*/
public final class InteropTester extends AsyncTask<Void, Void, String> {
final static String SUCCESS_MESSAGE = "Succeed!!!";
final static String LOG_TAG = "GrpcTest";

private ManagedChannel channel;
private TestServiceGrpc.TestServiceBlockingStub blockingStub;
Expand All @@ -112,7 +98,7 @@ public void onNext(Messages.StreamingOutputCallResponse value) {

@Override
public void onError(Throwable t) {
Log.e(TesterActivity.LOG_TAG, "Encounter an error", t);
Log.e(LOG_TAG, "Encounter an error", t);
responses.add(t);
}

Expand All @@ -124,39 +110,11 @@ public void onCompleted() {


public InteropTester(String testCase,
String host,
int port,
@Nullable String serverHostOverride,
boolean useTls,
@Nullable InputStream testCa,
@Nullable String androidSocketFactoryTls,
ManagedChannel channel,
TestListener listener) {
this.testCase = testCase;
this.listener = listener;

ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forAddress(host, port);
if (serverHostOverride != null) {
// Force the hostname to match the cert the server uses.
channelBuilder.overrideAuthority(serverHostOverride);
}
if (useTls) {
try {
SSLSocketFactory factory;
if (androidSocketFactoryTls != null) {
factory = getSslCertificateSocketFactory(testCa, androidSocketFactoryTls);
} else {
factory = getSslSocketFactory(testCa);
}
((OkHttpChannelBuilder) channelBuilder).negotiationType(NegotiationType.TLS);
((OkHttpChannelBuilder) channelBuilder).sslSocketFactory(factory);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
channelBuilder.usePlaintext(true);
}

channel = channelBuilder.build();
this.channel = channel;
blockingStub = TestServiceGrpc.newBlockingStub(channel);
asyncStub = TestServiceGrpc.newStub(channel);
}
Expand Down Expand Up @@ -194,7 +152,7 @@ public void shutdown() {
}

public void runTest(String testCase) throws Exception {
Log.i(TesterActivity.LOG_TAG, "Running test " + testCase);
Log.i(LOG_TAG, "Running test " + testCase);
if ("all".equals(testCase)) {
runTest("empty_unary");
runTest("large_unary");
Expand Down Expand Up @@ -778,61 +736,6 @@ private static void assertCodeEquals(io.grpc.Status expected, io.grpc.Status act
}
}

private SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa) throws Exception {
if (testCa == null) {
return (SSLSocketFactory) SSLSocketFactory.getDefault();
}

SSLContext context = SSLContext.getInstance("TLS");
context.init(null, getTrustManagers(testCa) , null);
return context.getSocketFactory();
}

@TargetApi(14)
private SSLCertificateSocketFactory getSslCertificateSocketFactory(
@Nullable InputStream testCa, String androidSocketFatoryTls) throws Exception {
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.ICE_CREAM_SANDWICH /* API level 14 */) {
throw new RuntimeException(
"android_socket_factory_tls doesn't work with API level less than 14.");
}
SSLCertificateSocketFactory factory = (SSLCertificateSocketFactory)
SSLCertificateSocketFactory.getDefault(5000 /* Timeout in ms*/);
// Use HTTP/2.0
byte[] h2 = "h2".getBytes();
byte[][] protocols = new byte[][]{h2};
if (androidSocketFatoryTls.equals("alpn")) {
Method setAlpnProtocols =
factory.getClass().getDeclaredMethod("setAlpnProtocols", byte[][].class);
setAlpnProtocols.invoke(factory, new Object[] { protocols });
} else if (androidSocketFatoryTls.equals("npn")) {
Method setNpnProtocols =
factory.getClass().getDeclaredMethod("setNpnProtocols", byte[][].class);
setNpnProtocols.invoke(factory, new Object[]{protocols});
} else {
throw new RuntimeException("Unknown protocol: " + androidSocketFatoryTls);
}

if (testCa != null) {
factory.setTrustManagers(getTrustManagers(testCa));
}

return factory;
}

private TrustManager[] getTrustManagers(InputStream testCa) throws Exception {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
ks.load(null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate cert = (X509Certificate) cf.generateCertificate(testCa);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
// Set up trust manager factory to use our key store.
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}

public interface TestListener {
void onPreTest();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@

public class TesterActivity extends AppCompatActivity
implements ProviderInstaller.ProviderInstallListener {
final static String LOG_TAG = "GrpcTest";
private List<Button> buttons;
private EditText hostEdit;
private EditText portEdit;
Expand Down Expand Up @@ -111,8 +110,9 @@ private void startTest(String testCase) {
int port = TextUtils.isEmpty(portStr) ? 8080 : Integer.valueOf(portStr);

// TODO (madongfly) support server_host_override, useTls and useTestCa in the App UI.
new InteropTester(testCase, host, port, "foo.test.google.fr", true,
getResources().openRawResource(R.raw.ca), null,
new InteropTester(testCase,
TesterOkHttpChannelBuilder.build(host, port, "foo.test.google.fr", true,
getResources().openRawResource(R.raw.ca), null),
new InteropTester.TestListener() {
@Override public void onPreTest() {
resultText.setText("Testing...");
Expand All @@ -135,7 +135,7 @@ public void onProviderInstalled() {
public void onProviderInstallFailed(int errorCode, Intent recoveryIntent) {
// The provider is helpful, but it is possible to succeed without it.
// Hope that the system-provided libraries are new enough.
Log.w(LOG_TAG, "Failed installing security provider, error code: " + errorCode);
Log.w(InteropTester.LOG_TAG, "Failed installing security provider, error code: " + errorCode);
enableButtons(true);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,18 @@ public void onCreate(Bundle args) {
} catch (GooglePlayServicesRepairableException e) {
// The provider is helpful, but it is possible to succeed without it.
// Hope that the system-provided libraries are new enough.
Log.w(TesterActivity.LOG_TAG, "Failed installing security provider", e);
Log.w(InteropTester.LOG_TAG, "Failed installing security provider", e);
} catch (GooglePlayServicesNotAvailableException e) {
// The provider is helpful, but it is possible to succeed without it.
// Hope that the system-provided libraries are new enough.
Log.w(TesterActivity.LOG_TAG, "Failed installing security provider", e);
Log.w(InteropTester.LOG_TAG, "Failed installing security provider", e);
}
}

try {
new InteropTester(testCase, host, port, serverHostOverride, useTls, testCa,
androidSocketFactoryTls,
new InteropTester(testCase,
TesterOkHttpChannelBuilder.build(host, port, serverHostOverride, useTls, testCa,
androidSocketFactoryTls),
new InteropTester.TestListener() {
@Override
public void onPreTest() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
/*
* Copyright 2016, Google Inc. All rights reserved.
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are
* met:
*
* * Redistributions of source code must retain the above copyright
* notice, this list of conditions and the following disclaimer.
* * Redistributions in binary form must reproduce the above
* copyright notice, this list of conditions and the following disclaimer
* in the documentation and/or other materials provided with the
* distribution.
*
* * Neither the name of Google Inc. nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

package io.grpc.android.integrationtest;

import com.google.common.base.Preconditions;

import android.annotation.TargetApi;
import android.net.SSLCertificateSocketFactory;
import android.os.Build;
import android.support.annotation.Nullable;

import io.grpc.ManagedChannel;
import io.grpc.ManagedChannelBuilder;
import io.grpc.okhttp.OkHttpChannelBuilder;
import io.grpc.okhttp.NegotiationType;

import java.io.InputStream;
import java.lang.RuntimeException;
import java.lang.reflect.Method;
import java.security.KeyStore;
import java.security.cert.CertificateFactory;
import java.security.cert.X509Certificate;

import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLSocketFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.security.auth.x500.X500Principal;

/**
* A helper class to create a OkHttp based channel.
*/
public class TesterOkHttpChannelBuilder {
static public ManagedChannel build(String host, int port, @Nullable String serverHostOverride,
boolean useTls, @Nullable InputStream testCa, @Nullable String androidSocketFactoryTls) {
ManagedChannelBuilder<?> channelBuilder = ManagedChannelBuilder.forAddress(host, port);
if (serverHostOverride != null) {
// Force the hostname to match the cert the server uses.
channelBuilder.overrideAuthority(serverHostOverride);
}
if (useTls) {
try {
SSLSocketFactory factory;
if (androidSocketFactoryTls != null) {
factory = getSslCertificateSocketFactory(testCa, androidSocketFactoryTls);
} else {
factory = getSslSocketFactory(testCa);
}
((OkHttpChannelBuilder) channelBuilder).negotiationType(NegotiationType.TLS);
((OkHttpChannelBuilder) channelBuilder).sslSocketFactory(factory);
} catch (Exception e) {
throw new RuntimeException(e);
}
} else {
channelBuilder.usePlaintext(true);
}
return channelBuilder.build();
}

static private SSLSocketFactory getSslSocketFactory(@Nullable InputStream testCa) throws Exception {
if (testCa == null) {
return (SSLSocketFactory) SSLSocketFactory.getDefault();
}

SSLContext context = SSLContext.getInstance("TLS");
context.init(null, getTrustManagers(testCa) , null);
return context.getSocketFactory();
}

@TargetApi(14)
static private SSLCertificateSocketFactory getSslCertificateSocketFactory(
@Nullable InputStream testCa, String androidSocketFatoryTls) throws Exception {
if (Build.VERSION.SDK_INT < Build.VERSION_CODES.ICE_CREAM_SANDWICH /* API level 14 */) {
throw new RuntimeException(
"android_socket_factory_tls doesn't work with API level less than 14.");
}
SSLCertificateSocketFactory factory = (SSLCertificateSocketFactory)
SSLCertificateSocketFactory.getDefault(5000 /* Timeout in ms*/);
// Use HTTP/2.0
byte[] h2 = "h2".getBytes();
byte[][] protocols = new byte[][]{h2};
if (androidSocketFatoryTls.equals("alpn")) {
Method setAlpnProtocols =
factory.getClass().getDeclaredMethod("setAlpnProtocols", byte[][].class);
setAlpnProtocols.invoke(factory, new Object[] { protocols });
} else if (androidSocketFatoryTls.equals("npn")) {
Method setNpnProtocols =
factory.getClass().getDeclaredMethod("setNpnProtocols", byte[][].class);
setNpnProtocols.invoke(factory, new Object[]{protocols});
} else {
throw new RuntimeException("Unknown protocol: " + androidSocketFatoryTls);
}

if (testCa != null) {
factory.setTrustManagers(getTrustManagers(testCa));
}

return factory;
}

static private TrustManager[] getTrustManagers(InputStream testCa) throws Exception {
KeyStore ks = KeyStore.getInstance(KeyStore.getDefaultType());
ks.load(null);
CertificateFactory cf = CertificateFactory.getInstance("X.509");
X509Certificate cert = (X509Certificate) cf.generateCertificate(testCa);
X500Principal principal = cert.getSubjectX500Principal();
ks.setCertificateEntry(principal.getName("RFC2253"), cert);
// Set up trust manager factory to use our key store.
TrustManagerFactory trustManagerFactory =
TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
trustManagerFactory.init(ks);
return trustManagerFactory.getTrustManagers();
}
}

0 comments on commit 5d025ed

Please sign in to comment.