Skip to content

Commit

Permalink
expression: fix the issue that comparison between Decimal may cause o…
Browse files Browse the repository at this point in the history
…verflow and report `Can't compare`. (pingcap#3097)
LittleFall authored Sep 16, 2021
1 parent f28af36 commit 08443b9
Showing 7 changed files with 370 additions and 171 deletions.
94 changes: 47 additions & 47 deletions dbms/src/Core/DecimalComparison.h
Original file line number Diff line number Diff line change
@@ -108,6 +108,53 @@ class DecimalComparison
return applyWithScale(a, b, shift);
}

template <bool scale_left, bool scale_right>
static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
{
CompareInt x = static_cast<CompareInt>(a);
CompareInt y = static_cast<CompareInt>(b);

if constexpr (_check_overflow)
{
bool invalid = false;

if constexpr (sizeof(A) > sizeof(CompareInt))
invalid |= (A(x) != a);
if constexpr (sizeof(B) > sizeof(CompareInt))
invalid |= (B(y) != b);
if constexpr (std::is_unsigned_v<A>)
invalid |= (x < 0);
if constexpr (std::is_unsigned_v<B>)
invalid |= (y < 0);

if (invalid)
throw Exception("Can't compare", ErrorCodes::DECIMAL_OVERFLOW);
}

if constexpr (scale_left && scale_right)
throw DB::Exception("Assumption broken: there should only one side need to be multiplied in decimal comparison.", ErrorCodes::LOGICAL_ERROR);
if constexpr (!scale_left && !scale_right)
return Op::apply(x, y);

// overflow means absolute value must be greater.
// we use this variable to mark whether the right side is greater than left side by overflow.
int right_side_greater_by_overflow = 0;
if constexpr (scale_left)
{
int sign = boost::math::sign(x);
right_side_greater_by_overflow = -sign * common::mulOverflow(x, scale, x); // x will be changed.
}
if constexpr (scale_right)
{
int sign = boost::math::sign(y);
right_side_greater_by_overflow = sign * common::mulOverflow(y, scale, y); // y will be changed.
}

if (right_side_greater_by_overflow)
return Op::apply(0, right_side_greater_by_overflow);
return Op::apply(x, y);
}

private:
struct Shift
{
@@ -264,53 +311,6 @@ class DecimalComparison
return c_res;
}

template <bool scale_left, bool scale_right>
static NO_INLINE UInt8 apply(A a, B b, CompareInt scale [[maybe_unused]])
{
CompareInt x = static_cast<CompareInt>(a);
CompareInt y = static_cast<CompareInt>(b);

if constexpr (_check_overflow)
{
bool overflow = false;

if constexpr (sizeof(A) > sizeof(CompareInt))
overflow |= (A(x) != a);
if constexpr (sizeof(B) > sizeof(CompareInt))
overflow |= (B(y) != b);
if constexpr (std::is_unsigned_v<A>)
overflow |= (x < 0);
if constexpr (std::is_unsigned_v<B>)
overflow |= (y < 0);

if constexpr (scale_left)
{
if constexpr (std::is_same_v<CompareInt, Int256>)
x = x * scale;
else
overflow |= common::mulOverflow(x, scale, x);
}
if constexpr (scale_right)
{
if constexpr (std::is_same_v<CompareInt, Int256>)
y = y * scale;
else
overflow |= common::mulOverflow(y, scale, y);
}
if (overflow)
throw Exception("Can't compare", ErrorCodes::DECIMAL_OVERFLOW);
}
else
{
if constexpr (scale_left)
x *= scale;
if constexpr (scale_right)
y *= scale;
}

return Op::apply(x, y);
}

