Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix various issues related to Decimal and Arrow #942

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion yt/yt/library/decimal/decimal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -891,10 +891,28 @@ TStringBuf TDecimal::WriteBinary256(int precision, TValue256 value, char* buffer
CheckDecimalIntBits<TValue256>(precision);
YT_VERIFY(bufferLength >= resultLength);

DecimalIntegerToBinaryUnchecked(std::move(value), buffer);
DecimalIntegerToBinaryUnchecked(value, buffer);
return TStringBuf{buffer, sizeof(TValue256)};
}

TStringBuf TDecimal::WriteBinary256Variadic(int precision, TValue256 value, char* buffer, size_t bufferLength)
{
const size_t resultLength = GetValueBinarySize(precision);
switch (resultLength) {
case 4:
return WriteBinary32(precision, *reinterpret_cast<i32*>(value.Parts.data()), buffer, bufferLength);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reinterpret_cast

😎

case 8:
return WriteBinary64(precision, *reinterpret_cast<i64*>(value.Parts.data()), buffer, bufferLength);
case 16:
return WriteBinary128(precision, *reinterpret_cast<TValue128*>(value.Parts.data()), buffer, bufferLength);
case 32:
return WriteBinary256(precision, value, buffer, bufferLength);
default:
THROW_ERROR_EXCEPTION("Invalid precision %v", precision);
}
}


template <typename T>
Y_FORCE_INLINE void CheckBufferLength(int precision, size_t bufferLength)
{
Expand Down
3 changes: 3 additions & 0 deletions yt/yt/library/decimal/decimal.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class TDecimal
};
static_assert(sizeof(TValue128) == 2 * sizeof(ui64));

//! Lower-endian representation of 256-bit decimal value.
struct TValue256
{
std::array<ui32, 8> Parts;
Expand Down Expand Up @@ -64,6 +65,8 @@ class TDecimal

// Writes either 32-bit, 64-bit or 128-bit binary value depending on precision, provided a TValue128.
static TStringBuf WriteBinary128Variadic(int precision, TValue128 value, char* buffer, size_t bufferLength);
// Writes either 32-bit, 64-bit, 128-bit or 256-bit binary value depending on precision, provided a TValue256.
static TStringBuf WriteBinary256Variadic(int precision, TValue256 value, char* buffer, size_t bufferLength);

static i32 ParseBinary32(int precision, TStringBuf buffer);
static i64 ParseBinary64(int precision, TStringBuf buffer);
Expand Down
77 changes: 57 additions & 20 deletions yt/yt/library/formats/arrow_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ void ThrowOnError(const arrow::Status& status)
}
}

template <class TUnderlyingValueType>
TStringBuf SerializeDecimalBinary(const TStringBuf& value, int precision, char* buffer, size_t bufferLength)
{
// NB: Arrow wire representation of Decimal128 is little-endian and (obviously) 128 bit,
// while YT in-memory representation of Decimal is big-endian, variadic-length of either 32 bit, 64 bit or 128 bit,
// and MSB-flipped to ensure lexical sorting order.
// Representation of Decimal256 is similar, but the upper limit for a length is 256 bit.
TUnderlyingValueType decimalValue;
YT_VERIFY(value.size() == sizeof(decimalValue));
std::memcpy(&decimalValue, value.data(), value.size());

TStringBuf decimalBinary;
if constexpr (std::is_same_v<TUnderlyingValueType, TDecimal::TValue128>) {
decimalBinary = TDecimal::WriteBinary128Variadic(precision, decimalValue, buffer, bufferLength);
} else if constexpr (std::is_same_v<TUnderlyingValueType, TDecimal::TValue256>) {
decimalBinary = TDecimal::WriteBinary256Variadic(precision, decimalValue, buffer, bufferLength);
} else {
static_assert(std::is_same_v<TUnderlyingValueType, TDecimal::TValue256>, "Unexpected decimal type");
}
return decimalBinary;
}

////////////////////////////////////////////////////////////////////////////////

class TArraySimpleVisitor
Expand Down Expand Up @@ -291,28 +313,12 @@ class TArraySimpleVisitor
}

