Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added parallelization support to StreamIterator #234

Merged
merged 3 commits into from
Oct 10, 2018
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,20 @@ public static <Type> Optional<Type> nth(final Iterable<Type> types, final long i
return Optional.ofNullable(result);
}

/**
* Create a {@link StreamIterable} that uses parallelization
*
* @param source
* The {@link Iterable} to use as source
* @param <Type>
* The type of the source {@link Iterable}
* @return The corresponding {@link StreamIterable}
*/
public static <Type> StreamIterable<Type> parallelStream(final Iterable<Type> source)
{
return new StreamIterable<>(source, true);
}

public static <T> void print(final Iterable<T> input, final String name)
{
System.out.println(toString(input, name));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,37 @@
public class StreamIterable<T> implements Iterable<T>
{
private final Iterable<T> source;
private boolean parallel = false;

protected StreamIterable(final Iterable<T> source)
{
this.source = source;
}

/**
* Test whether all elements from iterable match the given predicate
* Construct a new StreamIterable.
*
* @param source
* The source iterable to construct the StreamIterable from
* @param parallel
* Controls whether to use parallelization or not when streaming
*/
protected StreamIterable(final Iterable<T> source, final boolean parallel)
{
this.source = source;
this.parallel = parallel;
}

/**
* Test whether all elements from iterable match the given predicate.
*
* @param predicate
* Predicate to test
* @return {@code true} when given predicate is true for all entities in iterable, else false
*/
public boolean allMatch(final Predicate<T> predicate)
{
return StreamSupport.stream(this.source.spliterator(), false).allMatch(predicate);
return StreamSupport.stream(this.source.spliterator(), this.parallel).allMatch(predicate);
}

/**
Expand All @@ -51,7 +66,7 @@ public boolean allMatch(final Predicate<T> predicate)
*/
public boolean anyMatch(final Predicate<T> predicate)
{
return StreamSupport.stream(this.source.spliterator(), false).anyMatch(predicate);
return StreamSupport.stream(this.source.spliterator(), this.parallel).anyMatch(predicate);
}

/**
Expand Down Expand Up @@ -89,6 +104,28 @@ public SortedSet<T> collectToSortedSet()
return Iterables.asSortedSet(this.source);
}

/**
* Disable parallelization in streams from this StreamIterator
*
* @return The StreamIterator with parallelization disabled
*/
public StreamIterable<T> disableParallelization()
{
this.parallel = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thoughts on withParallelization() and withoutParallelization() that toggles the flag and return the StreamIterable object back? Might be easier to read inline.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

agreed!

Copy link
Contributor

@lucaspcram lucaspcram Oct 10, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was also planning on mentioning this, but you beat me to it. I do have one thought about this solution - it debilitates thread safety for shared instances of StreamIterable. The safest possible form of this API would be to keep the this.parallel variable but completely remove the enableParallelization and disableParallelization methods.

All that being said, I suppose it does not make too much sense for multiple threads to ever share a StreamIterable instance in practice - and the race condition that is introduced would simply involve one thread unintentionally using the parallel version of a StreamIterable method when it preempts another thread that just called enableParallelization() on the shared StreamIterable instance.

If we are fine with keeping this mutability (which I ultimately am, sorry for being pedantic), I would also prefer @MikeGost 's version for ease of use.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK. it's also pretty easy to just have it return a new StreamIterable object instead of changing the switch on the current object-- slightly worse for memory usage, but probably nbd in the grand scheme of things, and definitely more thread safe. Opinions?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would keep it as is, echoing @lucaspcram's point above about having multiple threads sharing a StreamIterable not making much sense. However, if we have a fairly large StreamIterable, that could potentially be a significant memory usage impact. Might be worth putting a note about this not being thread-safe in the parallelization case.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @MikeGost . Let's leave as is, but just note in the docstring that the class is not completely safe when shared between multiple threads (which of course, would be very strange way to use the class anyway).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍 cool, I'm on board with that. I added in a note to the docstring-- i feel like a shorter warning is more likely to be seen than a long explanation, so I kept it simple.

return this;
}

/**
* Enable parallelization in streams from this StreamIterator
*
* @return The StreamIterator with parallelization enabled
*/
public StreamIterable<T> enableParallelization()
{
this.parallel = true;
return this;
}

/**
* Filter an {@link Iterable}
*
Expand All @@ -98,7 +135,7 @@ public SortedSet<T> collectToSortedSet()
*/
public StreamIterable<T> filter(final Predicate<T> filter)
{
return new StreamIterable<>(Iterables.filter(this.source, filter));
return new StreamIterable<>(Iterables.filter(this.source, filter), this.parallel);
}

/**
Expand All @@ -116,7 +153,8 @@ public <IdentifierType> StreamIterable<T> filter(final Set<IdentifierType> filte
final Function<T, IdentifierType> identifier)
{
return new StreamIterable<>(
new FilteredIterable<T, IdentifierType>(this.source, filterSet, identifier));
new FilteredIterable<T, IdentifierType>(this.source, filterSet, identifier),
this.parallel);
}

/**
Expand All @@ -130,7 +168,7 @@ public <IdentifierType> StreamIterable<T> filter(final Set<IdentifierType> filte
*/
public <V> StreamIterable<V> flatMap(final Function<T, Iterable<? extends V>> flatMap)
{
return new StreamIterable<>(Iterables.translateMulti(this.source, flatMap));
return new StreamIterable<>(Iterables.translateMulti(this.source, flatMap), this.parallel);
}

@Override
Expand All @@ -150,7 +188,7 @@ public Iterator<T> iterator()
*/
public <V> StreamIterable<V> map(final Function<T, V> map)
{
return new StreamIterable<>(Iterables.translate(this.source, map));
return new StreamIterable<>(Iterables.translate(this.source, map), this.parallel);
}

/**
Expand All @@ -164,6 +202,7 @@ public <V> StreamIterable<V> map(final Function<T, V> map)
*/
public StreamIterable<T> truncate(final int startIndex, final int indexFromEnd)
{
return new StreamIterable<>(Iterables.truncate(this.source, startIndex, indexFromEnd));
return new StreamIterable<>(Iterables.truncate(this.source, startIndex, indexFromEnd),
this.parallel);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,25 @@

import static java.util.Arrays.asList;

import java.util.ArrayList;
import java.util.List;
import java.util.stream.IntStream;

import org.junit.Assert;
import org.junit.Test;
import org.openstreetmap.atlas.utilities.scalars.Duration;
import org.openstreetmap.atlas.utilities.time.Time;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* @author Sagar Rohankar
* @author mgostintsev
*/
public class StreamIterableTest
{
private static final Logger logger = LoggerFactory.getLogger(StreamIterableTest.class);

@Test
public void testAllMatch() throws Exception
{
Expand All @@ -22,6 +32,18 @@ public void testAllMatch() throws Exception
Assert.assertFalse(streamIterable.allMatch(n -> n % 2 == 0));
}

@Test
public void testAllMatchParallel() throws Exception
{
final StreamIterable<Integer> streamIterable = new StreamIterable<>(asList(1, 2, 3, 4),
true);

// true -> all numbers are less than 5
Assert.assertTrue(streamIterable.allMatch(n -> n < 5));
// false -> all numbers are even
Assert.assertFalse(streamIterable.allMatch(n -> n % 2 == 0));
}

@Test
public void testAnyMatch()
{
Expand All @@ -33,4 +55,39 @@ public void testAnyMatch()
Assert.assertFalse(streamIterable.allMatch(n -> n % 5 == 0));
}

@Test
public void testAnyMatchParallel()
{
final StreamIterable<Integer> streamIterable = new StreamIterable<>(asList(6, 12, 18),
true);

// true -> any number that is divisible by 9
Assert.assertTrue(streamIterable.anyMatch(n -> n % 9 == 0));
// false -> any number that is divisible by 5
Assert.assertFalse(streamIterable.allMatch(n -> n % 5 == 0));
}

@Test
public void testParallelPerformance()
{
final List<Integer> numbers = new ArrayList<>();
IntStream.range(0, 10000000).forEach(number ->
{
numbers.add(number);
});
final StreamIterable<Integer> streamIterable = new StreamIterable<>(numbers)
.disableParallelization();
Time currentTime = Time.now();
streamIterable.anyMatch(n -> n > 100000);
final Duration sequentialDuration = currentTime.elapsedSince();
logger.debug("Sequential duration was {} ms", sequentialDuration.asMilliseconds());

currentTime = Time.now();
streamIterable.enableParallelization();
streamIterable.anyMatch(n -> n > 100000);
final Duration parallelDuration = currentTime.elapsedSince();
logger.debug("Parallel duration was {} ms", parallelDuration.asMilliseconds());
Assert.assertTrue(parallelDuration.isLessThan(sequentialDuration));
}

}