Skip to content

Commit

Permalink
[C++] Support iterable types for RowEncoder (#1215)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice authored Dec 7, 2023
1 parent 71121ca commit fa7c7a1
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/fury/encoder/row_encode_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ template <typename C> struct DefaultWriteVisitor {

DefaultWriteVisitor(C &cont) : cont(cont) {}

template <typename> void Visit(std::unique_ptr<RowWriter> writer) {
template <typename, typename T> void Visit(std::unique_ptr<T> writer) {
cont.push_back(std::move(writer));
}
};
Expand Down
70 changes: 63 additions & 7 deletions src/fury/encoder/row_encoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,37 +17,93 @@
#pragma once

#include "fury/encoder/row_encode_trait.h"
#include "src/fury/row/writer.h"
#include <memory>
#include <type_traits>

namespace fury {

namespace encoder {

namespace details {

template <typename T, typename Enabled = void> struct GetWriterTypeImpl;

template <typename T>
struct GetWriterTypeImpl<T,
std::enable_if_t<details::IsClassButNotBuiltin<T>>> {
using type = RowWriter;
};

template <typename T>
struct GetWriterTypeImpl<T, std::enable_if_t<details::IsArray<T>>> {
using type = ArrayWriter;
};

template <typename T> using GetWriterType = typename GetWriterTypeImpl<T>::type;

template <typename T, std::enable_if_t<
std::is_same_v<GetWriterType<T>, RowWriter>, int> = 0>
auto GetSchemaOrType() {
return RowEncodeTrait<T>::Schema();
}

template <
typename T,
std::enable_if_t<std::is_same_v<GetWriterType<T>, ArrayWriter>, int> = 0>
auto GetSchemaOrType() {
return std::dynamic_pointer_cast<arrow::ListType>(RowEncodeTrait<T>::Type());
}

} // namespace details

template <typename T> struct RowEncoder {
static_assert(std::is_class_v<T>, "currently only class types are supported");
static_assert(details::IsClassButNotBuiltin<T> || details::IsArray<T>,
"only class types and iterable types are supported");

using WriterType = details::GetWriterType<T>;

RowEncoder()
: writer_(std::make_unique<RowWriter>(RowEncodeTrait<T>::Schema())) {
: writer_(std::make_unique<WriterType>(details::GetSchemaOrType<T>())) {}

template <typename U = WriterType,
std::enable_if_t<std::is_same_v<U, RowWriter>, int> = 0>
void Encode(const T &value) {
writer_->Reset();
RowEncodeTrait<T>::Write(DefaultWriteVisitor{children_}, value,
GetWriter());
}

template <typename U = WriterType,
std::enable_if_t<std::is_same_v<U, ArrayWriter>, int> = 0>
void Encode(const T &value) {
writer_->Reset(value.size());
RowEncodeTrait<T>::Write(DefaultWriteVisitor{children_}, value,
GetWriter());
}

RowWriter &GetWriter() const { return *writer_.get(); }
const std::vector<std::unique_ptr<RowWriter>> &GetChildren() const {
WriterType &GetWriter() const { return *writer_.get(); }
const std::vector<std::unique_ptr<Writer>> &GetChildren() const {
return children_;
}
const arrow::Schema &GetSchema() const { return *writer_->schema().get(); }

template <typename U = WriterType,
std::enable_if_t<std::is_same_v<U, RowWriter>, int> = 0>
const arrow::Schema &GetSchema() const {
return *writer_->schema().get();
}

template <typename U = WriterType,
std::enable_if_t<std::is_same_v<U, ArrayWriter>, int> = 0>
const arrow::ListType &GetType() const {
return *writer_->type().get();
}

void ResetChildren() { children_.clear(); }

private:
std::unique_ptr<RowWriter> writer_;
std::vector<std::unique_ptr<RowWriter>> children_;
std::unique_ptr<WriterType> writer_;
std::vector<std::unique_ptr<Writer>> children_;
};

} // namespace encoder
Expand Down
63 changes: 63 additions & 0 deletions src/fury/encoder/row_encoder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,68 @@ TEST(RowEncoder, Simple) {
ASSERT_FLOAT_EQ(y_row->GetFloat(0), 1.23);
}

struct C {
std::vector<A> x;
bool y;
};

FURY_FIELD_INFO(C, x, y);

TEST(RowEncoder, SimpleArray) {
std::vector<C> v{C{{{1, "a"}, {2, "b"}}, false},
C{{{1.1, "x"}, {2.2, "y"}, {3.3, "z"}}, true}};

encoder::RowEncoder<decltype(v)> enc;

auto &type = enc.GetType();
ASSERT_EQ(type.name(), "list");
ASSERT_EQ(type.field(0)->type()->name(), "struct");
ASSERT_EQ(type.field(0)->type()->field(0)->name(), "x");
ASSERT_EQ(type.field(0)->type()->field(1)->name(), "y");
ASSERT_EQ(type.field(0)->type()->field(0)->type()->name(), "list");
ASSERT_EQ(type.field(0)->type()->field(0)->type()->field(0)->type()->name(),
"struct");
ASSERT_EQ(type.field(0)
->type()
->field(0)
->type()
->field(0)
->type()
->field(0)
->type()
->name(),
"float");
ASSERT_EQ(type.field(0)
->type()
->field(0)
->type()
->field(0)
->type()
->field(1)
->type()
->name(),
"utf8");
ASSERT_EQ(type.field(0)->type()->field(1)->type()->name(), "bool");

enc.Encode(v);

auto data = enc.GetWriter().CopyToArrayData();
ASSERT_EQ(data->GetStruct(0)->GetArray(0)->GetStruct(0)->GetFloat(0), 1);
ASSERT_EQ(data->GetStruct(0)->GetArray(0)->GetStruct(1)->GetFloat(0), 2);
ASSERT_FLOAT_EQ(data->GetStruct(1)->GetArray(0)->GetStruct(0)->GetFloat(0),
1.1);
ASSERT_FLOAT_EQ(data->GetStruct(1)->GetArray(0)->GetStruct(1)->GetFloat(0),
2.2);
ASSERT_FLOAT_EQ(data->GetStruct(1)->GetArray(0)->GetStruct(2)->GetFloat(0),
3.3);
ASSERT_EQ(data->GetStruct(0)->GetArray(0)->GetStruct(0)->GetString(1), "a");
ASSERT_EQ(data->GetStruct(0)->GetArray(0)->GetStruct(1)->GetString(1), "b");
ASSERT_EQ(data->GetStruct(1)->GetArray(0)->GetStruct(0)->GetString(1), "x");
ASSERT_EQ(data->GetStruct(1)->GetArray(0)->GetStruct(1)->GetString(1), "y");
ASSERT_EQ(data->GetStruct(1)->GetArray(0)->GetStruct(2)->GetString(1), "z");
ASSERT_EQ(data->GetStruct(0)->GetBoolean(1), false);
ASSERT_EQ(data->GetStruct(1)->GetBoolean(1), true);
}

} // namespace test2
} // namespace fury
4 changes: 2 additions & 2 deletions src/fury/row/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ class Writer {
}
}

virtual ~Writer() = default;

protected:
explicit Writer(int bytes_before_bitmap);

explicit Writer(Writer *parent_writer, int bytes_before_bitmap);

virtual ~Writer() = default;

std::shared_ptr<Buffer> buffer_;

// The offset of the global buffer where we start to WriteString this
Expand Down

0 comments on commit fa7c7a1

Please sign in to comment.