Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BREAKING] Implicit cast rules for integer literals, and parameterized ANY for binding #10194

Merged
merged 14 commits into from
Jan 10, 2024
Merged
Prev Previous commit
Next Next commit
Clean up the way that the new ANY rework works by adding extra parame…
…ters that can be passed to ANY - namely the target cast score and the target type
  • Loading branch information
Mytherin committed Jan 9, 2024
commit b49492459707d7958f2b5089c95dd3134b4262b3
15 changes: 15 additions & 0 deletions src/common/extra_type_info.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -330,4 +330,19 @@ bool ArrayTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const {
return child_type == other.child_type && size == other.size;
}

//===--------------------------------------------------------------------===//
// Any Type Info
//===--------------------------------------------------------------------===//
AnyTypeInfo::AnyTypeInfo() : ExtraTypeInfo(ExtraTypeInfoType::ANY_TYPE_INFO) {
}

AnyTypeInfo::AnyTypeInfo(LogicalType target_type_p, idx_t cast_score_p)
: ExtraTypeInfo(ExtraTypeInfoType::ANY_TYPE_INFO), target_type(std::move(target_type_p)), cast_score(cast_score_p) {
}

bool AnyTypeInfo::EqualsInternal(ExtraTypeInfo *other_p) const {
auto &other = other_p->Cast<AnyTypeInfo>();
return target_type == other.target_type && cast_score == other.cast_score;
}

} // namespace duckdb
26 changes: 26 additions & 0 deletions src/common/types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1429,6 +1429,32 @@ LogicalType LogicalType::ARRAY(const LogicalType &child) {
return LogicalType(LogicalTypeId::ARRAY, std::move(info));
}

//===--------------------------------------------------------------------===//
// Any Type
//===--------------------------------------------------------------------===//
LogicalType LogicalType::ANY_PARAMS(LogicalType target, idx_t cast_score) {
auto type_info = make_shared<AnyTypeInfo>(std::move(target), cast_score);
return LogicalType(LogicalTypeId::ANY, std::move(type_info));
}

LogicalType AnyType::GetTargetType(const LogicalType &type) {
D_ASSERT(type.id() == LogicalTypeId::ANY);
auto info = type.AuxInfo();
if (!info) {
return LogicalType::ANY;
}
return info->Cast<AnyTypeInfo>().target_type;
}

idx_t AnyType::GetCastScore(const LogicalType &type) {
D_ASSERT(type.id() == LogicalTypeId::ANY);
auto info = type.AuxInfo();
if (!info) {
return 5;
}
return info->Cast<AnyTypeInfo>().cast_score;
}

//===--------------------------------------------------------------------===//
// Logical Type
//===--------------------------------------------------------------------===//
Expand Down
10 changes: 2 additions & 8 deletions src/core_functions/aggregate/distributive/approx_count.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -110,12 +110,6 @@ static void ApproxCountDistinctUpdateFunction(Vector inputs[], AggregateInputDat
HyperLogLog::AddToLogs(vdata, count, indices, counts, reinterpret_cast<HyperLogLog ***>(states), sdata.sel);
}

unique_ptr<FunctionData> ApproxCountDistinctAnyBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
function.arguments[0] = LogicalType::VARCHAR;
return nullptr;
}

AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type) {
auto fun = AggregateFunction(
{input_type}, LogicalTypeId::BIGINT, AggregateFunction::StateSize<ApproxDistinctCountState>,
Expand All @@ -124,7 +118,7 @@ AggregateFunction GetApproxCountDistinctFunction(const LogicalType &input_type)
AggregateFunction::StateCombine<ApproxDistinctCountState, ApproxCountDistinctFunction>,
AggregateFunction::StateFinalize<ApproxDistinctCountState, int64_t, ApproxCountDistinctFunction>,
ApproxCountDistinctSimpleUpdateFunction,
input_type.id() == LogicalTypeId::ANY ? ApproxCountDistinctAnyBind : nullptr,
nullptr,
AggregateFunction::StateDestroy<ApproxDistinctCountState, ApproxCountDistinctFunction>);
fun.null_handling = FunctionNullHandling::SPECIAL_HANDLING;
return fun;
Expand All @@ -146,7 +140,7 @@ AggregateFunctionSet ApproxCountDistinctFun::GetFunctions() {
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::TIMESTAMP_TZ));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::BLOB));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::ANY));
approx_count.AddFunction(GetApproxCountDistinctFunction(LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150)));
return approx_count;
}

