From 7ed45230f37b6852fd07c6439f919ad1ea9b4367 Mon Sep 17 00:00:00 2001 From: Sam Gass Date: Tue, 9 Oct 2018 13:56:22 -0700 Subject: [PATCH 1/3] Added parallelization support to StreamIterator --- .../utilities/collections/Iterables.java | 14 ++++++ .../utilities/collections/StreamIterable.java | 49 ++++++++++++++++--- .../collections/StreamIterableTest.java | 26 +++++++++- 3 files changed, 80 insertions(+), 9 deletions(-) diff --git a/src/main/java/org/openstreetmap/atlas/utilities/collections/Iterables.java b/src/main/java/org/openstreetmap/atlas/utilities/collections/Iterables.java index a79398724c..b97d0c5a89 100644 --- a/src/main/java/org/openstreetmap/atlas/utilities/collections/Iterables.java +++ b/src/main/java/org/openstreetmap/atlas/utilities/collections/Iterables.java @@ -632,6 +632,20 @@ public static Optional nth(final Iterable types, final long i return Optional.ofNullable(result); } + /** + * Create a {@link StreamIterable} that uses parallelization + * + * @param source + * The {@link Iterable} to use as source + * @param + * The type of the source {@link Iterable} + * @return The corresponding {@link StreamIterable} + */ + public static StreamIterable parallelStream(final Iterable source) + { + return new StreamIterable<>(source, true); + } + public static void print(final Iterable input, final String name) { System.out.println(toString(input, name)); diff --git a/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java b/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java index 8c31ca279c..5551b73f34 100644 --- a/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java +++ b/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java @@ -24,6 +24,7 @@ public class StreamIterable implements Iterable { private final Iterable source; + private boolean parallel = false; protected StreamIterable(final Iterable source) { @@ -31,7 +32,21 @@ protected StreamIterable(final Iterable source) } /** - * Test whether all elements from iterable match the given predicate + * Construct a new StreamIterable. + * + * @param source + * The source iterable to construct the StreamIterable from + * @param parallel + * Controls whether to use parallelization or not when streaming + */ + protected StreamIterable(final Iterable source, final boolean parallel) + { + this.source = source; + this.parallel = parallel; + } + + /** + * Test whether all elements from iterable match the given predicate. * * @param predicate * Predicate to test @@ -39,7 +54,7 @@ protected StreamIterable(final Iterable source) */ public boolean allMatch(final Predicate predicate) { - return StreamSupport.stream(this.source.spliterator(), false).allMatch(predicate); + return StreamSupport.stream(this.source.spliterator(), this.parallel).allMatch(predicate); } /** @@ -51,7 +66,7 @@ public boolean allMatch(final Predicate predicate) */ public boolean anyMatch(final Predicate predicate) { - return StreamSupport.stream(this.source.spliterator(), false).anyMatch(predicate); + return StreamSupport.stream(this.source.spliterator(), this.parallel).anyMatch(predicate); } /** @@ -89,6 +104,22 @@ public SortedSet collectToSortedSet() return Iterables.asSortedSet(this.source); } + /** + * Disable parallelization in streams from this StreamIterator + */ + public void disableParallelization() + { + this.parallel = false; + } + + /** + * Enable parallelization in streams from this StreamIterator + */ + public void enableParallelization() + { + this.parallel = true; + } + /** * Filter an {@link Iterable} * @@ -98,7 +129,7 @@ public SortedSet collectToSortedSet() */ public StreamIterable filter(final Predicate filter) { - return new StreamIterable<>(Iterables.filter(this.source, filter)); + return new StreamIterable<>(Iterables.filter(this.source, filter), this.parallel); } /** @@ -116,7 +147,8 @@ public StreamIterable filter(final Set filte final Function identifier) { return new StreamIterable<>( - new FilteredIterable(this.source, filterSet, identifier)); + new FilteredIterable(this.source, filterSet, identifier), + this.parallel); } /** @@ -130,7 +162,7 @@ public StreamIterable filter(final Set filte */ public StreamIterable flatMap(final Function> flatMap) { - return new StreamIterable<>(Iterables.translateMulti(this.source, flatMap)); + return new StreamIterable<>(Iterables.translateMulti(this.source, flatMap), this.parallel); } @Override @@ -150,7 +182,7 @@ public Iterator iterator() */ public StreamIterable map(final Function map) { - return new StreamIterable<>(Iterables.translate(this.source, map)); + return new StreamIterable<>(Iterables.translate(this.source, map), this.parallel); } /** @@ -164,6 +196,7 @@ public StreamIterable map(final Function map) */ public StreamIterable truncate(final int startIndex, final int indexFromEnd) { - return new StreamIterable<>(Iterables.truncate(this.source, startIndex, indexFromEnd)); + return new StreamIterable<>(Iterables.truncate(this.source, startIndex, indexFromEnd), + this.parallel); } } diff --git a/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java b/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java index 4d6f5b0a0b..d6b7ea6ed6 100644 --- a/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java +++ b/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java @@ -22,10 +22,34 @@ public void testAllMatch() throws Exception Assert.assertFalse(streamIterable.allMatch(n -> n % 2 == 0)); } + @Test + public void testAllMatchParallel() throws Exception + { + final StreamIterable streamIterable = new StreamIterable<>(asList(1, 2, 3, 4)); + + // true -> all numbers are less than 5 + Assert.assertTrue(streamIterable.allMatch(n -> n < 5)); + // false -> all numbers are even + Assert.assertFalse(streamIterable.allMatch(n -> n % 2 == 0)); + } + @Test public void testAnyMatch() { - final StreamIterable streamIterable = new StreamIterable<>(asList(6, 12, 18)); + final StreamIterable streamIterable = new StreamIterable<>(asList(6, 12, 18), + true); + + // true -> any number that is divisible by 9 + Assert.assertTrue(streamIterable.anyMatch(n -> n % 9 == 0)); + // false -> any number that is divisible by 5 + Assert.assertFalse(streamIterable.allMatch(n -> n % 5 == 0)); + } + + @Test + public void testAnyMatchParallel() + { + final StreamIterable streamIterable = new StreamIterable<>(asList(6, 12, 18), + true); // true -> any number that is divisible by 9 Assert.assertTrue(streamIterable.anyMatch(n -> n % 9 == 0)); From e239cc141e1fa8da28cbdb337d0396455c0a43fd Mon Sep 17 00:00:00 2001 From: Sam Gass Date: Wed, 10 Oct 2018 09:45:20 -0700 Subject: [PATCH 2/3] PR updates again --- .../utilities/collections/StreamIterable.java | 10 ++++- .../collections/StreamIterableTest.java | 39 +++++++++++++++++-- 2 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java b/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java index 5551b73f34..60d46eb5aa 100644 --- a/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java +++ b/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java @@ -106,18 +106,24 @@ public SortedSet collectToSortedSet() /** * Disable parallelization in streams from this StreamIterator + * + * @return The StreamIterator with parallelization disabled */ - public void disableParallelization() + public StreamIterable disableParallelization() { this.parallel = false; + return this; } /** * Enable parallelization in streams from this StreamIterator + * + * @return The StreamIterator with parallelization enabled */ - public void enableParallelization() + public StreamIterable enableParallelization() { this.parallel = true; + return this; } /** diff --git a/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java b/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java index d6b7ea6ed6..749ab7da93 100644 --- a/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java +++ b/src/test/java/org/openstreetmap/atlas/utilities/collections/StreamIterableTest.java @@ -2,8 +2,16 @@ import static java.util.Arrays.asList; +import java.util.ArrayList; +import java.util.List; +import java.util.stream.IntStream; + import org.junit.Assert; import org.junit.Test; +import org.openstreetmap.atlas.utilities.scalars.Duration; +import org.openstreetmap.atlas.utilities.time.Time; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * @author Sagar Rohankar @@ -11,6 +19,8 @@ */ public class StreamIterableTest { + private static final Logger logger = LoggerFactory.getLogger(StreamIterableTest.class); + @Test public void testAllMatch() throws Exception { @@ -25,7 +35,8 @@ public void testAllMatch() throws Exception @Test public void testAllMatchParallel() throws Exception { - final StreamIterable streamIterable = new StreamIterable<>(asList(1, 2, 3, 4)); + final StreamIterable streamIterable = new StreamIterable<>(asList(1, 2, 3, 4), + true); // true -> all numbers are less than 5 Assert.assertTrue(streamIterable.allMatch(n -> n < 5)); @@ -36,8 +47,7 @@ public void testAllMatchParallel() throws Exception @Test public void testAnyMatch() { - final StreamIterable streamIterable = new StreamIterable<>(asList(6, 12, 18), - true); + final StreamIterable streamIterable = new StreamIterable<>(asList(6, 12, 18)); // true -> any number that is divisible by 9 Assert.assertTrue(streamIterable.anyMatch(n -> n % 9 == 0)); @@ -57,4 +67,27 @@ public void testAnyMatchParallel() Assert.assertFalse(streamIterable.allMatch(n -> n % 5 == 0)); } + @Test + public void testParallelPerformance() + { + final List numbers = new ArrayList<>(); + IntStream.range(0, 10000000).forEach(number -> + { + numbers.add(number); + }); + final StreamIterable streamIterable = new StreamIterable<>(numbers) + .disableParallelization(); + Time currentTime = Time.now(); + streamIterable.anyMatch(n -> n > 100000); + final Duration sequentialDuration = currentTime.elapsedSince(); + logger.debug("Sequential duration was {} ms", sequentialDuration.asMilliseconds()); + + currentTime = Time.now(); + streamIterable.enableParallelization(); + streamIterable.anyMatch(n -> n > 100000); + final Duration parallelDuration = currentTime.elapsedSince(); + logger.debug("Parallel duration was {} ms", parallelDuration.asMilliseconds()); + Assert.assertTrue(parallelDuration.isLessThan(sequentialDuration)); + } + } From 2d15583d085d6968073a57e29b35e51dc97df679 Mon Sep 17 00:00:00 2001 From: Sam Gass Date: Wed, 10 Oct 2018 11:02:18 -0700 Subject: [PATCH 3/3] Added in comments around thread safety --- .../atlas/utilities/collections/StreamIterable.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java b/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java index 60d46eb5aa..87e036b171 100644 --- a/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java +++ b/src/main/java/org/openstreetmap/atlas/utilities/collections/StreamIterable.java @@ -15,7 +15,7 @@ *

* * Iterables.stream(someIterable).map(...).filter(...).collect(); - * + * Note: StreamIterable is not thread safe with parallelization usage. * * @author matthieun * @param