template <class TUnderlyingValueType>
TUnversionedValue MakeDecimalBinaryValue(const TStringBuf& value, i64 columnId, int precision)
TUnversionedValue MakeDecimalBinaryValue(const TStringBuf& arrowValue, i64 columnId, int precision)
{
// NB: Arrow wire representation of Decimal128 is little-endian and (obviously) 128 bit,
// while YT in-memory representation of Decimal is big-endian, variadic-length of either 32 bit, 64 bit or 128 bit,
// and MSB-flipped to ensure lexical sorting order.
// Representation of Decimal256 is similar, but only 256 bits.
TUnderlyingValueType decimalValue;
YT_VERIFY(value.size() == sizeof(decimalValue));
std::memcpy(&decimalValue, value.data(), value.size());

const auto maxByteCount = sizeof(decimalValue);
const auto maxByteCount = sizeof(TUnderlyingValueType);
char* buffer = BufferForStringLikeValues_->Preallocate(maxByteCount);
TStringBuf decimalBinary;
if constexpr (std::is_same_v<TUnderlyingValueType, TDecimal::TValue128>) {
decimalBinary = TDecimal::WriteBinary128Variadic(precision, decimalValue, buffer, maxByteCount);
} else if constexpr (std::is_same_v<TUnderlyingValueType, TDecimal::TValue256>) {
decimalBinary = TDecimal::WriteBinary256(precision, decimalValue, buffer, maxByteCount);
} else {
static_assert(std::is_same_v<TUnderlyingValueType, TDecimal::TValue256>, "Unexpected decimal type");
}
auto decimalBinary = SerializeDecimalBinary<TUnderlyingValueType>(arrowValue, precision, buffer, maxByteCount);
BufferForStringLikeValues_->Advance(decimalBinary.size());

return MakeUnversionedStringValue(decimalBinary, columnId);
}
};
Expand Down Expand Up @@ -456,6 +462,20 @@ class TArrayCompositeVisitor
return ParseStruct();
}

arrow::Status Visit(const arrow::Decimal128Type& type) override
{
return ParseStringLikeArray<arrow::Decimal128Array>([&] (const TStringBuf& value) {
WriteDecimalBinary<TDecimal::TValue128>(value, type.precision());
});
}

arrow::Status Visit(const arrow::Decimal256Type& type) override
{
return ParseStringLikeArray<arrow::Decimal256Array>([&] (const TStringBuf& value) {
WriteDecimalBinary<TDecimal::TValue256>(value, type.precision());
});
}

private:
const int RowIndex_;

Expand Down Expand Up @@ -505,13 +525,21 @@ class TArrayCompositeVisitor

template <typename ArrayType>
arrow::Status ParseStringLikeArray()
{
return ParseStringLikeArray<ArrayType>([&] (const TStringBuf& value) {
Writer_->WriteBinaryString(value);
});
}

template <typename ArrayType>
arrow::Status ParseStringLikeArray(auto writeStringValue)
{
auto array = std::static_pointer_cast<ArrayType>(Array_);
if (array->IsNull(RowIndex_)) {
Writer_->WriteEntity();
} else {
auto element = array->GetView(RowIndex_);
Writer_->WriteBinaryString(TStringBuf(element.data(), element.size()));
writeStringValue(TStringBuf(element.data(), element.size()));
}
return arrow::Status::OK();
}
Expand Down Expand Up @@ -610,6 +638,15 @@ class TArrayCompositeVisitor
}
return arrow::Status::OK();
}

template <class TUnderlyingType>
void WriteDecimalBinary(TStringBuf arrowValue, int precision)
{
const auto maxByteCount = sizeof(TUnderlyingType);
char buffer[maxByteCount];
auto decimalBinary = SerializeDecimalBinary<TUnderlyingType>(arrowValue, precision, buffer, maxByteCount);
Writer_->WriteBinaryString(decimalBinary);
}
};

////////////////////////////////////////////////////////////////////////////////
Expand Down
67 changes: 58 additions & 9 deletions yt/yt/library/formats/unittests/arrow_parser_ut.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,26 @@ std::string MakeDecimalArrows(std::vector<TString> values, std::vector<std::tupl
return MakeOutputFromRecordBatch(recordBatch);
}

std::string MakeDecimalListArrow(std::vector<TString> values)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TString -> std::string

{
// Create a single column with one value, which is a list containing all the #values.
// Type of the list is Decimal128(10, 3).
auto* pool = arrow::default_memory_pool();
auto decimalBuilder = std::make_shared<arrow::Decimal128Builder>(std::make_shared<arrow::Decimal128Type>(10, 3), pool);
auto listBuilder = std::make_unique<arrow::ListBuilder>(pool, decimalBuilder);

Verify(listBuilder->Append());
for (const auto& value : values) {
Verify(decimalBuilder->Append(arrow::Decimal128(std::string(value))));
}
std::shared_ptr<arrow::Array> listArray;
Verify(listBuilder->Finish(&listArray));
auto arrowSchema = arrow::schema({arrow::field("list", listArray->type())});
std::vector<std::shared_ptr<arrow::Array>> columns = {listArray};
auto recordBatch = arrow::RecordBatch::Make(arrowSchema, columns[0]->length(), columns);
return MakeOutputFromRecordBatch(recordBatch);
}

