Skip to content

Commit

Permalink
[C++] Support iterable types in RowEncodeTrait (#1212)
Browse files Browse the repository at this point in the history
  • Loading branch information
PragmaTwice authored Dec 6, 2023
1 parent c6da8ec commit 71121ca
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 12 deletions.
75 changes: 66 additions & 9 deletions src/fury/encoder/row_encode_trait.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include "fury/meta/type_traits.h"
#include "fury/row/writer.h"
#include <string_view>
#include <type_traits>
#include <utility>

namespace fury {
Expand Down Expand Up @@ -67,7 +68,19 @@ inline constexpr bool IsString =
meta::IsOneOf<T, std::string, std::string_view>::value;

template <typename T>
inline constexpr bool IsClassButNotBuiltin = std::is_class_v<T> && !IsString<T>;
inline constexpr bool IsArray = meta::IsIterable<T> && !IsString<T>;

template <typename T>
inline constexpr bool IsClassButNotBuiltin =
std::is_class_v<T> && !(IsString<T> || IsArray<T>);

inline decltype(auto) GetChildType(RowWriter &writer, int index) {
return writer.schema()->field(index)->type();
}

inline decltype(auto) GetChildType(ArrayWriter &writer, int index) {
return writer.type()->field(0)->type();
}

} // namespace details

Expand Down Expand Up @@ -106,8 +119,10 @@ struct RowEncodeTrait<
return details::ArrowSchemaBasicType<std::remove_cv_t<T>>::value();
}

template <typename V>
static void Write(V &&, const T &value, RowWriter &writer, int index) {
template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&, const T &value, W &writer, int index) {
writer.Write(index, value);
}
};
Expand All @@ -117,8 +132,10 @@ struct RowEncodeTrait<
T, std::enable_if_t<details::IsString<std::remove_cv_t<T>>>> {
static auto Type() { return arrow::utf8(); }

template <typename V>
static void Write(V &&, const T &value, RowWriter &writer, int index) {
template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&, const T &value, W &writer, int index) {
writer.WriteString(index, value);
}
};
Expand Down Expand Up @@ -165,13 +182,14 @@ struct RowEncodeTrait<
std::make_index_sequence<FieldInfo::Size>());
}

template <typename V>
static void Write(V &&visitor, const T &value, RowWriter &writer, int index) {
template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&visitor, const T &value, W &writer, int index) {
auto offset = writer.cursor();

auto inner_writer = std::make_unique<RowWriter>(
arrow::schema(writer.schema()->field(index)->type()->fields()),
&writer);
arrow::schema(details::GetChildType(writer, index)->fields()), &writer);

inner_writer->Reset();
RowEncodeTrait<T>::Write(std::forward<V>(visitor), value,
Expand All @@ -184,6 +202,45 @@ struct RowEncodeTrait<
}
};

template <typename T>
struct RowEncodeTrait<T,
std::enable_if_t<details::IsArray<std::remove_cv_t<T>>>> {
static auto Type() {
return arrow::list(RowEncodeTrait<meta::GetValueType<T>>::Type());
}

template <typename V>
static void Write(V &&visitor, const T &value, ArrayWriter &writer) {
int index = 0;
for (const auto &v : value) {
RowEncodeTrait<meta::GetValueType<T>>::Write(std::forward<V>(visitor), v,
writer, index);
++index;
}
}

template <typename V, typename W,
std::enable_if_t<meta::IsOneOf<W, RowWriter, ArrayWriter>::value,
int> = 0>
static void Write(V &&visitor, const T &value, W &writer, int index) {
auto offset = writer.cursor();

auto inner_writer = std::make_unique<ArrayWriter>(
std::dynamic_pointer_cast<arrow::ListType>(
details::GetChildType(writer, index)),
&writer);

inner_writer->Reset(value.size());
RowEncodeTrait<T>::Write(std::forward<V>(visitor), value,
*inner_writer.get());

writer.SetOffsetAndSize(index, offset, writer.cursor() - offset);

std::forward<V>(visitor).template Visit<std::remove_cv_t<T>>(
std::move(inner_writer));
}
};

} // namespace encoder

} // namespace fury
103 changes: 102 additions & 1 deletion src/fury/encoder/row_encode_trait_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,11 @@
*/

