Skip to content

Commit

Permalink
fix: respect max messages per batch option (#512)
Browse files Browse the repository at this point in the history
* fix: respect max messages per batch option

* fix formatting
  • Loading branch information
hannahrogers-google authored Nov 17, 2022
1 parent ade1022 commit f436045
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import com.google.common.annotations.VisibleForTesting;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.concurrent.GuardedBy;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
import org.apache.spark.sql.connector.read.streaming.MicroBatchStream;
Expand All @@ -37,6 +38,9 @@ public class PslMicroBatchStream extends BaseDataStream implements MicroBatchStr
private final PerTopicHeadOffsetReader headOffsetReader;
private final PslReadDataSourceOptions options;

@GuardedBy("this")
private SparkSourceOffset lastEndOffset = null;

@VisibleForTesting
PslMicroBatchStream(
CursorClient cursorClient,
Expand All @@ -61,8 +65,17 @@ public class PslMicroBatchStream extends BaseDataStream implements MicroBatchStr
}

@Override
public SparkSourceOffset latestOffset() {
return PslSparkUtils.toSparkSourceOffset(headOffsetReader.getHeadOffset());
public synchronized SparkSourceOffset latestOffset() {
SparkSourceOffset newStartingOffset = (lastEndOffset == null) ? initialOffset() : lastEndOffset;
SparkSourceOffset headOffset =
PslSparkUtils.toSparkSourceOffset(headOffsetReader.getHeadOffset());
lastEndOffset =
PslSparkUtils.getSparkEndOffset(
headOffset,
newStartingOffset,
options.maxMessagesPerBatch(),
headOffset.getPartitionOffsetMap().size());
return lastEndOffset;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ public class PslMicroBatchStreamTest {
private static final PslReadDataSourceOptions OPTIONS =
PslReadDataSourceOptions.builder()
.setSubscriptionPath(UnitTestExamples.exampleSubscriptionPath())
.setMaxMessagesPerBatch(500)
.build();
private final CursorClient cursorClient = mock(CursorClient.class);
private final MultiPartitionCommitter committer = mock(MultiPartitionCommitter.class);
Expand All @@ -52,15 +53,26 @@ public class PslMicroBatchStreamTest {

@Test
public void testLatestOffset() {
// Called when retreiving the initial offset on the first 'latestOffset' call
when(partitionCountReader.getPartitionCount()).thenReturn(2);
when(cursorClient.listPartitionCursors(UnitTestExamples.exampleSubscriptionPath()))
.thenReturn(ApiFutures.immediateFuture(ImmutableMap.of()));
when(headOffsetReader.getHeadOffset()).thenReturn(createPslSourceOffset(301L, 200L));
// First return head offsets that will not exceed maxMessagesPerBatch, then exceed the limit
when(headOffsetReader.getHeadOffset())
.thenReturn(createPslSourceOffset(301L, 200L))
.thenReturn(createPslSourceOffset(1000L, 250L));
assertThat(((SparkSourceOffset) stream.latestOffset()).getPartitionOffsetMap())
.containsExactly(
Partition.of(0L),
SparkPartitionOffset.create(Partition.of(0L), 300L),
Partition.of(1L),
SparkPartitionOffset.create(Partition.of(1L), 199L));
assertThat(((SparkSourceOffset) stream.latestOffset()).getPartitionOffsetMap())
.containsExactly(
Partition.of(0L),
SparkPartitionOffset.create(Partition.of(0L), 800L),
Partition.of(1L),
SparkPartitionOffset.create(Partition.of(1L), 249L));
}

@Test
Expand Down

0 comments on commit f436045

Please sign in to comment.