////////////////////////////////////////////////////////////////////////////////

TEST(TArrowParserTest, Simple)
Expand Down Expand Up @@ -550,6 +570,9 @@ TEST(TArrowParserTest, DecimalVariousPrecisions)
TColumnSchema("decimal128_10_3", DecimalLogicalType(10, 3)),
TColumnSchema("decimal128_35_3", DecimalLogicalType(35, 3)),
TColumnSchema("decimal128_38_3", DecimalLogicalType(38, 3)),
TColumnSchema("decimal256_10_3", DecimalLogicalType(10, 3)),
TColumnSchema("decimal256_35_3", DecimalLogicalType(35, 3)),
TColumnSchema("decimal256_38_3", DecimalLogicalType(38, 3)),
TColumnSchema("decimal256_76_3", DecimalLogicalType(76, 3)),
});

Expand All @@ -559,7 +582,7 @@ TEST(TArrowParserTest, DecimalVariousPrecisions)

auto parser = CreateParserForArrow(&collectedRows);

parser->Read(MakeDecimalArrows(values, {{128, 10, 3}, {128, 35, 3}, {128, 38, 3}, {256, 76, 3}}));
parser->Read(MakeDecimalArrows(values, {{128, 10, 3}, {128, 35, 3}, {128, 38, 3}, {256, 10, 3}, {256, 35, 3}, {256, 38, 3}, {256, 76, 3}}));
parser->Finish();

auto collectStrings = [&] (TStringBuf columnName) {
Expand All @@ -570,29 +593,55 @@ TEST(TArrowParserTest, DecimalVariousPrecisions)
return result;
};

std::vector<TString> expectedValues128_10_3 =
std::vector<TString> expectedValues_10_3 =
{"\x80\x00\x00\x00\x00\x00\x0c\x45"s, "\x80\x00\x00\x00\x00\x00\x00\x00"s, "\x7f\xff\xff\xff\xff\xff\xf5\x62"s, "\x80\x00\x00\x02\x54\x0b\xe3\xff"s};
std::vector<TString> expectedValues128_35_3 =
std::vector<TString> expectedValues_35_3 =
{
"\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x45"s, "\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"s,
"\x7f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xf5\x62"s, "\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x54\x0b\xe3\xff"s,
};
std::vector<TString> expectedValues128_38_3 =
std::vector<TString> expectedValues_38_3 =
{
"\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x45"s, "\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"s,
"\x7f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xf5\x62"s, "\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x54\x0b\xe3\xff"s
};
std::vector<TString> expectedValues256_76_3 =
std::vector<TString> expectedValues_76_3 =
{
"\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x0c\x45"s,
"\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"s,
"\x7f\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xff\xf5\x62"s,
"\x80\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02\x54\x0b\xe3\xff"s,
};
ASSERT_EQ(expectedValues128_10_3, collectStrings("decimal128_10_3"));
ASSERT_EQ(expectedValues128_35_3, collectStrings("decimal128_35_3"));
ASSERT_EQ(expectedValues128_38_3, collectStrings("decimal128_38_3"));
ASSERT_EQ(expectedValues256_76_3, collectStrings("decimal256_76_3"));
ASSERT_EQ(expectedValues_10_3, collectStrings("decimal128_10_3"));
ASSERT_EQ(expectedValues_35_3, collectStrings("decimal128_35_3"));
ASSERT_EQ(expectedValues_38_3, collectStrings("decimal128_38_3"));
ASSERT_EQ(expectedValues_10_3, collectStrings("decimal256_10_3"));
ASSERT_EQ(expectedValues_35_3, collectStrings("decimal256_35_3"));
ASSERT_EQ(expectedValues_38_3, collectStrings("decimal256_38_3"));
ASSERT_EQ(expectedValues_76_3, collectStrings("decimal256_76_3"));
}