#include "gtest/gtest.h"
#include <memory>
#include <type_traits>

#include "fury/encoder/row_encode_trait.h"
#include "src/fury/row/writer.h"
#include "fury/row/writer.h"

namespace fury {

Expand Down Expand Up @@ -146,6 +147,106 @@ TEST(RowEncodeTrait, NestedStruct) {
ASSERT_EQ(y_schema->field(2)->type()->name(), "bool");
}

TEST(RowEncodeTrait, SimpleArray) {
std::vector<int> a{10, 20, 30};

auto type = encoder::RowEncodeTrait<decltype(a)>::Type();

ASSERT_EQ(type->name(), "list");
ASSERT_EQ(type->field(0)->type()->name(), "int32");

ArrayWriter writer(std::dynamic_pointer_cast<arrow::ListType>(type));
writer.Reset(a.size());

encoder::RowEncodeTrait<decltype(a)>::Write(encoder::EmptyWriteVisitor{}, a,
writer);

auto array = writer.CopyToArrayData();
ASSERT_EQ(array->GetInt32(0), 10);
ASSERT_EQ(array->GetInt32(1), 20);
ASSERT_EQ(array->GetInt32(2), 30);
}

TEST(RowEncodeTrait, StructInArray) {
std::vector<A> a{{233, 1.1, false}, {234, 3.14, true}};

auto type = encoder::RowEncodeTrait<decltype(a)>::Type();

ASSERT_EQ(type->name(), "list");
ASSERT_EQ(type->field(0)->type()->name(), "struct");

ArrayWriter writer(std::dynamic_pointer_cast<arrow::ListType>(type));
writer.Reset(a.size());

encoder::RowEncodeTrait<decltype(a)>::Write(encoder::EmptyWriteVisitor{}, a,
writer);

auto array = writer.CopyToArrayData();

auto row1 = array->GetStruct(0);
ASSERT_EQ(row1->GetInt32(0), 233);
ASSERT_FLOAT_EQ(row1->GetFloat(1), 1.1);
ASSERT_EQ(row1->GetBoolean(2), false);

auto row2 = array->GetStruct(1);
ASSERT_EQ(row2->GetInt32(0), 234);
ASSERT_FLOAT_EQ(row2->GetFloat(1), 3.14);
ASSERT_EQ(row2->GetBoolean(2), true);
}

struct E {
int a;
std::vector<int> b;
};

FURY_FIELD_INFO(E, a, b);

TEST(RowEncodeTrait, ArrayInStruct) {
E e{233, {10, 20, 30}};

auto type = encoder::RowEncodeTrait<decltype(e)>::Type();

ASSERT_EQ(type->name(), "struct");
ASSERT_EQ(type->field(0)->type()->name(), "int32");
ASSERT_EQ(type->field(1)->type()->name(), "list");

RowWriter writer(encoder::RowEncodeTrait<decltype(e)>::Schema());
writer.Reset();

encoder::RowEncodeTrait<decltype(e)>::Write(encoder::EmptyWriteVisitor{}, e,
writer);

auto row = writer.ToRow();
ASSERT_EQ(row->GetInt32(0), 233);

ASSERT_EQ(row->GetArray(1)->GetInt32(0), 10);
ASSERT_EQ(row->GetArray(1)->GetInt32(1), 20);
ASSERT_EQ(row->GetArray(1)->GetInt32(2), 30);
}

TEST(RowEncodeTrait, ArrayInArray) {
std::vector<std::vector<int>> a{{10}, {20, 30}, {40, 50, 60}};

auto type = encoder::RowEncodeTrait<decltype(a)>::Type();

ASSERT_EQ(type->name(), "list");
ASSERT_EQ(type->field(0)->type()->name(), "list");

ArrayWriter writer(std::dynamic_pointer_cast<arrow::ListType>(type));
writer.Reset(a.size());

encoder::RowEncodeTrait<decltype(a)>::Write(encoder::EmptyWriteVisitor{}, a,
writer);

auto array = writer.CopyToArrayData();
ASSERT_EQ(array->GetArray(0)->GetInt32(0), 10);
ASSERT_EQ(array->GetArray(1)->GetInt32(0), 20);
ASSERT_EQ(array->GetArray(1)->GetInt32(1), 30);
ASSERT_EQ(array->GetArray(2)->GetInt32(0), 40);
ASSERT_EQ(array->GetArray(2)->GetInt32(1), 50);
ASSERT_EQ(array->GetArray(2)->GetInt32(2), 60);
}

} // namespace test

} // namespace fury
Expand Down
4 changes: 2 additions & 2 deletions src/fury/encoder/row_encoder_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
#include <type_traits>