Expand Down
13 changes: 2 additions & 11 deletions src/core_functions/aggregate/distributive/entropy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -120,12 +120,6 @@ AggregateFunction GetEntropyFunction(const LogicalType &input_type, const Logica
return fun;
}

static unique_ptr<FunctionData> EntropyVarcharBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
function.arguments[0] = LogicalType::VARCHAR;
return nullptr;
}

AggregateFunction GetEntropyFunctionInternal(PhysicalType type) {
switch (type) {
case PhysicalType::UINT16:
Expand Down Expand Up @@ -153,11 +147,8 @@ AggregateFunction GetEntropyFunctionInternal(PhysicalType type) {
return AggregateFunction::UnaryAggregateDestructor<EntropyState<double>, double, double, EntropyFunction>(
LogicalType::DOUBLE, LogicalType::DOUBLE);
case PhysicalType::VARCHAR: {
AggregateFunction result =
AggregateFunction::UnaryAggregateDestructor<EntropyState<string>, string_t, double, EntropyFunctionString>(
LogicalType::ANY, LogicalType::DOUBLE);
result.bind = EntropyVarcharBind;
return result;
return AggregateFunction::UnaryAggregateDestructor<EntropyState<string>, string_t, double, EntropyFunctionString>(
LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150), LogicalType::DOUBLE);
}

default:
Expand Down
3 changes: 1 addition & 2 deletions src/core_functions/aggregate/distributive/string_agg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@ struct StringAggFunction {

unique_ptr<FunctionData> StringAggBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
function.arguments[0] = LogicalType::VARCHAR;
if (arguments.size() == 1) {
// single argument: default to comma
return make_uniq<StringAggBindData>(",");
Expand Down Expand Up @@ -156,7 +155,7 @@ unique_ptr<FunctionData> StringAggDeserialize(Deserializer &deserializer, Aggreg
AggregateFunctionSet StringAggFun::GetFunctions() {
AggregateFunctionSet string_agg;
AggregateFunction string_agg_param(
{LogicalType::ANY}, LogicalType::VARCHAR, AggregateFunction::StateSize<StringAggState>,
{LogicalType::ANY_PARAMS(LogicalType::VARCHAR)}, LogicalType::VARCHAR, AggregateFunction::StateSize<StringAggState>,
AggregateFunction::StateInitialize<StringAggState, StringAggFunction>,
AggregateFunction::UnaryScatterUpdate<StringAggState, string_t, StringAggFunction>,
AggregateFunction::StateCombine<StringAggState, StringAggFunction>,
Expand Down
11 changes: 1 addition & 10 deletions src/core_functions/aggregate/holistic/mode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,21 +308,12 @@ struct ModeFunction {
}
};

static unique_ptr<FunctionData> ModeVarcharBind(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
function.arguments[0] = LogicalType::VARCHAR;
return nullptr;
}

template <typename INPUT_TYPE, typename KEY_TYPE, typename ASSIGN_OP = ModeAssignmentStandard>
AggregateFunction GetTypedModeFunction(const LogicalType &type) {
using STATE = ModeState<KEY_TYPE>;
using OP = ModeFunction<KEY_TYPE, ASSIGN_OP>;
auto return_type = type.id() == LogicalTypeId::ANY ? LogicalType::VARCHAR : type;
auto func = AggregateFunction::UnaryAggregateDestructor<STATE, INPUT_TYPE, INPUT_TYPE, OP>(type, return_type);
if (type.id() == LogicalTypeId::ANY) {
func.bind = ModeVarcharBind;
}
func.window = AggregateFunction::UnaryWindow<STATE, INPUT_TYPE, INPUT_TYPE, OP>;
return func;
}
Expand Down Expand Up @@ -359,7 +350,7 @@ AggregateFunction GetModeAggregate(const LogicalType &type) {
return GetTypedModeFunction<interval_t, interval_t>(type);

case PhysicalType::VARCHAR:
return GetTypedModeFunction<string_t, string, ModeAssignmentString>(LogicalType::ANY);
return GetTypedModeFunction<string_t, string, ModeAssignmentString>(LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150));

default:
throw NotImplementedException("Unimplemented mode aggregate");
Expand Down
9 changes: 1 addition & 8 deletions src/core_functions/aggregate/holistic/quantile.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1224,7 +1224,6 @@ AggregateFunction GetContinuousQuantileListAggregateFunction(const LogicalType &
return GetTypedContinuousQuantileListAggregateFunction<int64_t, double>(type, LogicalType::DOUBLE);
case LogicalTypeId::HUGEINT:
return GetTypedContinuousQuantileListAggregateFunction<hugeint_t, double>(type, LogicalType::DOUBLE);

case LogicalTypeId::FLOAT:
return GetTypedContinuousQuantileListAggregateFunction<float, float>(type, type);
case LogicalTypeId::DOUBLE:
Expand All @@ -1242,8 +1241,6 @@ AggregateFunction GetContinuousQuantileListAggregateFunction(const LogicalType &
default:
throw NotImplementedException("Unimplemented discrete quantile DECIMAL list aggregate");
}
break;

case LogicalTypeId::DATE:
return GetTypedContinuousQuantileListAggregateFunction<date_t, timestamp_t>(type, LogicalType::TIMESTAMP);
case LogicalTypeId::TIMESTAMP:
Expand All @@ -1252,7 +1249,6 @@ AggregateFunction GetContinuousQuantileListAggregateFunction(const LogicalType &
case LogicalTypeId::TIME:
case LogicalTypeId::TIME_TZ:
return GetTypedContinuousQuantileListAggregateFunction<dtime_t, dtime_t>(type, type);

default:
throw NotImplementedException("Unimplemented discrete quantile list aggregate");
}
Expand Down Expand Up @@ -1415,9 +1411,6 @@ struct MedianAbsoluteDeviationOperation : public QuantileOperation {

unique_ptr<FunctionData> BindMedian(ClientContext &context, AggregateFunction &function,
vector<unique_ptr<Expression>> &arguments) {
if (function.arguments[0].id() == LogicalTypeId::ANY) {
function.arguments[0] = LogicalType::VARCHAR;
}
return make_uniq<QuantileBindData>(Value::DECIMAL(int16_t(5), 2, 1));
}

Expand Down Expand Up @@ -1665,7 +1658,7 @@ vector<LogicalType> GetQuantileTypes() {
return {LogicalType::TINYINT, LogicalType::SMALLINT, LogicalType::INTEGER, LogicalType::BIGINT,
LogicalType::HUGEINT, LogicalType::FLOAT, LogicalType::DOUBLE, LogicalType::DATE,
LogicalType::TIMESTAMP, LogicalType::TIME, LogicalType::TIMESTAMP_TZ, LogicalType::TIME_TZ,
LogicalType::INTERVAL, LogicalType::ANY};
LogicalType::INTERVAL, LogicalType::ANY_PARAMS(LogicalType::VARCHAR, 150)};
}

AggregateFunctionSet MedianFun::GetFunctions() {
Expand Down
51 changes: 23 additions & 28 deletions src/core_functions/aggregate/nested/histogram.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,8 @@ unique_ptr<FunctionData> HistogramBindFunction(ClientContext &context, Aggregate
arguments[0]->return_type.id() == LogicalTypeId::MAP) {
throw NotImplementedException("Unimplemented type for histogram %s", arguments[0]->return_type.ToString());
}
if (function.arguments[0].id() == LogicalTypeId::ANY) {
// add varchar cast for ANY
function.arguments[0] = LogicalType::VARCHAR;
}

auto struct_type = LogicalType::MAP(arguments[0]->return_type, LogicalType::UBIGINT);
auto child_type = function.arguments[0].id() == LogicalTypeId::ANY ? LogicalType::VARCHAR : function.arguments[0];
auto struct_type = LogicalType::MAP(child_type, LogicalType::UBIGINT);

function.return_type = struct_type;
return make_uniq<VariableReturnBindData>(function.return_type);
Expand Down Expand Up @@ -188,47 +184,46 @@ AggregateFunction GetMapType(const LogicalType &type) {

template <bool IS_ORDERED = true>
AggregateFunction GetHistogramFunction(const LogicalType &type) {

switch (type.id()) {
case LogicalType::BOOLEAN:
case LogicalTypeId::BOOLEAN:
return GetMapType<HistogramFunctor, bool, IS_ORDERED>(type);
case LogicalType::UTINYINT:
case LogicalTypeId::UTINYINT:
return GetMapType<HistogramFunctor, uint8_t, IS_ORDERED>(type);
case LogicalType::USMALLINT:
case LogicalTypeId::USMALLINT:
return GetMapType<HistogramFunctor, uint16_t, IS_ORDERED>(type);
case LogicalType::UINTEGER:
case LogicalTypeId::UINTEGER:
return GetMapType<HistogramFunctor, uint32_t, IS_ORDERED>(type);
case LogicalType::UBIGINT:
case LogicalTypeId::UBIGINT:
return GetMapType<HistogramFunctor, uint64_t, IS_ORDERED>(type);
case LogicalType::TINYINT:
case LogicalTypeId::TINYINT:
return GetMapType<HistogramFunctor, int8_t, IS_ORDERED>(type);
case LogicalType::SMALLINT:
case LogicalTypeId::SMALLINT:
return GetMapType<HistogramFunctor, int16_t, IS_ORDERED>(type);
case LogicalType::INTEGER:
case LogicalTypeId::INTEGER:
return GetMapType<HistogramFunctor, int32_t, IS_ORDERED>(type);
case LogicalType::BIGINT:
case LogicalTypeId::BIGINT:
return GetMapType<HistogramFunctor, int64_t, IS_ORDERED>(type);
case LogicalType::FLOAT:
case LogicalTypeId::FLOAT:
return GetMapType<HistogramFunctor, float, IS_ORDERED>(type);
case LogicalType::DOUBLE:
case LogicalTypeId::DOUBLE:
return GetMapType<HistogramFunctor, double, IS_ORDERED>(type);
case LogicalType::TIMESTAMP:
case LogicalTypeId::TIMESTAMP:
return GetMapType<HistogramFunctor, timestamp_t, IS_ORDERED>(type);
case LogicalType::TIMESTAMP_TZ:
case LogicalTypeId::TIMESTAMP_TZ:
return GetMapType<HistogramFunctor, timestamp_tz_t, IS_ORDERED>(type);
case LogicalType::TIMESTAMP_S:
case LogicalTypeId::TIMESTAMP_SEC:
return GetMapType<HistogramFunctor, timestamp_sec_t, IS_ORDERED>(type);
case LogicalType::TIMESTAMP_MS:
case LogicalTypeId::TIMESTAMP_MS:
return GetMapType<HistogramFunctor, timestamp_ms_t, IS_ORDERED>(type);
case LogicalType::TIMESTAMP_NS:
case LogicalTypeId::TIMESTAMP_NS:
return GetMapType<HistogramFunctor, timestamp_ns_t, IS_ORDERED>(type);
case LogicalType::TIME:
case LogicalTypeId::TIME:
return GetMapType<HistogramFunctor, dtime_t, IS_ORDERED>(type);
case LogicalType::TIME_TZ:
case LogicalTypeId::TIME_TZ:
return GetMapType<HistogramFunctor, dtime_tz_t, IS_ORDERED>(type);
case LogicalType::DATE:
case LogicalTypeId::DATE:
return GetMapType<HistogramFunctor, date_t, IS_ORDERED>(type);
case LogicalType::ANY:
case LogicalTypeId::ANY:
return GetMapType<HistogramStringFunctor, string, IS_ORDERED>(type);
default:
throw InternalException("Unimplemented histogram aggregate");
Expand Down Expand Up @@ -256,7 +251,7 @@ AggregateFunctionSet HistogramFun::GetFunctions() {
fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME));
fun.AddFunction(GetHistogramFunction<>(LogicalType::TIME_TZ));
fun.AddFunction(GetHistogramFunction<>(LogicalType::DATE));
fun.AddFunction(GetHistogramFunction<>(LogicalType::ANY));
fun.AddFunction(GetHistogramFunction<>(LogicalType::ANY_PARAMS(LogicalType::VARCHAR)));
return fun;
}

Expand Down
2 changes: 1 addition & 1 deletion src/function/cast_rules.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ static int64_t TargetTypeCost(const LogicalType &type) {
case LogicalTypeId::ARRAY:
return 160;
case LogicalTypeId::ANY:
return 5;
return int64_t(AnyType::GetCastScore(type));
default:
return 110;
}
Expand Down
15 changes: 15 additions & 0 deletions src/function/function_binder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,22 @@ LogicalTypeComparisonResult RequiresCast(const LogicalType &source_type, const L
return LogicalTypeComparisonResult::DIFFERENT_TYPES;
}

LogicalType PrepareTypeForCast(const LogicalType &type) {
if (type.id() == LogicalTypeId::ANY) {
return AnyType::GetTargetType(type);
}
if (type.id() == LogicalTypeId::LIST) {
return LogicalType::LIST(PrepareTypeForCast(ListType::GetChildType(type)));
}
return type;
}

void FunctionBinder::CastToFunctionArguments(SimpleFunction &function, vector<unique_ptr<Expression>> &children) {
for(auto &arg : function.arguments) {
arg = PrepareTypeForCast(arg);
}
function.varargs = PrepareTypeForCast(function.varargs);

for (idx_t i = 0; i < children.size(); i++) {
auto target_type = i < function.arguments.size() ? function.arguments[i] : function.varargs;
if (target_type.id() == LogicalTypeId::STRING_LITERAL) {
Expand Down
18 changes: 18 additions & 0 deletions src/include/duckdb/common/extra_type_info.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ enum class ExtraTypeInfoType : uint8_t {
USER_TYPE_INFO = 7,
AGGREGATE_STATE_TYPE_INFO = 8,
ARRAY_TYPE_INFO = 9,
ANY_TYPE_INFO = 10,
};

struct ExtraTypeInfo {
Expand Down Expand Up @@ -199,4 +200,21 @@ struct ArrayTypeInfo : public ExtraTypeInfo {
bool EqualsInternal(ExtraTypeInfo *other_p) const override;
};

struct AnyTypeInfo : public ExtraTypeInfo {
AnyTypeInfo(LogicalType target_type, idx_t cast_score);

LogicalType target_type;
idx_t cast_score;

public:
void Serialize(Serializer &serializer) const override;
static shared_ptr<ExtraTypeInfo> Deserialize(Deserializer &source);

protected:
bool EqualsInternal(ExtraTypeInfo *other_p) const override;

private:
AnyTypeInfo();
};

} // namespace duckdb
7 changes: 7 additions & 0 deletions src/include/duckdb/common/types.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,8 @@ struct LogicalType {
// an array of unknown size (only used for binding)
DUCKDB_API static LogicalType ARRAY(const LogicalType &child); // NOLINT
DUCKDB_API static LogicalType ENUM(Vector &ordered_data, idx_t size); // NOLINT
// ANY but with special rules (default is LogicalType::ANY, 5)
DUCKDB_API static LogicalType ANY_PARAMS(LogicalType target, idx_t cast_score = 5); // NOLINT
// DEPRECATED - provided for backwards compatibility
DUCKDB_API static LogicalType ENUM(const string &enum_name, Vector &ordered_data, idx_t size); // NOLINT
DUCKDB_API static LogicalType USER(const string &user_type_name); // NOLINT
Expand Down Expand Up @@ -456,6 +458,11 @@ struct AggregateStateType {
DUCKDB_API static const aggregate_state_t &GetStateType(const LogicalType &type);
};

struct AnyType {
DUCKDB_API static LogicalType GetTargetType(const LogicalType &type);
DUCKDB_API static idx_t GetCastScore(const LogicalType &type);
};

// **DEPRECATED**: Use EnumUtil directly instead.
DUCKDB_API string LogicalTypeIdToString(LogicalTypeId type);

Expand Down
Loading