Skip to content

Commit

Permalink
Merge pull request #9392 from lnkuiper/parquet_encryption
Browse files Browse the repository at this point in the history
Parquet Encryption
  • Loading branch information
Mytherin authored Nov 8, 2023
2 parents fa158b2 + 2ece048 commit 4915dd7
Show file tree
Hide file tree
Showing 51 changed files with 13,782 additions and 233 deletions.
24 changes: 16 additions & 8 deletions extension/parquet/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,29 @@ cmake_minimum_required(VERSION 3.5)
project(ParquetExtension)

include_directories(
include ../../third_party/parquet ../../third_party/snappy
../../third_party/miniz ../../third_party/thrift
../../third_party/zstd/include)
include
../../third_party/parquet
../../third_party/thrift
../../third_party/snappy
../../third_party/zstd/include
../../third_party/mbedtls
../../third_party/mbedtls/include)

set(PARQUET_EXTENSION_FILES
column_reader.cpp
column_writer.cpp
parquet_crypto.cpp
parquet_extension.cpp
parquet_metadata.cpp
parquet_reader.cpp
parquet_statistics.cpp
parquet_timestamp.cpp
parquet_writer.cpp
parquet_statistics.cpp
serialize_parquet.cpp
zstd_file_system.cpp
column_reader.cpp)
zstd_file_system.cpp)

