diff --git a/e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt b/e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt new file mode 100644 index 0000000000000..3d8e3279a8b42 --- /dev/null +++ b/e2e_test/udf/bug_fixes/17560_udaf_as_win_func.slt @@ -0,0 +1,19 @@ +# https://github.com/risingwavelabs/risingwave/issues/17560 + +statement ok +create aggregate sum00(value int) returns int language python as $$ +def create_state(): + return 0 +def accumulate(state, value): + return state + value +def retract(state, value): + return state - value +def finish(state): + return state +$$; + +query ii +select t.value, sum00(weight) OVER (PARTITION BY value) from (values (1, 1), (null, 2), (3, 3)) as t(value, weight); +---- +1 1 +3 3 diff --git a/src/frontend/src/expr/window_function.rs b/src/frontend/src/expr/window_function.rs index 70f6a79866fb3..8f2e6c66728dd 100644 --- a/src/frontend/src/expr/window_function.rs +++ b/src/frontend/src/expr/window_function.rs @@ -16,11 +16,11 @@ use itertools::Itertools; use risingwave_common::bail_not_implemented; use risingwave_common::types::DataType; use risingwave_expr::aggregate::AggKind; -use risingwave_expr::sig::FUNCTION_REGISTRY; use risingwave_expr::window_function::{Frame, WindowFuncKind}; use super::{Expr, ExprImpl, OrderBy, RwResult}; use crate::error::{ErrorCode, RwError}; +use crate::expr::infer_type; /// A window function performs a calculation across a set of table rows that are somehow related to /// the current row, according to the window spec `OVER (PARTITION BY .. ORDER BY ..)`. @@ -45,10 +45,10 @@ impl WindowFunction { kind: WindowFuncKind, partition_by: Vec, order_by: OrderBy, - args: Vec, + mut args: Vec, frame: Option, ) -> RwResult { - let return_type = Self::infer_return_type(&kind, &args)?; + let return_type = Self::infer_return_type(&kind, &mut args)?; Ok(Self { kind, args, @@ -59,7 +59,7 @@ impl WindowFunction { }) } - fn infer_return_type(kind: &WindowFuncKind, args: &[ExprImpl]) -> RwResult { + fn infer_return_type(kind: &WindowFuncKind, args: &mut [ExprImpl]) -> RwResult { use WindowFuncKind::*; match (kind, args) { (RowNumber, []) => Ok(DataType::Int64), @@ -87,13 +87,13 @@ impl WindowFunction { ); } - (Aggregate(AggKind::Builtin(agg_kind)), args) => { - let arg_types = args.iter().map(ExprImpl::return_type).collect::>(); - let return_type = FUNCTION_REGISTRY.get_return_type(*agg_kind, &arg_types)?; - Ok(return_type) - } + (Aggregate(agg_kind), args) => Ok(match agg_kind { + AggKind::Builtin(kind) => infer_type((*kind).into(), args)?, + AggKind::UserDefined(udf) => udf.return_type.as_ref().unwrap().into(), + AggKind::WrapScalar(expr) => expr.return_type.as_ref().unwrap().into(), + }), - _ => { + (_, args) => { let args = args .iter() .map(|e| format!("{}", e.return_type()))