TEST(TArrowParserTest, ListOfDecimals)
{
auto tableSchema = New<TTableSchema>(std::vector<TColumnSchema>{
TColumnSchema("list", ListLogicalType(DecimalLogicalType(10, 3))),
});

TCollectingValueConsumer collectedRows(tableSchema);

std::vector<TString> values = {"3.141", "0.000", "-2.718", "9999999.999"};

auto parser = CreateParserForArrow(&collectedRows);

parser->Read(MakeDecimalListArrow(values));
parser->Finish();

auto firstList = ConvertTo<std::vector<TString>>(GetComposite(collectedRows.GetRowValue(0, "list")));
std::vector<TString> secondList = {
"\x80\x00\x00\x00\x00\x00\x0c\x45"s, "\x80\x00\x00\x00\x00\x00\x00\x00"s,
"\x7f\xff\xff\xff\xff\xff\xf5\x62"s, "\x80\x00\x00\x02\x54\x0b\xe3\xff"s
};
ASSERT_EQ(firstList, secondList);
}

TEST(TArrowParserTest, BlockingInput)
Expand Down
57 changes: 48 additions & 9 deletions yt/yt/tests/integration/formats/test_arrow_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from yt_commands import authors, create, read_table, write_table, map, merge, get

from yt_type_helpers import optional_type, list_type
from yt_type_helpers import optional_type, list_type, decimal_type

import pytest

Expand All @@ -15,6 +15,13 @@
GOODBYE_WORLD = b"\xd0\x9f\xd0\xbe\xd0\xba\xd0\xb0, \xd0\xbc\xd0\xb8\xd1\x80!"


def serialize_arrow_table(table):
sink = pa.BufferOutputStream()
with pa.RecordBatchStreamWriter(sink, table.schema) as writer:
writer.write(table)
return bytes(sink.getvalue())


def parse_list_to_arrow():
data = [
{'string': 'one', 'list_strings': ['bar', 'foo']},
Expand All @@ -28,14 +35,7 @@ def parse_list_to_arrow():

table = pa.Table.from_pandas(pd.DataFrame(data), schema=schema)

sink = pa.BufferOutputStream()

with pa.RecordBatchStreamWriter(sink, table.schema) as writer:
writer.write(table)

buffer = sink.getvalue()

return bytes(buffer)
return serialize_arrow_table(table)


def parse_arrow_stream(data):
Expand Down Expand Up @@ -180,6 +180,45 @@ def test_write_arrow(self, optimize_for):

assert read_table("//tmp/table1") == read_table("//tmp/table2")

@authors("max42")
def test_write_arrow_decimal(self, optimize_for):
schema = [
{"name": "d128", "type_v3": decimal_type(10, 3)},
{"name": "d256", "type_v3": decimal_type(10, 3)},
{"name": "ld128", "type_v3": list_type(decimal_type(10, 3))},
{"name": "ld256", "type_v3": list_type(decimal_type(10, 3))}
]

fields = [
pa.field("d128", pa.decimal128(10, 3)),
pa.field("d256", pa.decimal256(10, 3)),
pa.field("ld128", pa.list_(pa.decimal128(10, 3))),
pa.field("ld256", pa.list_(pa.decimal256(10, 3))),
]
arrow_schema = pa.schema(fields)

d_list = [3.141, 0.000, -2.718, 9999999.999]
ld_list = [[d] for d in d_list]

d_arr = pa.array(d_list)
ld_arr = pa.array(ld_list)
arrow_table = pa.Table.from_arrays([d_arr, d_arr, ld_arr, ld_arr], schema=arrow_schema)

create("table", "//tmp/table", attributes={"schema": schema, "optimize_for": optimize_for})

format = yson.YsonString(b"arrow")
write_table("//tmp/table", serialize_arrow_table(arrow_table), is_raw=True, input_format=format)

rows = list(yson.loads(read_table("//tmp/table", output_format=yson.loads(b"<decimal_mode=text>yson")),
yson_type="list_fragment"))

assert rows == [
{'d128': "3.141", 'd256': "3.141", 'ld128': ["3.141"], 'ld256': ["3.141"]},
{'d128': "0.000", 'd256': "0.000", 'ld128': ["0.000"], 'ld256': ["0.000"]},
{'d128': "-2.718", 'd256': "-2.718", 'ld128': ["-2.718"], 'ld256': ["-2.718"]},
{'d128': "9999999.999", 'd256': "9999999.999", 'ld128': ["9999999.999"], 'ld256': ["9999999.999"]}
]

@authors("nadya02")
def test_write_arrow_complex(self, optimize_for):
schema = [
Expand Down