Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Implicit cast rules for integer literals, and parameterized ANY for binding #10194

Merged
merged 14 commits into from
Jan 10, 2024
Merged
Prev Previous commit
Next Next commit
Add the INTEGER_LITERAL class - similar to the STRING_LITERAL we use …
…different rules when binding integer literals, namely they can adapt themselves to adjacent typesnumeric types
  • Loading branch information
Mytherin committed Jan 9, 2024
commit 815e2c4ffef879b90d048bf43ecf7f176a6e4975
15 changes: 15 additions & 0 deletions src/common/enum_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2175,6 +2175,10 @@ const char* EnumUtil::ToChars<ExtraTypeInfoType>(ExtraTypeInfoType value) {
return "AGGREGATE_STATE_TYPE_INFO";
case ExtraTypeInfoType::ARRAY_TYPE_INFO:
return "ARRAY_TYPE_INFO";
case ExtraTypeInfoType::ANY_TYPE_INFO:
return "ANY_TYPE_INFO";
case ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO:
return "INTEGER_LITERAL_TYPE_INFO";
default:
throw NotImplementedException(StringUtil::Format("Enum value: '%d' not implemented", value));
}
Expand Down Expand Up @@ -2212,6 +2216,12 @@ ExtraTypeInfoType EnumUtil::FromString<ExtraTypeInfoType>(const char *value) {
if (StringUtil::Equals(value, "ARRAY_TYPE_INFO")) {
return ExtraTypeInfoType::ARRAY_TYPE_INFO;
}
if (StringUtil::Equals(value, "ANY_TYPE_INFO")) {
return ExtraTypeInfoType::ANY_TYPE_INFO;
}
if (StringUtil::Equals(value, "INTEGER_LITERAL_TYPE_INFO")) {
return ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO;
}
throw NotImplementedException(StringUtil::Format("Enum value: '%s' not implemented", value));
}