template <bool scale_left, bool scale_right>
static void NO_INLINE vectorVector(const ArrayA & a, const ArrayB & b, PaddedPODArray<UInt8> & c, CompareInt scale [[maybe_unused]])
{
2 changes: 1 addition & 1 deletion dbms/src/Core/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
add_executable (exception exception.cpp)
add_executable (exception exception.cpp gtest_decimal_comparison.cpp)
target_link_libraries (exception clickhouse_common_io)

add_executable (string_pool string_pool.cpp)
158 changes: 158 additions & 0 deletions dbms/src/Core/tests/gtest_decimal_comparison.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
#include <Core/DecimalComparison.h>
#include <TestUtils/FunctionTestUtils.h>
#include <common/arithmeticOverflow.h>

namespace DB
{
namespace tests
{
class TestDecimalComparisonUtils : public ::testing::Test
{
};

TEST_F(TestDecimalComparisonUtils, DecimalComparisonApply)
try
{
{
// case 1: original issue case, 1/2 < 3 ?
using A = Decimal32;
using B = Decimal32;
A a(500'000'000);
B b(3);
int res = DecimalComparison<A, B, LessOp, true>::apply<false, true>(a, b, 1'000'000'000);
ASSERT_EQ(res, 1);
}

{
// case 2: original issue case, 1*10^64 <= 0.1 ?
using A = Decimal256;
using B = Decimal32;
boost::multiprecision::checked_int256_t origin_a{"12345678911234567891123456789112345678911234567891123456789112345"};
A a(origin_a);
B b(1);
int res = DecimalComparison<A, B, LessOrEqualsOp, true>::apply<true, false>(a, b, 10);
ASSERT_EQ(res, 0);
}

{
// case 3: path check, will apply<true, true> throw?
using A = Decimal32;
using B = Decimal32;
A a(500'000'000);
B b(3);
auto func_call = [&] {
DecimalComparison<A, B, LessOp, true>::apply<true, true>(a, b, 1'000'000'000);
};
ASSERT_THROW(func_call(), DB::Exception);
}

{
// case 4: path check, x is positive and will overflow.
using A = Decimal32;
using B = Decimal32;
using Type = DecimalComparison<A, B, LessOp, true>;
A a(500'000'000);
B b(0);
Type::CompareInt scale = 1'000'000'000;

int res = Type::apply<true, false>(a, b, scale);
ASSERT_EQ(res, 0);

Type::CompareInt a_promoted = static_cast<Type::CompareInt>(a);
int overflowed = common::mulOverflow(a_promoted, scale, a_promoted);
ASSERT_EQ(overflowed, 1);
}

{
// case 5: path check, x is negative and will overflow.
using A = Decimal32;
using B = Decimal32;
using Type = DecimalComparison<A, B, LessOp, true>;
A a(-500'000'000);
B b(0);
Type::CompareInt scale = 1'000'000'000;

int res = Type::apply<true, false>(a, b, scale);
ASSERT_EQ(res, 1);

Type::CompareInt a_promoted = static_cast<Type::CompareInt>(a);
int overflowed = common::mulOverflow(a_promoted, scale, a_promoted);
ASSERT_EQ(overflowed, 1);
}

{
// case 6: path check, y is positive and will overflow.
using A = Decimal32;
using B = Decimal32;
using Type = DecimalComparison<A, B, LessOp, true>;
A a(-500'000'000);
B b(500'000'000);
Type::CompareInt scale = 1'000'000'000;

int res = Type::apply<false, true>(a, b, scale);
ASSERT_EQ(res, 1);

Type::CompareInt b_promoted = static_cast<Type::CompareInt>(b);
int overflowed = common::mulOverflow(b_promoted, scale, b_promoted);
ASSERT_EQ(overflowed, 1);
}

{
// case 7: path check, y is negative and will overflow.
using A = Decimal32;
using B = Decimal32;
using Type = DecimalComparison<A, B, LessOp, true>;
A a(-500'000'000);
B b(-500'000'000);
Type::CompareInt scale = 1'000'000'000;

int res = Type::apply<false, true>(a, b, scale);
ASSERT_EQ(res, 0);

Type::CompareInt b_promoted = static_cast<Type::CompareInt>(b);
int overflowed = common::mulOverflow(b_promoted, scale, b_promoted);
ASSERT_EQ(overflowed, 1);
}

{
// case 8: path check, no overflow.
using A = Decimal64;
using B = Decimal32;
using Type = DecimalComparison<A, B, LessOp, true>;
A a(500'000'000'000'000);
B b(500'000'000);
Type::CompareInt scale = 1'000'000'000;

int res = Type::apply<false, true>(a, b, scale);
ASSERT_EQ(res, 1);

Type::CompareInt b_promoted = static_cast<Type::CompareInt>(b);
int overflowed = common::mulOverflow(b_promoted, scale, b_promoted);
ASSERT_EQ(overflowed, 0);
}

{
// case 9: path check, overflow of int256.
using A = Decimal256;
using B = Decimal256;
using Type = DecimalComparison<A, B, LessOp, true>;
boost::multiprecision::checked_int256_t origin_a{"-12345678911234567891123456789112345678911234567891123456789112345"};
boost::multiprecision::checked_int256_t origin_b{"114514"};
boost::multiprecision::checked_int256_t origin_scale{"1000000000000000000000000000000000000000000000000"};
A a(origin_a);
B b(origin_b);
Type::CompareInt scale(origin_scale);

int res = Type::apply<true, false>(a, b, scale);
ASSERT_EQ(res, 1);

Type::CompareInt a_promoted = static_cast<Type::CompareInt>(a);
int overflowed = common::mulOverflow(a_promoted, scale, a_promoted);
ASSERT_EQ(overflowed, 1);
}
}
CATCH

} // namespace tests

} // namespace DB
57 changes: 31 additions & 26 deletions dbms/src/DataTypes/DataTypeDecimal.h
Original file line number Diff line number Diff line change
@@ -9,7 +9,6 @@

namespace DB
{

namespace ErrorCodes
{
extern const int ARGUMENT_OUT_OF_BOUND;
@@ -43,9 +42,13 @@ class DataTypeDecimal : public IDataType
static constexpr size_t maxPrecision() { return maxDecimalPrecision<T>(); }

// If scale is omitted, the default is 0. If precision is omitted, the default is 10.
DataTypeDecimal() : DataTypeDecimal(10, 0) {}
DataTypeDecimal()
: DataTypeDecimal(10, 0)
{}

DataTypeDecimal(size_t precision_, size_t scale_) : precision(precision_), scale(scale_)
DataTypeDecimal(size_t precision_, size_t scale_)
: precision(precision_)
, scale(scale_)
{
if (precision > decimal_max_prec || scale > precision || scale > decimal_max_scale)
{
@@ -129,7 +132,7 @@ class DataTypeDecimal : public IDataType
template <typename U>
typename T::NativeType scaleFactorFor(const DataTypeDecimal<U> & x) const
{
if (scale < x.getScale())
if (getScale() < x.getScale())
{
return 1;
}
@@ -164,8 +167,8 @@ inline DataTypePtr createDecimal(UInt64 prec, UInt64 scale)

if (static_cast<UInt64>(scale) > prec)
throw Exception("Negative scales and scales larger than precision are not supported. precision:" + DB::toString(prec)
+ ", scale:" + DB::toString(scale),
ErrorCodes::ARGUMENT_OUT_OF_BOUND);
+ ", scale:" + DB::toString(scale),
ErrorCodes::ARGUMENT_OUT_OF_BOUND);

if (prec <= maxDecimalPrecision<Decimal32>())
{
@@ -195,15 +198,17 @@ inline bool IsDecimalDataType(const DataTypePtr & type)
}
template <typename T, typename U>
typename std::enable_if_t<(sizeof(T) >= sizeof(U)), const DataTypeDecimal<T>> decimalResultType(
const DataTypeDecimal<T> & tx, const DataTypeDecimal<U> & ty)
const DataTypeDecimal<T> & tx,
const DataTypeDecimal<U> & ty)
{
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
return DataTypeDecimal<T>(maxDecimalPrecision<T>(), scale);
}

template <typename T, typename U>
typename std::enable_if_t<(sizeof(T) < sizeof(U)), const DataTypeDecimal<U>> decimalResultType(
const DataTypeDecimal<T> & tx, const DataTypeDecimal<U> & ty)
const DataTypeDecimal<T> & tx,
const DataTypeDecimal<U> & ty)
{
UInt32 scale = (tx.getScale() > ty.getScale() ? tx.getScale() : ty.getScale());
return DataTypeDecimal<U>(maxDecimalPrecision<U>(), scale);
@@ -239,24 +244,24 @@ inline UInt32 leastDecimalPrecisionFor(TypeIndex int_type)
{
switch (int_type)
{
case TypeIndex::Int8:
[[fallthrough]];
case TypeIndex::UInt8:
return 3;
case TypeIndex::Int16:
[[fallthrough]];
case TypeIndex::UInt16:
return 5;
case TypeIndex::Int32:
[[fallthrough]];
case TypeIndex::UInt32:
return 10;
case TypeIndex::Int64:
return 19;
case TypeIndex::UInt64:
return 20;
default:
break;
case TypeIndex::Int8:
[[fallthrough]];
case TypeIndex::UInt8:
return 3;
case TypeIndex::Int16:
[[fallthrough]];
case TypeIndex::UInt16:
return 5;
case TypeIndex::Int32:
[[fallthrough]];
case TypeIndex::UInt32:
return 10;
case TypeIndex::Int64:
return 19;
case TypeIndex::UInt64:
return 20;
default:
break;
}
return 0;
}
Loading
Oops, something went wrong.

0 comments on commit 08443b9

Please sign in to comment.