#include "fury/encoder/row_encode_trait.h"
#include "src/fury/encoder/row_encoder.h"
#include "src/fury/row/writer.h"
#include "fury/encoder/row_encoder.h"
#include "fury/row/writer.h"

namespace fury {

Expand Down
26 changes: 26 additions & 0 deletions src/fury/meta/type_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#pragma once

#include <iterator>
#include <type_traits>

namespace fury {
Expand Down Expand Up @@ -71,6 +72,31 @@ template <typename T, typename... Args>
using EnableIfIsOneOf =
typename std::enable_if<IsOneOf<T, Args...>::value, T>::type;

namespace details {
using std::begin;
using std::end;

template <typename T,
typename U = std::void_t<
decltype(*begin(std::declval<T &>()),
++std::declval<decltype(begin(std::declval<T &>())) &>(),
begin(std::declval<T &>()) != end(std::declval<T &>()))>>
std::true_type IsIterableImpl(int);

template <typename T> std::false_type IsIterableImpl(...);

template <typename T> struct GetValueTypeImpl {
using type = std::remove_reference_t<decltype(*begin(std::declval<T &>()))>;
};
} // namespace details

template <typename T>
constexpr inline bool IsIterable =
decltype(details::IsIterableImpl<T>(0))::value;

template <typename T>
using GetValueType = typename details::GetValueTypeImpl<T>::type;

} // namespace meta

} // namespace fury
17 changes: 17 additions & 0 deletions src/fury/meta/type_traits_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
*/

#include "gtest/gtest.h"
#include <deque>
#include <initializer_list>
#include <list>

#include "fury/meta/field_info.h"
#include "src/fury/meta/type_traits.h"
Expand Down Expand Up @@ -60,6 +63,20 @@ TEST(Meta, IsUnique) {
static_assert(!IsUnique<1, false, true, &A::x, 1>::value);
}

TEST(Meta, IsIterable) {
static_assert(IsIterable<std::vector<int>>);
static_assert(IsIterable<std::vector<std::vector<int>>>);
static_assert(IsIterable<std::deque<float>>);
static_assert(IsIterable<std::list<int>>);
static_assert(IsIterable<std::set<int>>);
static_assert(IsIterable<std::map<int, std::vector<unsigned>>>);
static_assert(IsIterable<struct A[10]>);
static_assert(IsIterable<float[2][2]>);
static_assert(IsIterable<std::initializer_list<A>>);
static_assert(IsIterable<std::string>);
static_assert(IsIterable<std::string_view>);
}

} // namespace test

} // namespace fury
Expand Down
2 changes: 2 additions & 0 deletions src/fury/row/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,8 @@ class ArrayWriter : public Writer {

int size() { return cursor() - starting_offset_; }

std::shared_ptr<arrow::ListType> type() { return type_; }

private:
std::shared_ptr<arrow::ListType> type_;
int element_size_;
Expand Down

0 comments on commit 71121ca

Please sign in to comment.