Expand Down Expand Up @@ -3129,6 +3139,8 @@ const char* EnumUtil::ToChars<LogicalTypeId>(LogicalTypeId value) {
return "BIT";
case LogicalTypeId::STRING_LITERAL:
return "STRING_LITERAL";
case LogicalTypeId::INTEGER_LITERAL:
return "INTEGER_LITERAL";
case LogicalTypeId::UHUGEINT:
return "UHUGEINT";
case LogicalTypeId::HUGEINT:
Expand Down Expand Up @@ -3257,6 +3269,9 @@ LogicalTypeId EnumUtil::FromString<LogicalTypeId>(const char *value) {
if (StringUtil::Equals(value, "STRING_LITERAL")) {
return LogicalTypeId::STRING_LITERAL;
}
if (StringUtil::Equals(value, "INTEGER_LITERAL")) {
return LogicalTypeId::INTEGER_LITERAL;
}
if (StringUtil::Equals(value, "UHUGEINT")) {
return LogicalTypeId::UHUGEINT;
}
Expand Down
17 changes: 16 additions & 1 deletion src/common/extra_type_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -337,12 +337,27 @@ AnyTypeInfo::AnyTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::ANY_TYPE_INFO) {
}

AnyTypeInfo::AnyTypeInfo(LogicalType target_type_p, idx_t cast_score_p)
: ExtraTypeInfo(ExtraTypeInfoType::ANY_TYPE_INFO), target_type(std::move(target_type_p)), cast_score(cast_score_p) {
: ExtraTypeInfo(ExtraTypeInfoType::ANY_TYPE_INFO), target_type(std::move(target_type_p)), cast_score(cast_score_p) {
}

bool AnyTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const {
auto &other = other_p->Cast<AnyTypeInfo>();
return target_type == other.target_type && cast_score == other.cast_score;
}

//===--------------------------------------------------------------------===//
// Any Type Info
//===--------------------------------------------------------------------===//
IntegerLiteralTypeInfo::IntegerLiteralTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO) {
}

IntegerLiteralTypeInfo::IntegerLiteralTypeInfo(Value constant_value_p)
: ExtraTypeInfo(ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO), constant_value(std::move(constant_value_p)) {
}

bool IntegerLiteralTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const {
auto &other = other_p->Cast<IntegerLiteralTypeInfo>();
return constant_value == other.constant_value;
}

} // namespace duckdb
66 changes: 60 additions & 6 deletions src/common/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,7 @@ PhysicalType LogicalType::GetInternalType() {
case LogicalTypeId::INVALID:
case LogicalTypeId::UNKNOWN:
case LogicalTypeId::STRING_LITERAL:
case LogicalTypeId::INTEGER_LITERAL:
return PhysicalType::INVALID;
case LogicalTypeId::USER:
return PhysicalType::UNKNOWN;
Expand Down Expand Up @@ -634,6 +635,8 @@ bool LogicalType::GetDecimalProperties(uint8_t &width, uint8_t &scale) const {
width = DecimalType::GetWidth(*this);
scale = DecimalType::GetScale(*this);
break;
case LogicalTypeId::INTEGER_LITERAL:
return IntegerLiteral::GetType(*this).GetDecimalProperties(width, scale);
default:
// Nonsense values to ensure initialization
width = 255u;
Expand Down Expand Up @@ -677,6 +680,7 @@ static LogicalType DecimalSizeCheck(const LogicalType &left, const LogicalType &

static LogicalType CombineNumericTypes(const LogicalType &left, const LogicalType &right) {
D_ASSERT(left.id() != right.id());
// for integer literals - grab the inner type
if (left.id() > right.id()) {
// this method is symmetric
// arrange it so the left type is smaller to limit the number of options we need to check
Expand Down Expand Up @@ -708,11 +712,15 @@ static LogicalType CombineNumericTypes(const LogicalType &left, const LogicalTyp
throw InternalException("Cannot combine these numeric types (%s & %s)", left.ToString(), right.ToString());
}

static LogicalType ReturnType(const LogicalType &type) {
if (type.id() == LogicalTypeId::STRING_LITERAL) {
LogicalType LogicalType::NormalizeType(const LogicalType &type) {
switch (type.id()) {
case LogicalTypeId::STRING_LITERAL:
return LogicalType::VARCHAR;
case LogicalTypeId::INTEGER_LITERAL:
return IntegerLiteral::GetType(type);
default:
return type;
}
return type;
}

template <class OP>
Expand All @@ -728,10 +736,10 @@ static bool CombineUnequalTypes(const LogicalType &left, const LogicalType &righ
LogicalTypeId other_types[] = {LogicalTypeId::UNKNOWN, LogicalTypeId::SQLNULL, LogicalTypeId::STRING_LITERAL};
for (auto &other_type : other_types) {
if (left.id() == other_type) {
result = ReturnType(right);
result = LogicalType::NormalizeType(right);
return true;
} else if (right.id() == other_type) {
result = ReturnType(left);
result = LogicalType::NormalizeType(left);
return true;
}
}
Expand Down Expand Up @@ -759,6 +767,13 @@ static bool CombineUnequalTypes(const LogicalType &left, const LogicalType &righ
}
return true;
}
// for integer literals - rerun the operation with the underlying type
if (left.id() == LogicalTypeId::INTEGER_LITERAL) {
return OP::Operation(IntegerLiteral::GetType(left), right, result);
}
if (right.id() == LogicalTypeId::INTEGER_LITERAL) {
return OP::Operation(left, IntegerLiteral::GetType(right), result);
}
// for unsigned/signed comparisons we have a few fallbacks
if (left.IsNumeric() && right.IsNumeric()) {
result = CombineNumericTypes(left, right);
Expand All @@ -784,6 +799,10 @@ static bool CombineEqualTypes(const LogicalType &left, const LogicalType &right,
// two string literals convert to varchar
result = LogicalType::VARCHAR;
return true;
case LogicalTypeId::INTEGER_LITERAL:
// for integer literals we pick the highest type of the provided integer literal type
result = LogicalType::ForceMaxLogicalType(IntegerLiteral::GetType(left), IntegerLiteral::GetType(right));
return true;
case LogicalTypeId::ENUM:
// If both types are different ENUMs we do a string comparison.
result = left == right ? left : LogicalType::VARCHAR;
Expand Down Expand Up @@ -920,13 +939,13 @@ bool LogicalType::TryGetMaxLogicalType(ClientContext &context, const LogicalType
}

static idx_t GetLogicalTypeScore(const LogicalType &type) {
return idx_t(type.id());
switch (type.id()) {
case LogicalTypeId::INVALID:
case LogicalTypeId::SQLNULL:
case LogicalTypeId::UNKNOWN:
case LogicalTypeId::ANY:
case LogicalTypeId::STRING_LITERAL:
case LogicalTypeId::INTEGER_LITERAL:
return 0;
// numerics
case LogicalTypeId::BOOLEAN:
Expand Down Expand Up @@ -1455,6 +1474,41 @@ idx_t AnyType::GetCastScore(const LogicalType &type) {
return info->Cast<AnyTypeInfo>().cast_score;
}

//===--------------------------------------------------------------------===//
// Integer Literal Type
//===--------------------------------------------------------------------===//
LogicalType IntegerLiteral::GetType(const LogicalType &type) {
D_ASSERT(type.id() == LogicalTypeId::INTEGER_LITERAL);
auto info = type.AuxInfo();
D_ASSERT(info && info->type == ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO);
return info->Cast<IntegerLiteralTypeInfo>().constant_value.type();
}

bool IntegerLiteral::FitsInType(const LogicalType &type, const LogicalType &target) {
D_ASSERT(type.id() == LogicalTypeId::INTEGER_LITERAL);
// we can always cast integer literals to float and double
if (target.id() == LogicalTypeId::FLOAT || target.id() == LogicalTypeId::DOUBLE) {
return true;
}
if (!target.IsIntegral()) {
return false;
}
// we can cast to integral types if the constant value fits within that type
auto info = type.AuxInfo();
D_ASSERT(info && info->type == ExtraTypeInfoType::INTEGER_LITERAL_TYPE_INFO);
auto &literal_info = info->Cast<IntegerLiteralTypeInfo>();
Value copy = literal_info.constant_value;
return copy.DefaultTryCastAs(target);
}

LogicalType LogicalType::INTEGER_LITERAL(const Value &constant) {
if (!constant.type().IsIntegral()) {
throw InternalException("INTEGER_LITERAL can only be made from literals of integer types");
}
auto type_info = make_shared<IntegerLiteralTypeInfo>(constant);
return LogicalType(LogicalTypeId::INTEGER_LITERAL, std::move(type_info));
}

//===--------------------------------------------------------------------===//
// Logical Type
//===--------------------------------------------------------------------===//
Expand Down
3 changes: 1 addition & 2 deletions src/core_functions/aggregate/distributive/approx_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,7 @@ AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type)
ApproxCountDistinctUpdateFunction,
AggregateFunction::StateCombine<ApproxDistinctCountState, ApproxCountDistinctFunction>,
AggregateFunction::StateFinalize<ApproxDistinctCountState, int64_t, ApproxCountDistinctFunction>,
ApproxCountDistinctSimpleUpdateFunction,
nullptr,
ApproxCountDistinctSimpleUpdateFunction, nullptr,
AggregateFunction::StateDestroy<ApproxDistinctCountState, ApproxCountDistinctFunction>);
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return fun;
Expand Down
5 changes: 3 additions & 2 deletions src/core_functions/aggregate/distributive/entropy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -147,8 +147,9 @@ AggregateFunction GetEntropyFunctionInternal(PhysicalType type) {
return AggregateFunction::UnaryAggregateDestructor<EntropyState<double>, double, double, EntropyFunction>(
LogicalType::DOUBLE, LogicalType::DOUBLE);
case PhysicalType::VARCHAR: {
return AggregateFunction::UnaryAggregateDestructor<EntropyState<string>, string_t, double, EntropyFunctionString>(
LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150), LogicalType::DOUBLE);
return AggregateFunction::UnaryAggregateDestructor<EntropyState<string>, string_t, double,
EntropyFunctionString>(
LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150), LogicalType::DOUBLE);
}

default:
Expand Down
3 changes: 2 additions & 1 deletion src/core_functions/aggregate/distributive/string_agg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,8 @@ unique_ptr<FunctionData> StringAggDeserialize(Deserializer &deserializer, Aggreg
AggregateFunctionSet StringAggFun::GetFunctions() {
AggregateFunctionSet string_agg;
AggregateFunction string_agg_param(
{LogicalType::ANY_PARAMS(LogicalType::VARCHAR)}, LogicalType::VARCHAR, AggregateFunction::StateSize<StringAggState>,
{LogicalType::ANY_PARAMS(LogicalType::VARCHAR)}, LogicalType::VARCHAR,
AggregateFunction::StateSize<StringAggState>,
AggregateFunction::StateInitialize<StringAggState, StringAggFunction>,
AggregateFunction::UnaryScatterUpdate<StringAggState, string_t, StringAggFunction>,
AggregateFunction::StateCombine<StringAggState, StringAggFunction>,
Expand Down
3 changes: 2 additions & 1 deletion src/core_functions/aggregate/holistic/mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ AggregateFunction GetModeAggregate(const LogicalType &type) {
return GetTypedModeFunction<interval_t, interval_t>(type);

case PhysicalType::VARCHAR:
return GetTypedModeFunction<string_t, string, ModeAssignmentString>(LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150));
return GetTypedModeFunction<string_t, string, ModeAssignmentString>(
LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150));

default:
throw NotImplementedException("Unimplemented mode aggregate");
Expand Down
11 changes: 7 additions & 4 deletions src/core_functions/aggregate/holistic/quantile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1655,10 +1655,13 @@ AggregateFunction GetQuantileDecimalAggregate(const vector<LogicalType> &argumen
}

vector<LogicalType> GetQuantileTypes() {
return {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT,
LogicalType::HUGEINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE,
LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ,
LogicalType::INTERVAL, LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150)};
return {LogicalType::TINYINT, LogicalType::SMALLINT,
LogicalType::INTEGER, LogicalType::BIGINT,
LogicalType::HUGEINT, LogicalType::FLOAT,
LogicalType::DOUBLE, LogicalType::DATE,
LogicalType::TIMESTAMP, LogicalType::TIME,
LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ,
LogicalType::INTERVAL, LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150)};
}

AggregateFunctionSet MedianFun::GetFunctions() {
Expand Down
4 changes: 1 addition & 3 deletions src/core_functions/scalar/list/list_value.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,7 @@ static unique_ptr<FunctionData> ListValueBind(ClientContext &context, ScalarFunc
child_type.ToString(), arg_type.ToString());
}
}
if (child_type.id() == LogicalTypeId::STRING_LITERAL) {
child_type = LogicalType::VARCHAR;
}
child_type = LogicalType::NormalizeType(child_type);

// this is more for completeness reasons
bound_function.varargs = child_type;
Expand Down
22 changes: 15 additions & 7 deletions src/function/cast_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -325,19 +325,27 @@ int64_t CastRules::ImplicitCast(const LogicalType &from, const LogicalType &to)
// string literals can be cast to any type for low cost as long as the type is valid
// i.e. we cannot cast to LIST(ANY) as we don't know what "ANY" should be
// we cannot cast to DECIMAL without precision/width specified
// etc...
// the exception is the ANY type - for the ANY type we just cast to VARCHAR
// but we prefer casting to VARCHAR
if (to.id() != LogicalType::ANY) {
if (!LogicalTypeIsValid(to)) {
return -1;
}
if (!LogicalTypeIsValid(to)) {
return -1;
}
if (to.id() == LogicalTypeId::VARCHAR && to.GetAlias().empty()) {
return 1;
}
return 20;
}
if (from.id() == LogicalTypeId::INTEGER_LITERAL) {
// the integer literal has an underlying type - this type always matches
if (IntegerLiteral::GetType(from).id() == to.id()) {
return 0;
}
// integer literals can be cast to any other integer type for a low cost, but only if the literal fits
if (IntegerLiteral::FitsInType(from, to)) {
// to avoid ties we prefer BIGINT, INT, ...
return TargetTypeCost(to) - 90;
}
// in any other case we use the casting rules of the preferred type of the literal
return CastRules::ImplicitCast(IntegerLiteral::GetType(from), to);
}
if (from.GetAlias() != to.GetAlias()) {
// if aliases are different, an implicit cast is not possible
return -1;
Expand Down
8 changes: 5 additions & 3 deletions src/function/function_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -249,15 +249,17 @@ LogicalType PrepareTypeForCast(const LogicalType &type) {
}

void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector<unique_ptr<Expression>> &children) {
for(auto &arg : function.arguments) {
for (auto &arg : function.arguments) {
arg = PrepareTypeForCast(arg);
}
function.varargs = PrepareTypeForCast(function.varargs);

for (idx_t i = 0; i < children.size(); i++) {
auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs;
if (target_type.id() == LogicalTypeId::STRING_LITERAL) {
throw InternalException("Function %s returned a STRING_LITERAL type - use VARCHAR instead", function.name);
if (target_type.id() == LogicalTypeId::STRING_LITERAL || target_type.id() == LogicalTypeId::INTEGER_LITERAL) {
throw InternalException(
"Function %s returned a STRING_LITERAL or INTEGER_LITERAL type - return an explicit type instead",
function.name);
}
target_type.Verify();
// don't cast lambda children, they get removed before execution
Expand Down
17 changes: 17 additions & 0 deletions src/include/duckdb/common/extra_type_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ enum class ExtraTypeInfoType : uint8_t {
AGGREGATE_STATE_TYPE_INFO = 8,
ARRAY_TYPE_INFO = 9,
ANY_TYPE_INFO = 10,
INTEGER_LITERAL_TYPE_INFO = 11
};

struct ExtraTypeInfo {
Expand Down Expand Up @@ -217,4 +218,20 @@ struct AnyTypeInfo : public ExtraTypeInfo {
AnyTypeInfo();
};

struct IntegerLiteralTypeInfo : public ExtraTypeInfo {
IntegerLiteralTypeInfo(Value constant_value);

Value constant_value;

public:
void Serialize(Serializer &serializer) const override;
static shared_ptr<ExtraTypeInfo> Deserialize(Deserializer &source);

protected:
bool EqualsInternal(ExtraTypeInfo *other_p) const override;

private:
IntegerLiteralTypeInfo();
};

} // namespace duckdb
Loading