if(NOT CLANG_TIDY)
# parquet/thrift/snappy
set(PARQUET_EXTENSION_FILES
${PARQUET_EXTENSION_FILES}
../../third_party/parquet/parquet_constants.cpp
Expand All @@ -28,7 +34,10 @@ if(NOT CLANG_TIDY)
../../third_party/thrift/thrift/transport/TTransportException.cpp
../../third_party/thrift/thrift/transport/TBufferTransports.cpp
../../third_party/snappy/snappy.cc
../../third_party/snappy/snappy-sinksource.cc
../../third_party/snappy/snappy-sinksource.cc)
# zstd
set(PARQUET_EXTENSION_FILES
${PARQUET_EXTENSION_FILES}
../../third_party/zstd/decompress/zstd_ddict.cpp
../../third_party/zstd/decompress/huf_decompress.cpp
../../third_party/zstd/decompress/zstd_decompress.cpp
Expand All @@ -53,7 +62,6 @@ if(NOT CLANG_TIDY)
endif()

build_static_extension(parquet ${PARQUET_EXTENSION_FILES})

set(PARAMETERS "-warnings")
build_loadable_extension(parquet ${PARAMETERS} ${PARQUET_EXTENSION_FILES})

Expand Down
12 changes: 5 additions & 7 deletions extension/parquet/column_reader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ void ColumnReader::PrepareRead(parquet_filter_t &filter) {
bss_decoder.reset();
block.reset();
PageHeader page_hdr;
page_hdr.read(protocol);
reader.Read(page_hdr, *protocol);

switch (page_hdr.type) {
case PageType::DATA_PAGE_V2:
Expand Down Expand Up @@ -287,7 +287,7 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) {
uncompressed = true;
}
if (uncompressed) {
trans.read(block->ptr, page_hdr.compressed_page_size);
reader.ReadData(*protocol, block->ptr, page_hdr.compressed_page_size);
return;
}

Expand All @@ -299,7 +299,7 @@ void ColumnReader::PreparePageV2(PageHeader &page_hdr) {
auto compressed_bytes = page_hdr.compressed_page_size - uncompressed_bytes;

AllocateCompressed(compressed_bytes);
trans.read(compressed_buffer.ptr, compressed_bytes);
reader.ReadData(*protocol, compressed_buffer.ptr, compressed_bytes);

DecompressInternal(chunk->meta_data.codec, compressed_buffer.ptr, compressed_bytes, block->ptr + uncompressed_bytes,
page_hdr.uncompressed_page_size - uncompressed_bytes);
Expand All @@ -318,19 +318,17 @@ void ColumnReader::AllocateCompressed(idx_t size) {
}

void ColumnReader::PreparePage(PageHeader &page_hdr) {
auto &trans = reinterpret_cast<ThriftFileTransport &>(*protocol->getTransport());

AllocateBlock(page_hdr.uncompressed_page_size + 1);
if (chunk->meta_data.codec == CompressionCodec::UNCOMPRESSED) {
if (page_hdr.compressed_page_size != page_hdr.uncompressed_page_size) {
throw std::runtime_error("Page size mismatch");
}
trans.read((uint8_t *)block->ptr, page_hdr.compressed_page_size);
reader.ReadData(*protocol, block->ptr, page_hdr.compressed_page_size);
return;
}

AllocateCompressed(page_hdr.compressed_page_size + 1);
trans.read((uint8_t *)compressed_buffer.ptr, page_hdr.compressed_page_size);
reader.ReadData(*protocol, compressed_buffer.ptr, page_hdr.compressed_page_size);

DecompressInternal(chunk->meta_data.codec, compressed_buffer.ptr, page_hdr.compressed_page_size, block->ptr,
page_hdr.uncompressed_page_size);
Expand Down
8 changes: 4 additions & 4 deletions extension/parquet/column_writer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
#include "duckdb/common/mutex.hpp"
#include "duckdb/common/operator/comparison_operators.hpp"
#include "duckdb/common/serializer/buffered_file_writer.hpp"
#include "duckdb/common/serializer/memory_stream.hpp"
#include "duckdb/common/serializer/write_stream.hpp"
#include "duckdb/common/string_map_set.hpp"
#include "duckdb/common/types/chunk_collection.hpp"
#include "duckdb/common/types/date.hpp"
#include "duckdb/common/types/hugeint.hpp"
#include "duckdb/common/types/string_heap.hpp"
#include "duckdb/common/types/time.hpp"
#include "duckdb/common/types/timestamp.hpp"
#include "duckdb/common/serializer/write_stream.hpp"
#include "duckdb/common/serializer/memory_stream.hpp"
#endif

#include "miniz_wrapper.hpp"
Expand Down Expand Up @@ -678,11 +678,11 @@ void BasicColumnWriter::FinalizeWrite(ColumnWriterState &state_p) {
for (auto &write_info : state.write_info) {
D_ASSERT(write_info.page_header.uncompressed_page_size > 0);
auto header_start_offset = column_writer.GetTotalWritten();
write_info.page_header.write(writer.GetProtocol());
writer.Write(write_info.page_header);
// total uncompressed size in the column chunk includes the header size (!)
total_uncompressed_size += column_writer.GetTotalWritten() - header_start_offset;
total_uncompressed_size += write_info.page_header.uncompressed_page_size;
column_writer.WriteData(write_info.compressed_data, write_info.compressed_size);
writer.WriteData(write_info.compressed_data, write_info.compressed_size);
}
column_chunk.meta_data.total_compressed_size = column_writer.GetTotalWritten() - start_offset;
column_chunk.meta_data.total_uncompressed_size = total_uncompressed_size;
Expand Down
26 changes: 26 additions & 0 deletions extension/parquet/include/parquet.json
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,26 @@
],
"pointer_type": "none"
},
{
"class": "ParquetEncryptionConfig",
"includes": [
"parquet_crypto.hpp"
],
"members": [
{
"id": 100,
"name": "footer_key",
"type": "string"
},
{
"id": 101,
"name": "column_keys",
"type": "unordered_map<string, string>"
}
],
"pointer_type": "shared_ptr",
"constructor": ["$ClientContext"]
},
{
"class": "ParquetOptions",
"includes": [
Expand All @@ -53,6 +73,12 @@
"id": 103,
"name": "schema",
"type": "vector<ParquetColumnDefinition>"
},
{
"id": 104,
"name": "encryption_config",
"type": "shared_ptr<ParquetEncryptionConfig>",
"default": "nullptr"
}
],
"pointer_type": "none"
Expand Down
87 changes: 87 additions & 0 deletions extension/parquet/include/parquet_crypto.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
//===----------------------------------------------------------------------===//
// DuckDB
//
// parquet_crypto.hpp
//
//
//===----------------------------------------------------------------------===/

#pragma once

#include "parquet_types.h"

#ifndef DUCKDB_AMALGAMATION
#include "duckdb/storage/object_cache.hpp"
#endif

namespace duckdb {

using duckdb_apache::thrift::TBase;
using duckdb_apache::thrift::protocol::TProtocol;

class BufferedFileWriter;

class ParquetKeys : public ObjectCacheEntry {
public:
static ParquetKeys &Get(ClientContext &context);

public:
void AddKey(const string &key_name, const string &key);
bool HasKey(const string &key_name) const;
const string &GetKey(const string &key_name) const;

public:
static string ObjectType();
string GetObjectType() override;

private:
unordered_map<string, string> keys;
};

class ParquetEncryptionConfig {
public:
explicit ParquetEncryptionConfig(ClientContext &context);
ParquetEncryptionConfig(ClientContext &context, const Value &arg);

public:
static shared_ptr<ParquetEncryptionConfig> Create(ClientContext &context, const Value &arg);
const string &GetFooterKey() const;

public:
void Serialize(Serializer &serializer) const;
static shared_ptr<ParquetEncryptionConfig> Deserialize(Deserializer &deserializer);

private:
ClientContext &context;
//! Name of the key used for the footer
string footer_key;
//! Mapping from column name to key name
unordered_map<string, string> column_keys;
};

class ParquetCrypto {
public:
//! Encrypted modules
static constexpr uint32_t LENGTH_BYTES = 4;
static constexpr uint32_t NONCE_BYTES = 12;
static constexpr uint32_t TAG_BYTES = 16;

//! Block size we encrypt/decrypt
static constexpr uint32_t CRYPTO_BLOCK_SIZE = 4096;

public:
//! Decrypt and read a Thrift object from the transport protocol
static uint32_t Read(TBase &object, TProtocol &iprot, const string &key);
//! Encrypt and write a Thrift object to the transport protocol
static uint32_t Write(const TBase &object, TProtocol &oprot, const string &key);
//! Decrypt and read a buffer
static uint32_t ReadData(TProtocol &iprot, const data_ptr_t buffer, const uint32_t buffer_size, const string &key);
//! Encrypt and write a buffer to a file
static uint32_t WriteData(TProtocol &oprot, const const_data_ptr_t buffer, const uint32_t buffer_size,
const string &key);

public:
static void AddKey(ClientContext &context, const FunctionParameters &parameters);
};

} // namespace duckdb
7 changes: 7 additions & 0 deletions extension/parquet/include/parquet_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class Allocator;
class ClientContext;
class BaseStatistics;
class TableFilterSet;
class ParquetEncryptionConfig;

struct ParquetReaderPrefetchConfig {
// Percentage of data in a row group span that should be scanned for enabling whole group prefetch
Expand Down Expand Up @@ -86,6 +87,8 @@ struct ParquetOptions {

bool binary_as_string = false;
bool file_row_number = false;
shared_ptr<ParquetEncryptionConfig> encryption_config;

MultiFileReaderOptions file_options;
vector<ParquetColumnDefinition> schema;

Expand Down Expand Up @@ -125,6 +128,10 @@ class ParquetReader {

const duckdb_parquet::format::FileMetaData *GetFileMetadata();

uint32_t Read(duckdb_apache::thrift::TBase &object, TProtocol &iprot);
uint32_t ReadData(duckdb_apache::thrift::protocol::TProtocol &iprot, const data_ptr_t buffer,
const uint32_t buffer_size);

unique_ptr<BaseStatistics> ReadStatistics(const string &name);
static LogicalType DeriveLogicalType(const SchemaElement &s_ele, bool binary_as_string);

Expand Down
8 changes: 7 additions & 1 deletion extension/parquet/include/parquet_writer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
namespace duckdb {
class FileSystem;
class FileOpener;
class ParquetEncryptionConfig;

class Serializer;
class Deserializer;
Expand Down Expand Up @@ -62,7 +63,8 @@ class ParquetWriter {
public:
ParquetWriter(FileSystem &fs, string file_name, vector<LogicalType> types, vector<string> names,
duckdb_parquet::format::CompressionCodec::type codec, ChildFieldIDs field_ids,
const vector<pair<string, string>> &kv_metadata);
const vector<pair<string, string>> &kv_metadata,
shared_ptr<ParquetEncryptionConfig> encryption_config);

public:
void PrepareRowGroup(ColumnDataCollection &buffer, PreparedRowGroup &result);
Expand All @@ -88,6 +90,9 @@ class ParquetWriter {

static CopyTypeSupport TypeIsSupported(const LogicalType &type);

uint32_t Write(const duckdb_apache::thrift::TBase &object);
uint32_t WriteData(const const_data_ptr_t buffer, const uint32_t buffer_size);

private:
static CopyTypeSupport DuckDBTypeToParquetTypeInternal(const LogicalType &duckdb_type,
duckdb_parquet::format::Type::type &type);
Expand All @@ -96,6 +101,7 @@ class ParquetWriter {
vector<string> column_names;
duckdb_parquet::format::CompressionCodec::type codec;
ChildFieldIDs field_ids;
shared_ptr<ParquetEncryptionConfig> encryption_config;

unique_ptr<BufferedFileWriter> writer;
shared_ptr<duckdb_apache::thrift::protocol::TProtocol> protocol;
Expand Down
37 changes: 18 additions & 19 deletions extension/parquet/parquet_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,34 @@
for x in [
'extension/parquet/include',
'third_party/parquet',
'third_party/snappy',
'third_party/thrift',
'third_party/snappy',
'third_party/zstd/include',
'third_party/mbedtls',
'third_party/mbedtls/include',
]
]
# source files
source_files = [
os.path.sep.join(x.split('/'))
for x in [
'extension/parquet/parquet_extension.cpp',
'extension/parquet/column_reader.cpp',
'extension/parquet/column_writer.cpp',
'extension/parquet/parquet_crypto.cpp',
'extension/parquet/parquet_extension.cpp',
'extension/parquet/parquet_metadata.cpp',
'extension/parquet/parquet_reader.cpp',
'extension/parquet/parquet_statistics.cpp',
'extension/parquet/parquet_timestamp.cpp',
'extension/parquet/parquet_writer.cpp',
'extension/parquet/serialize_parquet.cpp',
'extension/parquet/zstd_file_system.cpp',
]
]
# parquet/thrift/snappy
source_files += [
os.path.sep.join(x.split('/'))
for x in [
'third_party/parquet/parquet_constants.cpp',
'third_party/parquet/parquet_types.cpp',
'third_party/thrift/thrift/protocol/TProtocol.cpp',
Expand All @@ -40,11 +56,6 @@
'third_party/zstd/common/zstd_common.cpp',
'third_party/zstd/common/error_private.cpp',
'third_party/zstd/common/xxhash.cpp',
]
]
source_files += [
os.path.sep.join(x.split('/'))
for x in [
'third_party/zstd/compress/fse_compress.cpp',
'third_party/zstd/compress/hist.cpp',
'third_party/zstd/compress/huf_compress.cpp',
Expand All @@ -59,15 +70,3 @@
'third_party/zstd/compress/zstd_opt.cpp',
]
]
source_files += [
os.path.sep.join(x.split('/'))
for x in [
'extension/parquet/parquet_reader.cpp',
'extension/parquet/parquet_timestamp.cpp',
'extension/parquet/parquet_writer.cpp',
'extension/parquet/column_reader.cpp',
'extension/parquet/parquet_statistics.cpp',
'extension/parquet/parquet_metadata.cpp',
'extension/parquet/zstd_file_system.cpp',
]
]
Loading

0 comments on commit 4915dd7

Please sign in to comment.