Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(query): fix register function working with nullable scalar #17217

Merged
merged 6 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 9 additions & 16 deletions src/query/expression/src/register_vectorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,8 +283,7 @@ pub fn passthrough_nullable_1_arg<I1: ArgType, O: ArgType>(

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand All @@ -308,15 +307,15 @@ pub fn passthrough_nullable_2_arg<I1: ArgType, I2: ArgType, O: ArgType>(
if let Some(validity) = ctx.validity.as_ref() {
args_validity = &args_validity & validity;
}

ctx.validity = Some(args_validity.clone());
match (arg1.value(), arg2.value()) {
(Some(arg1), Some(arg2)) => {
let out = func(arg1, arg2, ctx);

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -352,8 +351,7 @@ pub fn passthrough_nullable_3_arg<I1: ArgType, I2: ArgType, I3: ArgType, O: ArgT

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -397,8 +395,7 @@ pub fn passthrough_nullable_4_arg<

match out {
Value::Column(out) => Value::Column(NullableColumn::new(out, args_validity)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(Some(out)),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(Some(out)),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -427,8 +424,7 @@ pub fn combine_nullable_1_arg<I1: ArgType, O: ArgType>(
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -465,8 +461,7 @@ pub fn combine_nullable_2_arg<I1: ArgType, I2: ArgType, O: ArgType>(
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -505,8 +500,7 @@ pub fn combine_nullable_3_arg<I1: ArgType, I2: ArgType, I3: ArgType, O: ArgType>
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down Expand Up @@ -552,8 +546,7 @@ pub fn combine_nullable_4_arg<I1: ArgType, I2: ArgType, I3: ArgType, I4: ArgType
out.column,
&args_validity & &out.validity,
)),
Value::Scalar(out) if args_validity.get_bit(0) => Value::Scalar(out),
_ => Value::Scalar(None),
Value::Scalar(out) => Value::Scalar(out),
}
}
_ => Value::Scalar(None),
Expand Down
40 changes: 40 additions & 0 deletions src/query/functions/tests/it/scalars/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,3 +271,43 @@ fn list_all_builtin_functions() {
fn check_ambiguity() {
BUILTIN_FUNCTIONS.check_ambiguity()
}

#[test]
fn test_if_function() -> Result<()> {
use databend_common_expression::types::*;
use databend_common_expression::FromData;
use databend_common_expression::Scalar;
let raw_expr = parser::parse_raw_expr("if(eq(n,1), sum_sid + 1,100)", &[
("n", UInt8Type::data_type()),
("sum_sid", Int32Type::data_type().wrap_nullable()),
]);
let expr = type_check::check(&raw_expr, &BUILTIN_FUNCTIONS)?;
let block = DataBlock::new(
vec![
BlockEntry {
data_type: UInt8Type::data_type(),
value: Value::Column(UInt8Type::from_data(vec![2_u8, 1])),
},
BlockEntry {
data_type: Int32Type::data_type().wrap_nullable(),
value: Value::Scalar(Scalar::Number(NumberScalar::Int32(2400_i32))),
},
],
2,
);
let func_ctx = FunctionContext::default();
let evaluator = Evaluator::new(&block, &func_ctx, &BUILTIN_FUNCTIONS);
let result = evaluator.run(&expr).unwrap();
let result = result
.as_column()
.unwrap()
.clone()
.as_nullable()
.unwrap()
.clone();

let bm = Bitmap::from_iter([true, true]);
assert_eq!(result.validity, bm);
assert_eq!(result.column, Int64Type::from_data(vec![100, 2401]));
Ok(())
}
61 changes: 60 additions & 1 deletion tests/sqllogictests/suites/query/cte/basic_r_cte.test
Original file line number Diff line number Diff line change
Expand Up @@ -227,5 +227,64 @@ select cte1.a from cte1;
8
9


statement ok
create table train(
train_id varchar(8) not null ,
departure_station varchar(32) not null,
arrival_station varchar(32) not null,
seat_count int not null
);

statement ok
create table passenger(
passenger_id varchar(16) not null,
departure_station varchar(32) not null,
arrival_station varchar(32) not null
);

statement ok
create table city(city varchar(32));

statement ok
insert into city
with t as (select 1 n union select 2 union select 3 union select 4 union select 5)
,t1 as(select row_number()over() rn from t ,t t2,t t3)
select concat('城市',rn::varchar) city from t1 where rn<=5;

statement ok
insert into train
select concat('G',row_number()over()::varchar),c1.city,c2.city, n from city c1, city c2, (select 600 n union select 800 union select 1200 union select 1600) a ;

statement ok
insert into passenger
select concat('P',substr((100000000+row_number()over())::varchar,2)),c1.city,c2.city from city c1, city c2 ,city c3, city c4, city c5,
city c6, (select 1 n union select 2 union select 3 union select 4) c7,(select 1 n union select 2) c8;


query III
with
t0 as (
select
train_id,
seat_count,
sum(seat_count) over (
partition by departure_station, arrival_station order by train_id
) ::int sum_sid
from
train
)
select
sum(case when n=1 then sum_sid+1 else 0 end::int),
sum(sum_sid),
sum(seat_count)
from
t0,(select 1 n union all select 2);
----
261700 523200 210000

statement ok
use default;

statement ok
drop table t1;
drop database db;
1 change: 1 addition & 0 deletions tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.result
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
OK
25 changes: 25 additions & 0 deletions tests/suites/0_stateless/19_fuzz/19_0005_fuzz_cte.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env bash

CURDIR=$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)
. "$CURDIR"/../../../shell_env.sh


rows=3

echo "" > /tmp/fuzz_a.txt
echo "" > /tmp/fuzz_b.txt

for i in `seq 1 ${rows}`;do
echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number)
from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2;
""" | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_a.txt
done


for i in `seq 1 ${rows}`;do
echo """with t0(sum_sid) as (select sum(number) over(partition by number order by number)
from numbers(3)) select n, if(n =1, sum_sid +1, 0) from t0, (select 1 n union all select 2) order by 1,2;
""" | $BENDSQL_CLIENT_CONNECT >> /tmp/fuzz_b.txt
done

diff /tmp/fuzz_a.txt /tmp/fuzz_b.txt && echo "OK"
Loading