Skip to content

Commit

Permalink
Add a flat map writer key limit
Browse files Browse the repository at this point in the history
  • Loading branch information
sdruzkin authored and ARUNACHALAM THIRUPATHI committed Jun 30, 2022
1 parent d1b20f1 commit cd71cf0
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
import java.util.Set;

import static com.facebook.presto.orc.OrcWriterOptions.DEFAULT_MAX_COMPRESSION_BUFFER_SIZE;
import static com.facebook.presto.orc.OrcWriterOptions.DEFAULT_MAX_FLATTENED_MAP_KEY_COUNT;
import static com.facebook.presto.orc.OrcWriterOptions.DEFAULT_MAX_STRING_STATISTICS_LIMIT;
import static com.facebook.presto.orc.OrcWriterOptions.DEFAULT_PRESERVE_DIRECT_ENCODING_STRIPE_COUNT;
import static com.google.common.base.Preconditions.checkArgument;
import static io.airlift.units.DataSize.Unit.BYTE;
import static java.lang.Math.toIntExact;
import static java.util.Objects.requireNonNull;
Expand All @@ -43,6 +45,7 @@ public class ColumnWriterOptions
private final CompressionBufferPool compressionBufferPool;
private final Set<Integer> flattenedNodes;
private final boolean mapStatisticsEnabled;
private final int maxFlattenedMapKeyCount;

public ColumnWriterOptions(
CompressionKind compressionKind,
Expand All @@ -56,11 +59,14 @@ public ColumnWriterOptions(
int preserveDirectEncodingStripeCount,
CompressionBufferPool compressionBufferPool,
Set<Integer> flattenedNodes,
boolean mapStatisticsEnabled)
boolean mapStatisticsEnabled,
int maxFlattenedMapKeyCount)
{
checkArgument(maxFlattenedMapKeyCount > 0, "maxFlattenedMapKeyCount must be positive: %s", maxFlattenedMapKeyCount);
requireNonNull(compressionMaxBufferSize, "compressionMaxBufferSize is null");

this.compressionKind = requireNonNull(compressionKind, "compressionKind is null");
this.compressionLevel = requireNonNull(compressionLevel, "compressionLevel is null");
requireNonNull(compressionMaxBufferSize, "compressionMaxBufferSize is null");
this.compressionMaxBufferSize = toIntExact(compressionMaxBufferSize.toBytes());
this.stringStatisticsLimit = requireNonNull(stringStatisticsLimit, "stringStatisticsLimit is null");
this.integerDictionaryEncodingEnabled = integerDictionaryEncodingEnabled;
Expand All @@ -71,6 +77,7 @@ public ColumnWriterOptions(
this.compressionBufferPool = requireNonNull(compressionBufferPool, "compressionBufferPool is null");
this.flattenedNodes = requireNonNull(flattenedNodes, "flattenedNodes is null");
this.mapStatisticsEnabled = mapStatisticsEnabled;
this.maxFlattenedMapKeyCount = maxFlattenedMapKeyCount;
}

public CompressionKind getCompressionKind()
Expand Down Expand Up @@ -133,6 +140,11 @@ public boolean isMapStatisticsEnabled()
return mapStatisticsEnabled;
}

public int getMaxFlattenedMapKeyCount()
{
return maxFlattenedMapKeyCount;
}

/**
* Create a copy of this ColumnWriterOptions, but disable string and integer dictionary encodings.
*/
Expand All @@ -158,7 +170,8 @@ public Builder toBuilder()
.setPreserveDirectEncodingStripeCount(getPreserveDirectEncodingStripeCount())
.setCompressionBufferPool(getCompressionBufferPool())
.setFlattenedNodes(getFlattenedNodes())
.setMapStatisticsEnabled(isMapStatisticsEnabled());
.setMapStatisticsEnabled(isMapStatisticsEnabled())
.setMaxFlattenedMapKeyCount(getMaxFlattenedMapKeyCount());
}

public static Builder builder()
Expand All @@ -180,6 +193,7 @@ public static class Builder
private CompressionBufferPool compressionBufferPool = new LastUsedCompressionBufferPool();
private Set<Integer> flattenedNodes = ImmutableSet.of();
private boolean mapStatisticsEnabled;
private int maxFlattenedMapKeyCount = DEFAULT_MAX_FLATTENED_MAP_KEY_COUNT;

