Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Maintain prepared statement parameter types explicitly instead of converting into literals #12759

Merged
merged 6 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/include/duckdb/main/capi/capi_internal.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "duckdb/main/appender.hpp"
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/main/client_context.hpp"
#include "duckdb/planner/expression/bound_parameter_data.hpp"

#include <cstring>
#include <cassert>
Expand All @@ -33,7 +34,7 @@ struct DatabaseData {

struct PreparedStatementWrapper {
//! Map of name -> values
case_insensitive_map_t<Value> values;
case_insensitive_map_t<BoundParameterData> values;
unique_ptr<PreparedStatement> statement;
};

Expand Down
30 changes: 16 additions & 14 deletions src/include/duckdb/main/client_context.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,25 @@

#include "duckdb/catalog/catalog_entry/schema_catalog_entry.hpp"
#include "duckdb/catalog/catalog_set.hpp"
#include "duckdb/common/enums/pending_execution_result.hpp"
#include "duckdb/common/atomic.hpp"
#include "duckdb/common/deque.hpp"
#include "duckdb/common/enums/pending_execution_result.hpp"
#include "duckdb/common/enums/prepared_statement_mode.hpp"
#include "duckdb/common/error_data.hpp"
#include "duckdb/common/pair.hpp"
#include "duckdb/common/unordered_set.hpp"
#include "duckdb/common/winapi.hpp"
#include "duckdb/main/client_config.hpp"
#include "duckdb/main/client_context_state.hpp"
#include "duckdb/main/client_properties.hpp"
#include "duckdb/main/external_dependencies.hpp"
#include "duckdb/main/pending_query_result.hpp"
#include "duckdb/main/prepared_statement.hpp"
#include "duckdb/main/settings.hpp"
#include "duckdb/main/stream_query_result.hpp"
#include "duckdb/main/table_description.hpp"
#include "duckdb/transaction/transaction_context.hpp"
#include "duckdb/main/pending_query_result.hpp"
#include "duckdb/common/atomic.hpp"
#include "duckdb/main/client_config.hpp"
#include "duckdb/main/external_dependencies.hpp"
#include "duckdb/common/error_data.hpp"
#include "duckdb/common/enums/prepared_statement_mode.hpp"
#include "duckdb/main/client_properties.hpp"
#include "duckdb/main/client_context_state.hpp"
#include "duckdb/main/settings.hpp"
#include "duckdb/planner/expression/bound_parameter_data.hpp"

namespace duckdb {
class Appender;
Expand All @@ -52,7 +53,7 @@ class ClientContextState;

struct PendingQueryParameters {
//! Prepared statement parameters (if any)
optional_ptr<case_insensitive_map_t<Value>> parameters;
optional_ptr<case_insensitive_map_t<BoundParameterData>> parameters;
//! Whether or not a stream result should be allowed
bool allow_stream_result = false;
};
Expand Down Expand Up @@ -142,7 +143,8 @@ class ClientContext : public enable_shared_from_this<ClientContext> {
//! It is possible that the prepared statement will be re-bound. This will generally happen if the catalog is
//! modified in between the prepared statement being bound and the prepared statement being run.
DUCKDB_API unique_ptr<QueryResult> Execute(const string &query, shared_ptr<PreparedStatementData> &prepared,
case_insensitive_map_t<Value> &values, bool allow_stream_result = true);
case_insensitive_map_t<BoundParameterData> &values,
bool allow_stream_result = true);
DUCKDB_API unique_ptr<QueryResult> Execute(const string &query, shared_ptr<PreparedStatementData> &prepared,
const PendingQueryParameters &parameters);

Expand Down Expand Up @@ -227,7 +229,7 @@ class ClientContext : public enable_shared_from_this<ClientContext> {
//! Internally prepare a SQL statement. Caller must hold the context_lock.
shared_ptr<PreparedStatementData>
CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr<SQLStatement> statement,
optional_ptr<case_insensitive_map_t<Value>> values = nullptr,
optional_ptr<case_insensitive_map_t<BoundParameterData>> values = nullptr,
PreparedStatementMode mode = PreparedStatementMode::PREPARE_ONLY);
unique_ptr<PendingQueryResult> PendingStatementInternal(ClientContextLock &lock, const string &query,
unique_ptr<SQLStatement> statement,
Expand Down Expand Up @@ -268,7 +270,7 @@ class ClientContext : public enable_shared_from_this<ClientContext> {

shared_ptr<PreparedStatementData>
CreatePreparedStatementInternal(ClientContextLock &lock, const string &query, unique_ptr<SQLStatement> statement,
optional_ptr<case_insensitive_map_t<Value>> values);
optional_ptr<case_insensitive_map_t<BoundParameterData>> values);

private:
//! Lock on using the ClientContext in parallel
Expand Down
5 changes: 3 additions & 2 deletions src/include/duckdb/main/prepared_statement.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "duckdb/main/pending_query_result.hpp"
#include "duckdb/common/error_data.hpp"
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/planner/expression/bound_parameter_data.hpp"

namespace duckdb {
class ClientContext;
Expand Down Expand Up @@ -76,14 +77,14 @@ class PreparedStatement {
DUCKDB_API unique_ptr<PendingQueryResult> PendingQuery(vector<Value> &values, bool allow_stream_result = true);

//! Create a pending query result of the prepared statement with the given set named arguments
DUCKDB_API unique_ptr<PendingQueryResult> PendingQuery(case_insensitive_map_t<Value> &named_values,
DUCKDB_API unique_ptr<PendingQueryResult> PendingQuery(case_insensitive_map_t<BoundParameterData> &named_values,
bool allow_stream_result = true);

//! Execute the prepared statement with the given set of values
DUCKDB_API unique_ptr<QueryResult> Execute(vector<Value> &values, bool allow_stream_result = true);

//! Execute the prepared statement with the given set of named+unnamed values
DUCKDB_API unique_ptr<QueryResult> Execute(case_insensitive_map_t<Value> &named_values,
DUCKDB_API unique_ptr<QueryResult> Execute(case_insensitive_map_t<BoundParameterData> &named_values,
bool allow_stream_result = true);

//! Execute the prepared statement with the given set of arguments
Expand Down
4 changes: 2 additions & 2 deletions src/include/duckdb/main/prepared_statement_data.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ class PreparedStatementData {
public:
void CheckParameterCount(idx_t parameter_count);
//! Whether or not the prepared statement data requires the query to rebound for the given parameters
bool RequireRebind(ClientContext &context, optional_ptr<case_insensitive_map_t<Value>> values);
bool RequireRebind(ClientContext &context, optional_ptr<case_insensitive_map_t<BoundParameterData>> values);
//! Bind a set of values to the prepared statement data
DUCKDB_API void Bind(case_insensitive_map_t<Value> values);
DUCKDB_API void Bind(case_insensitive_map_t<BoundParameterData> values);
//! Get the expected SQL Type of the bound parameter
DUCKDB_API LogicalType GetType(const string &identifier);
//! Try to get the expected SQL Type of the bound parameter
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ struct BoundParameterData {
}
explicit BoundParameterData(Value val) : value(std::move(val)), return_type(value.type()) {
}
BoundParameterData(Value val, LogicalType type_p) : value(std::move(val)), return_type(std::move(type_p)) {
}

private:
Value value;
Expand Down
4 changes: 2 additions & 2 deletions src/main/capi/prepared-c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ duckdb_type duckdb_param_type(duckdb_prepared_statement prepared_statement, idx_
// See if this is the case and we still have a value registered for it
auto it = wrapper->values.find(identifier);
if (it != wrapper->values.end()) {
return ConvertCPPTypeToC(it->second.type());
return ConvertCPPTypeToC(it->second.return_type.id());
}
return DUCKDB_TYPE_INVALID;
}
Expand All @@ -162,7 +162,7 @@ duckdb_state duckdb_bind_value(duckdb_prepared_statement prepared_statement, idx
return DuckDBError;
}
auto identifier = duckdb_parameter_name_internal(prepared_statement, param_idx);
wrapper->values[identifier] = *value;
wrapper->values[identifier] = duckdb::BoundParameterData(*value);
return DuckDBSuccess;
}

Expand Down
10 changes: 6 additions & 4 deletions src/main/client_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ static bool IsExplainAnalyze(SQLStatement *statement) {
shared_ptr<PreparedStatementData>
ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const string &query,
unique_ptr<SQLStatement> statement,
optional_ptr<case_insensitive_map_t<Value>> values) {
optional_ptr<case_insensitive_map_t<BoundParameterData>> values) {
StatementType statement_type = statement->type;
auto result = make_shared_ptr<PreparedStatementData>(statement_type);

Expand Down Expand Up @@ -369,7 +369,8 @@ ClientContext::CreatePreparedStatementInternal(ClientContextLock &lock, const st

shared_ptr<PreparedStatementData>
ClientContext::CreatePreparedStatement(ClientContextLock &lock, const string &query, unique_ptr<SQLStatement> statement,
optional_ptr<case_insensitive_map_t<Value>> values, PreparedStatementMode mode) {
optional_ptr<case_insensitive_map_t<BoundParameterData>> values,
PreparedStatementMode mode) {
// check if any client context state could request a rebind
bool can_request_rebind = false;
for (auto const &s : registered_state) {
Expand Down Expand Up @@ -420,7 +421,7 @@ QueryProgress ClientContext::GetQueryProgress() {
}

void BindPreparedStatementParameters(PreparedStatementData &statement, const PendingQueryParameters &parameters) {
case_insensitive_map_t<Value> owned_values;
case_insensitive_map_t<BoundParameterData> owned_values;
if (parameters.parameters) {
auto &params = *parameters.parameters;
for (auto &val : params) {
Expand Down Expand Up @@ -712,7 +713,8 @@ unique_ptr<QueryResult> ClientContext::Execute(const string &query, shared_ptr<P
}

unique_ptr<QueryResult> ClientContext::Execute(const string &query, shared_ptr<PreparedStatementData> &prepared,
case_insensitive_map_t<Value> &values, bool allow_stream_result) {
case_insensitive_map_t<BoundParameterData> &values,
bool allow_stream_result) {
PendingQueryParameters parameters;
parameters.parameters = &values;
parameters.allow_stream_result = allow_stream_result;
Expand Down
8 changes: 4 additions & 4 deletions src/main/prepared_statement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ case_insensitive_map_t<LogicalType> PreparedStatement::GetExpectedParameterTypes
return expected_types;
}

unique_ptr<QueryResult> PreparedStatement::Execute(case_insensitive_map_t<Value> &named_values,
unique_ptr<QueryResult> PreparedStatement::Execute(case_insensitive_map_t<BoundParameterData> &named_values,
bool allow_stream_result) {
auto pending = PendingQuery(named_values, allow_stream_result);
if (pending->HasError()) {
Expand All @@ -86,15 +86,15 @@ unique_ptr<QueryResult> PreparedStatement::Execute(vector<Value> &values, bool a
}

unique_ptr<PendingQueryResult> PreparedStatement::PendingQuery(vector<Value> &values, bool allow_stream_result) {
case_insensitive_map_t<Value> named_values;
case_insensitive_map_t<BoundParameterData> named_values;
for (idx_t i = 0; i < values.size(); i++) {
auto &val = values[i];
named_values[std::to_string(i + 1)] = val;
named_values[std::to_string(i + 1)] = BoundParameterData(val);
}
return PendingQuery(named_values, allow_stream_result);
}

unique_ptr<PendingQueryResult> PreparedStatement::PendingQuery(case_insensitive_map_t<Value> &named_values,
unique_ptr<PendingQueryResult> PreparedStatement::PendingQuery(case_insensitive_map_t<BoundParameterData> &named_values,
bool allow_stream_result) {
if (!success) {
auto exception = InvalidInputException("Attempting to execute an unsuccessfully prepared statement!");
Expand Down
11 changes: 6 additions & 5 deletions src/main/prepared_statement_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ void StartTransactionInCatalog(ClientContext &context, const string &catalog_nam
Transaction::Get(context, *database);
}

bool PreparedStatementData::RequireRebind(ClientContext &context, optional_ptr<case_insensitive_map_t<Value>> values) {
bool PreparedStatementData::RequireRebind(ClientContext &context,
optional_ptr<case_insensitive_map_t<BoundParameterData>> values) {
idx_t count = values ? values->size() : 0;
CheckParameterCount(count);
if (!unbound_statement) {
Expand All @@ -49,7 +50,7 @@ bool PreparedStatementData::RequireRebind(ClientContext &context, optional_ptr<c
if (lookup == values->end()) {
break;
}
if (lookup->second.type() != it.second->return_type) {
if (lookup->second.GetValue().type() != it.second->return_type) {
return true;
}
}
Expand All @@ -68,7 +69,7 @@ bool PreparedStatementData::RequireRebind(ClientContext &context, optional_ptr<c
return false;
}

void PreparedStatementData::Bind(case_insensitive_map_t<Value> values) {
void PreparedStatementData::Bind(case_insensitive_map_t<BoundParameterData> values) {
// set parameters
D_ASSERT(!unbound_statement || unbound_statement->n_param == properties.parameter_count);
CheckParameterCount(values.size());
Expand All @@ -81,13 +82,13 @@ void PreparedStatementData::Bind(case_insensitive_map_t<Value> values) {
throw BinderException("Could not find parameter with identifier %s", identifier);
}
D_ASSERT(it.second);
auto &value = lookup->second;
auto value = lookup->second.GetValue();
if (!value.DefaultTryCastAs(it.second->return_type)) {
throw BinderException(
"Type mismatch for binding parameter with identifier %s, expected type %s but got type %s", identifier,
it.second->return_type.ToString().c_str(), value.type().ToString().c_str());
}
it.second->SetValue(value);
it.second->SetValue(std::move(value));
}
}

Expand Down
11 changes: 9 additions & 2 deletions src/planner/binder/expression/bind_parameter_expression.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "duckdb/parser/expression/parameter_expression.hpp"
#include "duckdb/planner/binder.hpp"
#include "duckdb/planner/expression/bound_cast_expression.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"
#include "duckdb/planner/expression/bound_parameter_expression.hpp"
#include "duckdb/planner/expression_binder.hpp"
Expand All @@ -19,10 +20,16 @@ BindResult ExpressionBinder::BindExpression(ParameterExpression &expr, idx_t dep
if (param_data_it != parameter_data.end()) {
// it has! emit a constant directly
auto &data = param_data_it->second;
auto return_type = binder.parameters->GetReturnType(parameter_id);
bool is_literal =
return_type.id() == LogicalTypeId::INTEGER_LITERAL || return_type.id() == LogicalTypeId::STRING_LITERAL;
auto constant = make_uniq<BoundConstantExpression>(data.GetValue());
constant->alias = expr.alias;
constant->return_type = binder.parameters->GetReturnType(parameter_id);
return BindResult(std::move(constant));
if (is_literal) {
return BindResult(std::move(constant));
}
auto cast = BoundCastExpression::AddCastToType(context, std::move(constant), return_type);
return BindResult(std::move(cast));
}

auto bound_parameter = binder.parameters->BindParameterExpression(expr);
Expand Down
30 changes: 23 additions & 7 deletions src/planner/binder/statement/bind_execute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "duckdb/main/client_data.hpp"
#include "duckdb/execution/expression_executor.hpp"
#include "duckdb/common/case_insensitive_map.hpp"
#include "duckdb/planner/expression/bound_constant_expression.hpp"

namespace duckdb {

Expand All @@ -30,24 +31,39 @@ BoundStatement Binder::Bind(ExecuteStatement &stmt) {

auto &mapped_named_values = stmt.named_values;
// bind any supplied parameters
case_insensitive_map_t<Value> bind_values;
case_insensitive_map_t<BoundParameterData> bind_values;
auto constant_binder = Binder::CreateBinder(context);
constant_binder->SetCanContainNulls(true);
for (auto &pair : mapped_named_values) {
bool is_literal = pair.second->type == ExpressionType::VALUE_CONSTANT;

ConstantBinder cbinder(*constant_binder, context, "EXECUTE statement");
auto bound_expr = cbinder.Bind(pair.second);

Value value = ExpressionExecutor::EvaluateScalar(context, *bound_expr, true);
bind_values[pair.first] = std::move(value);
BoundParameterData parameter_data;
if (is_literal) {
auto &constant = bound_expr->Cast<BoundConstantExpression>();
LogicalType return_type;
if (constant.return_type == LogicalTypeId::VARCHAR &&
StringType::GetCollation(constant.return_type).empty()) {
return_type = LogicalTypeId::STRING_LITERAL;
} else if (constant.return_type.IsIntegral()) {
return_type = LogicalType::INTEGER_LITERAL(constant.value);
} else {
return_type = constant.value.type();
}
parameter_data = BoundParameterData(std::move(constant.value), std::move(return_type));
} else {
auto value = ExpressionExecutor::EvaluateScalar(context, *bound_expr, true);
parameter_data = BoundParameterData(std::move(value));
}
bind_values[pair.first] = std::move(parameter_data);
}
unique_ptr<LogicalOperator> rebound_plan;

if (prepared->RequireRebind(context, &bind_values)) {
// catalog was modified or statement does not have clear types: rebind the statement before running the execute
Planner prepared_planner(context);
for (auto &pair : bind_values) {
prepared_planner.parameter_data.emplace(std::make_pair(pair.first, BoundParameterData(pair.second)));
}
prepared_planner.parameter_data = bind_values;
prepared = prepared_planner.PrepareSQLStatement(entry->second->unbound_statement->Copy());
rebound_plan = std::move(prepared_planner.plan);
D_ASSERT(prepared->properties.bound_all_parameters);
Expand Down
24 changes: 24 additions & 0 deletions test/api/capi/test_capi_prepared.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,30 @@ TEST_CASE("Test prepared statements with named parameters in C API", "[capi]") {
duckdb_destroy_prepare(&stmt);
}

TEST_CASE("Maintain prepared statement types", "[capi]") {
CAPITester tester;
duckdb::unique_ptr<CAPIResult> result;
duckdb_result res;
duckdb_prepared_statement stmt = nullptr;
duckdb_state status;

// open the database in in-memory mode
REQUIRE(tester.OpenDatabase(nullptr));

status = duckdb_prepare(tester.connection, "select cast(111 as short) * $1", &stmt);
REQUIRE(status == DuckDBSuccess);
REQUIRE(stmt != nullptr);

status = duckdb_bind_int64(stmt, 1, 1665);
REQUIRE(status == DuckDBSuccess);

status = duckdb_execute_prepared(stmt, &res);
REQUIRE(status == DuckDBSuccess);
REQUIRE(duckdb_value_int64(&res, 0, 0) == 184815);
duckdb_destroy_result(&res);
duckdb_destroy_prepare(&stmt);
}

TEST_CASE("Prepared streaming result", "[capi]") {
CAPITester tester;

Expand Down
Loading
Loading