Skip to content

Commit

Permalink
Added parallelization support to StreamIterator (#234)
Browse files Browse the repository at this point in the history
* Added parallelization support to StreamIterator

* PR updates again

* Added in comments around thread safety
  • Loading branch information
adahn6 authored and MikeGost committed Oct 10, 2018
1 parent 6d73202 commit ec8d9f8
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,20 @@ public static <Type> Optional<Type> nth(final Iterable<Type> types, final long i
return Optional.ofNullable(result);
}

/**
* Create a {@link StreamIterable} that uses parallelization
*
* @param source
* The {@link Iterable} to use as source
* @param <Type>
* The type of the source {@link Iterable}
* @return The corresponding {@link StreamIterable}
*/
public static <Type> StreamIterable<Type> parallelStream(final Iterable<Type> source)
{
return new StreamIterable<>(source, true);
}

public static <T> void print(final Iterable<T> input, final String name)
{
System.out.println(toString(input, name));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* <p>
* <code>
* Iterables.stream(someIterable).map(...).filter(...).collect();
* </code>
* </code> Note: StreamIterable is not thread safe with parallelization usage.
*
* @author matthieun
* @param <T>
Expand All @@ -24,22 +24,37 @@
public class StreamIterable<T> implements Iterable<T>
{
private final Iterable<T> source;
private boolean parallel = false;

protected StreamIterable(final Iterable<T> source)
{
this.source = source;
}

/**
* Test whether all elements from iterable match the given predicate
* Construct a new StreamIterable.
*
* @param source
* The source iterable to construct the StreamIterable from
* @param parallel
* Controls whether to use parallelization or not when streaming
*/
protected StreamIterable(final Iterable<T> source, final boolean parallel)
{
this.source = source;
this.parallel = parallel;
}

/**
* Test whether all elements from iterable match the given predicate.
*
* @param predicate
* Predicate to test
* @return {@code true} when given predicate is true for all entities in iterable, else false
*/
public boolean allMatch(final Predicate<T> predicate)
{
return StreamSupport.stream(this.source.spliterator(), false).allMatch(predicate);
return StreamSupport.stream(this.source.spliterator(), this.parallel).allMatch(predicate);
}

/**
Expand All @@ -51,7 +66,7 @@ public boolean allMatch(final Predicate<T> predicate)
*/
public boolean anyMatch(final Predicate<T> predicate)
{
return StreamSupport.stream(this.source.spliterator(), false).anyMatch(predicate);
return StreamSupport.stream(this.source.spliterator(), this.parallel).anyMatch(predicate);
}

/**
Expand Down Expand Up @@ -89,6 +104,28 @@ public SortedSet<T> collectToSortedSet()
return Iterables.asSortedSet(this.source);
}

/**
* Disable parallelization in streams from this StreamIterator
*
* @return The StreamIterator with parallelization disabled
*/
public StreamIterable<T> disableParallelization()
{
this.parallel = false;
return this;
}

/**
* Enable parallelization in streams from this StreamIterator
*
* @return The StreamIterator with parallelization enabled
*/
public StreamIterable<T> enableParallelization()
{
this.parallel = true;
return this;
}

/**
* Filter an {@link Iterable}
*
Expand All @@ -98,7 +135,7 @@ public SortedSet<T> collectToSortedSet()
*/
public StreamIterable<T> filter(final Predicate<T> filter)
{
return new StreamIterable<>(Iterables.filter(this.source, filter));
return new StreamIterable<>(Iterables.filter(this.source, filter), this.parallel);
}

/**
Expand All @@ -116,7 +153,8 @@ public <IdentifierType> StreamIterable<T> filter(final Set<IdentifierType> filte
final Function<T, IdentifierType> identifier)
{
return new StreamIterable<>(
new FilteredIterable<T, IdentifierType>(this.source, filterSet, identifier));
new FilteredIterable<T, IdentifierType>(this.source, filterSet, identifier),
this.parallel);
}

/**
Expand All @@ -130,7 +168,7 @@ public <IdentifierType> StreamIterable<T> filter(final Set<IdentifierType> filte
*/
public <V> StreamIterable<V> flatMap(final Function<T, Iterable<? extends V>> flatMap)
{
return new StreamIterable<>(Iterables.translateMulti(this.source, flatMap));
return new StreamIterable<>(Iterables.translateMulti(this.source, flatMap), this.parallel);
}

@Override
Expand All @@ -150,7 +188,7 @@ public Iterator<T> iterator()
*/
public <V> StreamIterable<V> map(final Function<T, V> map)
{
return new StreamIterable<>(Iterables.translate(this.source, map));
return new StreamIterable<>(Iterables.translate(this.source, map), this.parallel);
}

/**
Expand All @@ -164,6 +202,7 @@ public <V> StreamIterable<V> map(final Function<T, V> map)
*/
public StreamIterable<T> truncate(final int startIndex, final int indexFromEnd)
{
return new StreamIterable<>(Iterables.truncate(this.source, startIndex, indexFromEnd));
return new StreamIterable<>(Iterables.truncate(this.source, startIndex, indexFromEnd),
this.parallel);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,25 @@

import static java.util.Arrays.asList;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;

import org.junit.Assert;
import org.junit.Test;
import org.openstreetmap.atlas.utilities.scalars.Duration;
import org.openstreetmap.atlas.utilities.time.Time;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* @author Sagar Rohankar
* @author mgostintsev
*/
public class StreamIterableTest
{
private static final Logger logger = LoggerFactory.getLogger(StreamIterableTest.class);

@Test
public void testAllMatch() throws Exception
{
Expand All @@ -22,6 +32,18 @@ public void testAllMatch() throws Exception
Assert.assertFalse(streamIterable.allMatch(n -> n % 2 == 0));
}

@Test
public void testAllMatchParallel() throws Exception
{
final StreamIterable<Integer> streamIterable = new StreamIterable<>(asList(1, 2, 3, 4),
true);

// true -> all numbers are less than 5
Assert.assertTrue(streamIterable.allMatch(n -> n < 5));
// false -> all numbers are even
Assert.assertFalse(streamIterable.allMatch(n -> n % 2 == 0));
}

@Test
public void testAnyMatch()
{
Expand All @@ -33,4 +55,39 @@ public void testAnyMatch()
Assert.assertFalse(streamIterable.allMatch(n -> n % 5 == 0));
}

@Test
public void testAnyMatchParallel()
{
final StreamIterable<Integer> streamIterable = new StreamIterable<>(asList(6, 12, 18),
true);

// true -> any number that is divisible by 9
Assert.assertTrue(streamIterable.anyMatch(n -> n % 9 == 0));
// false -> any number that is divisible by 5
Assert.assertFalse(streamIterable.allMatch(n -> n % 5 == 0));
}

@Test
public void testParallelPerformance()
{
final List<Integer> numbers = new ArrayList<>();
IntStream.range(0, 10000000).forEach(number ->
{
numbers.add(number);
});
final StreamIterable<Integer> streamIterable = new StreamIterable<>(numbers)
.disableParallelization();
Time currentTime = Time.now();
streamIterable.anyMatch(n -> n > 100000);
final Duration sequentialDuration = currentTime.elapsedSince();
logger.debug("Sequential duration was {} ms", sequentialDuration.asMilliseconds());

currentTime = Time.now();
streamIterable.enableParallelization();
streamIterable.anyMatch(n -> n > 100000);
final Duration parallelDuration = currentTime.elapsedSince();
logger.debug("Parallel duration was {} ms", parallelDuration.asMilliseconds());
Assert.assertTrue(parallelDuration.isLessThan(sequentialDuration));
}

}

0 comments on commit ec8d9f8

Please sign in to comment.