private Builder() {}

Expand Down Expand Up @@ -255,6 +269,12 @@ public Builder setMapStatisticsEnabled(boolean mapStatisticsEnabled)
return this;
}

public Builder setMaxFlattenedMapKeyCount(int maxFlattenedMapKeyCount)
{
this.maxFlattenedMapKeyCount = maxFlattenedMapKeyCount;
return this;
}

public ColumnWriterOptions build()
{
return new ColumnWriterOptions(
Expand All @@ -269,7 +289,8 @@ public ColumnWriterOptions build()
preserveDirectEncodingStripeCount,
compressionBufferPool,
flattenedNodes,
mapStatisticsEnabled);
mapStatisticsEnabled,
maxFlattenedMapKeyCount);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ public OrcWriter(
.setCompressionBufferPool(compressionBufferPool)
.setFlattenedNodes(flattenedNodes)
.setMapStatisticsEnabled(options.isMapStatisticsEnabled())
.setMaxFlattenedMapKeyCount(options.getMaxFlattenedMapKeyCount())
.build();
recordValidation(validation -> validation.setCompression(compressionKind));
recordValidation(validation -> validation.setFlattenedNodes(flattenedNodes));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ public class OrcWriterOptions
public static final DataSize DEFAULT_DWRF_STRIPE_CACHE_MAX_SIZE = new DataSize(8, MEGABYTE);
public static final DwrfStripeCacheMode DEFAULT_DWRF_STRIPE_CACHE_MODE = INDEX_AND_FOOTER;
public static final int DEFAULT_PRESERVE_DIRECT_ENCODING_STRIPE_COUNT = 0;
public static final int DEFAULT_MAX_FLATTENED_MAP_KEY_COUNT = 20000;

private final OrcWriterFlushPolicy flushPolicy;
private final int rowGroupMaxRowCount;
Expand All @@ -65,6 +66,7 @@ public class OrcWriterOptions
private final Optional<DwrfStripeCacheOptions> dwrfWriterOptions;
private final int preserveDirectEncodingStripeCount;
private final boolean mapStatisticsEnabled;
private final int maxFlattenedMapKeyCount;

/**
* Contains indexes of columns (not nodes!) for which writer should use flattened encoding, e.g. flat maps.
Expand All @@ -89,7 +91,8 @@ private OrcWriterOptions(
boolean ignoreDictionaryRowGroupSizes,
int preserveDirectEncodingStripeCount,
Set<Integer> flattenedColumns,
boolean mapStatisticsEnabled)
boolean mapStatisticsEnabled,
int maxFlattenedMapKeyCount)
{
requireNonNull(flushPolicy, "flushPolicy is null");
checkArgument(rowGroupMaxRowCount >= 1, "rowGroupMaxRowCount must be at least 1");
Expand All @@ -102,6 +105,7 @@ private OrcWriterOptions(
requireNonNull(streamLayoutFactory, "streamLayoutFactory is null");
requireNonNull(dwrfWriterOptions, "dwrfWriterOptions is null");
requireNonNull(flattenedColumns, "flattenedColumns is null");
checkArgument(maxFlattenedMapKeyCount > 0, "maxFlattenedMapKeyCount must be positive: %s", maxFlattenedMapKeyCount);

this.flushPolicy = flushPolicy;
this.rowGroupMaxRowCount = rowGroupMaxRowCount;
Expand All @@ -121,6 +125,7 @@ private OrcWriterOptions(
this.preserveDirectEncodingStripeCount = preserveDirectEncodingStripeCount;
this.flattenedColumns = flattenedColumns;
this.mapStatisticsEnabled = mapStatisticsEnabled;
this.maxFlattenedMapKeyCount = maxFlattenedMapKeyCount;
}

public OrcWriterFlushPolicy getFlushPolicy()
Expand Down Expand Up @@ -213,6 +218,11 @@ public boolean isMapStatisticsEnabled()
return mapStatisticsEnabled;
}

public int getMaxFlattenedMapKeyCount()
{
return maxFlattenedMapKeyCount;
}

@Override
public String toString()
{
Expand All @@ -235,6 +245,7 @@ public String toString()
.add("preserveDirectEncodingStripeCount", preserveDirectEncodingStripeCount)
.add("flattenedColumns", flattenedColumns)
.add("mapStatisticsEnabled", mapStatisticsEnabled)
.add("maxFlattenedMapKeyCount", maxFlattenedMapKeyCount)
.toString();
}

Expand Down Expand Up @@ -270,6 +281,7 @@ public static class Builder
private int preserveDirectEncodingStripeCount = DEFAULT_PRESERVE_DIRECT_ENCODING_STRIPE_COUNT;
private Set<Integer> flattenedColumns = ImmutableSet.of();
private boolean mapStatisticsEnabled;
private int maxFlattenedMapKeyCount = DEFAULT_MAX_FLATTENED_MAP_KEY_COUNT;

public Builder withFlushPolicy(OrcWriterFlushPolicy flushPolicy)
{
Expand Down Expand Up @@ -393,6 +405,12 @@ public Builder withMapStatisticsEnabled(boolean mapStatisticsEnabled)
return this;
}

public Builder withMaxFlattenedMapKeyCount(int maxFlattenedMapKeyCount)
{
this.maxFlattenedMapKeyCount = maxFlattenedMapKeyCount;
return this;
}

public OrcWriterOptions build()
{
Optional<DwrfStripeCacheOptions> dwrfWriterOptions;
Expand Down Expand Up @@ -421,7 +439,8 @@ public OrcWriterOptions build()
ignoreDictionaryRowGroupSizes,
preserveDirectEncodingStripeCount,
flattenedColumns,
mapStatisticsEnabled);
mapStatisticsEnabled,
maxFlattenedMapKeyCount);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ public class MapFlatColumnWriter
private final PresentOutputStream presentStream;
private final CompressedMetadataWriter metadataWriter;
private final KeyManager keyManager;
private final int maxFlattenedMapKeyCount;

// Pre-create a value block with a single null value to avoid creating a block
// region for null values.
Expand Down Expand Up @@ -139,12 +140,14 @@ public MapFlatColumnWriter(
checkArgument(keyNodeIndex > 0, "keyNodeIndex is invalid: %s", keyNodeIndex);
checkArgument(valueNodeIndex > 0, "valueNodeIndex is invalid: %s", valueNodeIndex);
requireNonNull(keyStatisticsBuilderSupplier, "keyStatisticsBuilderSupplier is null");
checkArgument(columnWriterOptions.getMaxFlattenedMapKeyCount() > 0, "maxFlattenedMapKeyCount must be positive: %s", columnWriterOptions.getMaxFlattenedMapKeyCount());

this.nodeIndex = nodeIndex;
this.keyNodeIndex = keyNodeIndex;
this.valueNodeIndex = valueNodeIndex;
this.keyType = requireNonNull(keyType, "keyType is null");
this.nullValueBlock = createNullValueBlock(requireNonNull(valueType, "valueType is null"));
this.maxFlattenedMapKeyCount = columnWriterOptions.getMaxFlattenedMapKeyCount();

this.columnWriterOptions = requireNonNull(columnWriterOptions, "columnWriterOptions is null");
this.dwrfEncryptor = requireNonNull(dwrfEncryptor, "dwrfEncryptor is null");
Expand Down Expand Up @@ -418,6 +421,9 @@ public void reset()

private MapFlatValueWriter createNewValueWriter(DwrfProto.KeyInfo dwrfKey)
{
checkState(valueWriters.size() < maxFlattenedMapKeyCount - 1,
"Map column writer for node %s reached max allowed number of keys %s", nodeIndex, maxFlattenedMapKeyCount);

int valueWriterIdx = valueWriters.size();
int sequence = valueWriterIdx + SEQUENCE_START_INDEX;
ColumnWriter columnWriter = valueWriterFactory.apply(sequence);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ public void testProperties()
CompressionBufferPool compressionBufferPool = new CompressionBufferPool.LastUsedCompressionBufferPool();
Set<Integer> flattenedNodes = ImmutableSet.of(1, 5);
boolean mapStatisticsEnabled = true;
int maxFlattenedMapKeyCount = 27;

ColumnWriterOptions options = ColumnWriterOptions.builder()
.setCompressionKind(compressionKind)
Expand All @@ -58,6 +59,7 @@ public void testProperties()
.setCompressionBufferPool(compressionBufferPool)
.setFlattenedNodes(flattenedNodes)
.setMapStatisticsEnabled(mapStatisticsEnabled)
.setMaxFlattenedMapKeyCount(maxFlattenedMapKeyCount)
.build();

boolean checkDisabledDictionaryEncoding = false;
Expand All @@ -73,6 +75,7 @@ public void testProperties()
assertEquals(actual.getCompressionBufferPool(), compressionBufferPool);
assertEquals(actual.getFlattenedNodes(), flattenedNodes);
assertEquals(actual.isMapStatisticsEnabled(), mapStatisticsEnabled);
assertEquals(actual.getMaxFlattenedMapKeyCount(), maxFlattenedMapKeyCount);

if (checkDisabledDictionaryEncoding) {
assertFalse(actual.isStringDictionaryEncodingEnabled());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@
*/
package com.facebook.presto.orc;

import com.facebook.presto.common.Page;
import com.facebook.presto.common.block.BlockBuilder;
import com.facebook.presto.common.block.MapBlockBuilder;
import com.facebook.presto.common.type.MapType;
import com.facebook.presto.common.type.SqlVarbinary;
import com.facebook.presto.common.type.Type;
import com.facebook.presto.orc.metadata.CompressionKind;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import org.testng.annotations.Test;

import java.time.LocalDate;
import java.util.HashMap;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.Random;

import static com.facebook.presto.common.type.BigintType.BIGINT;
Expand All @@ -35,6 +42,7 @@
import static com.facebook.presto.common.type.VarbinaryType.VARBINARY;
import static com.facebook.presto.common.type.VarcharType.VARCHAR;
import static com.facebook.presto.orc.OrcTester.arrayType;
import static com.facebook.presto.orc.OrcTester.createOrcWriter;
import static com.facebook.presto.orc.OrcTester.mapType;
import static com.facebook.presto.orc.OrcTester.rowType;
import static com.google.common.collect.Iterables.cycle;
Expand Down Expand Up @@ -213,6 +221,54 @@ private static <K, V> void runTest(Type mapType, K key1, K key2, K key3, V value
tester.testRoundTrip(mapType, newArrayList(limit(random(map1, map2, map3, map4, map5, nullValue, EMPTY_MAP), ROWS)));
}

@Test
public void testMaxKeyLimit()
throws Exception
{
MapType mapType = (MapType) mapType(INTEGER, INTEGER);
int maxFlattenedMapKeyCount = 3;

OrcWriterOptions writerOptions = OrcWriterOptions.builder()
.withFlattenedColumns(ImmutableSet.of(0))
.withMaxFlattenedMapKeyCount(maxFlattenedMapKeyCount)
.build();

try (TempFile tempFile = new TempFile()) {
try (OrcWriter orcWriter = createOrcWriter(
tempFile.getFile(),
OrcEncoding.DWRF,
CompressionKind.ZLIB,
Optional.empty(),
ImmutableList.of(mapType),
writerOptions,
new NoOpOrcWriterStats())) {
// write a block with 2 keys
orcWriter.write(createMapPageForKeyLimitTest(mapType, maxFlattenedMapKeyCount - 1));

// write a block with 3 keys, which is the max allowed number of keys
expectThrows(IllegalStateException.class, () -> orcWriter.write(createMapPageForKeyLimitTest(mapType, maxFlattenedMapKeyCount)));
}
}
}

private static Page createMapPageForKeyLimitTest(MapType type, int keyCount)
{
Type keyType = type.getKeyType();
Type valueType = type.getValueType();
MapBlockBuilder mapBlockBuilder = (MapBlockBuilder) type.createBlockBuilder(null, 10);
BlockBuilder mapKeyBuilder = mapBlockBuilder.getKeyBlockBuilder();
BlockBuilder mapValueBuilder = mapBlockBuilder.getValueBlockBuilder();

mapBlockBuilder.beginDirectEntry();
for (int k = 0; k < keyCount; k++) {
keyType.writeLong(mapKeyBuilder, k);
valueType.writeLong(mapValueBuilder, k);
}
mapBlockBuilder.closeEntry();

return new Page(mapBlockBuilder.build());
}

private static <T> Iterable<T> random(T... elements)
{
Random rnd = new Random(LocalDate.now().toEpochDay());
Expand Down
Loading

0 comments on commit cd71cf0

Please sign in to comment.