From 81edfd1ec2a0e34c3bad9f341a22110b7817e175 Mon Sep 17 00:00:00 2001 From: Alex Qyoun-ae <4062971+MazterQyou@users.noreply.github.com> Date: Wed, 9 Oct 2024 17:09:21 +0400 Subject: [PATCH] chore(cubesql): Do not call async Node functions while planning --- .../cubesql/src/compile/query_engine.rs | 79 ++++++++++- .../cubesql/src/compile/rewrite/rewriter.rs | 8 +- rust/cubesql/cubesql/src/compile/router.rs | 48 +------ .../cubesql/cubesql/src/sql/compiler_cache.rs | 132 +++++++++++------- rust/cubesql/cubesql/src/sql/postgres/shim.rs | 21 ++- 5 files changed, 180 insertions(+), 108 deletions(-) diff --git a/rust/cubesql/cubesql/src/compile/query_engine.rs b/rust/cubesql/cubesql/src/compile/query_engine.rs index ac49e217b88fc..b4ff360d657c6 100644 --- a/rust/cubesql/cubesql/src/compile/query_engine.rs +++ b/rust/cubesql/cubesql/src/compile/query_engine.rs @@ -1,5 +1,8 @@ use crate::compile::engine::df::planner::CubeQueryPlanner; -use std::{backtrace::Backtrace, collections::HashMap, future::Future, pin::Pin, sync::Arc}; +use std::{ + backtrace::Backtrace, collections::HashMap, future::Future, pin::Pin, sync::Arc, + time::SystemTime, +}; use crate::{ compile::{ @@ -43,6 +46,7 @@ use datafusion::{ sql::{parser::Statement as DFStatement, planner::SqlToRel}, variable::VarType, }; +use uuid::Uuid; #[async_trait::async_trait] pub trait QueryEngine { @@ -74,6 +78,11 @@ pub trait QueryEngine { fn sanitize_statement(&self, stmt: &Self::AstStatementType) -> Self::AstStatementType; + async fn get_compiler_id_and_refresh_cache_if_needed( + &self, + state: Arc, + ) -> Result; + async fn plan( &self, stmt: Self::AstStatementType, @@ -82,6 +91,28 @@ pub trait QueryEngine { meta: Arc, state: Arc, ) -> CompilationResult { + let compiler_id = self + .get_compiler_id_and_refresh_cache_if_needed(state.clone()) + .await?; + + let planning_start = SystemTime::now(); + if let Some(span_id) = span_id.as_ref() { + if let Some(auth_context) = state.auth_context() { + self.transport_ref() + .log_load_state( + Some(span_id.clone()), + auth_context, + state.get_load_request_meta(), + "SQL API Query Planning".to_string(), + serde_json::json!({ + "query": span_id.query_key.clone(), + }), + ) + .await + .map_err(|e| CompilationError::internal(e.to_string()))?; + } + } + let ctx = self.create_session_ctx(state.clone())?; let cube_ctx = self.create_cube_ctx(state.clone(), meta.clone(), ctx.clone())?; @@ -140,7 +171,7 @@ pub trait QueryEngine { let mut finalized_graph = self .compiler_cache_ref() .rewrite( - state.auth_context().unwrap(), + compiler_id, cube_ctx.clone(), converter.take_egraph(), &query_params.unwrap(), @@ -186,7 +217,13 @@ pub trait QueryEngine { let mut rewriter = Rewriter::new(finalized_graph, cube_ctx.clone()); let result = rewriter - .find_best_plan(root, state.auth_context().unwrap(), qtrace, span_id.clone()) + .find_best_plan( + root, + compiler_id, + state.auth_context().unwrap(), + qtrace, + span_id.clone(), + ) .await .map_err(|e| match e.cause { CubeErrorCauseType::Internal(_) => CompilationError::Internal( @@ -233,12 +270,31 @@ pub trait QueryEngine { // TODO: We should find what optimizers will be safety to use for OLAP queries guard.optimizer.rules = vec![]; } - if let Some(span_id) = span_id { + if let Some(span_id) = &span_id { span_id.set_is_data_query(true).await; } }; log::debug!("Rewrite: {:#?}", rewrite_plan); + + if let Some(span_id) = span_id.as_ref() { + if let Some(auth_context) = state.auth_context() { + self.transport_ref() + .log_load_state( + Some(span_id.clone()), + auth_context, + state.get_load_request_meta(), + "SQL API Query Planning Success".to_string(), + serde_json::json!({ + "query": span_id.query_key.clone(), + "duration": planning_start.elapsed().unwrap().as_millis() as u64, + }), + ) + .await + .map_err(|e| CompilationError::internal(e.to_string()))?; + } + } + let rewrite_plan = Self::evaluate_wrapped_sql( self.transport_ref().clone(), Arc::new(state.get_load_request_meta()), @@ -493,6 +549,21 @@ impl QueryEngine for SqlQueryEngine { fn sanitize_statement(&self, stmt: &Self::AstStatementType) -> Self::AstStatementType { SensitiveDataSanitizer::new().replace(stmt.clone()) } + + async fn get_compiler_id_and_refresh_cache_if_needed( + &self, + state: Arc, + ) -> Result { + self.compiler_cache_ref() + .get_compiler_id_and_refresh_if_needed( + state.auth_context().ok_or_else(|| { + CompilationError::internal("Unable to get auth context".to_string()) + })?, + state.protocol.clone(), + ) + .await + .map_err(|e| CompilationError::internal(e.to_string())) + } } fn is_olap_query(parent: &LogicalPlan) -> Result { diff --git a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs index 4841e5b3ffc1d..90f993d16cf51 100644 --- a/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs +++ b/rust/cubesql/cubesql/src/compile/rewrite/rewriter.rs @@ -31,6 +31,7 @@ use std::{ sync::Arc, time::Duration, }; +use uuid::Uuid; pub struct Rewriter { graph: EGraph, @@ -229,7 +230,7 @@ impl Rewriter { pub async fn run_rewrite_to_completion( &mut self, - auth_context: AuthContextRef, + compiler_id: Uuid, qtrace: &mut Option, ) -> Result, CubeError> { let cube_context = self.cube_context.clone(); @@ -243,7 +244,7 @@ impl Rewriter { .server .compiler_cache .rewrite_rules( - auth_context.clone(), + compiler_id, cube_context.session_state.protocol.clone(), false, ) @@ -311,6 +312,7 @@ impl Rewriter { pub async fn find_best_plan( &mut self, root: Id, + compiler_id: Uuid, auth_context: AuthContextRef, qtrace: &mut Option, span_id: Option>, @@ -326,7 +328,7 @@ impl Rewriter { .server .compiler_cache .rewrite_rules( - auth_context.clone(), + compiler_id, cube_context.session_state.protocol.clone(), true, ) diff --git a/rust/cubesql/cubesql/src/compile/router.rs b/rust/cubesql/cubesql/src/compile/router.rs index 53b3bb0a50a6c..af8f5449e9934 100644 --- a/rust/cubesql/cubesql/src/compile/router.rs +++ b/rust/cubesql/cubesql/src/compile/router.rs @@ -3,7 +3,7 @@ use crate::compile::{ StatusFlags, }; use sqlparser::ast; -use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc, time::SystemTime}; +use std::{collections::HashMap, future::Future, pin::Pin, sync::Arc}; use crate::{ compile::{ @@ -61,50 +61,8 @@ impl QueryRouter { qtrace: &mut Option, span_id: Option>, ) -> CompilationResult { - let planning_start = SystemTime::now(); - if let Some(span_id) = span_id.as_ref() { - if let Some(auth_context) = self.state.auth_context() { - self.session_manager - .server - .transport - .log_load_state( - Some(span_id.clone()), - auth_context, - self.state.get_load_request_meta(), - "SQL API Query Planning".to_string(), - serde_json::json!({ - "query": span_id.query_key.clone(), - }), - ) - .await - .map_err(|e| CompilationError::internal(e.to_string()))?; - } - } - let result = self - .create_df_logical_plan(stmt.clone(), qtrace, span_id.clone()) - .await?; - - if let Some(span_id) = span_id.as_ref() { - if let Some(auth_context) = self.state.auth_context() { - self.session_manager - .server - .transport - .log_load_state( - Some(span_id.clone()), - auth_context, - self.state.get_load_request_meta(), - "SQL API Query Planning Success".to_string(), - serde_json::json!({ - "query": span_id.query_key.clone(), - "duration": planning_start.elapsed().unwrap().as_millis() as u64, - }), - ) - .await - .map_err(|e| CompilationError::internal(e.to_string()))?; - } - } - - return Ok(result); + self.create_df_logical_plan(stmt.clone(), qtrace, span_id.clone()) + .await } pub async fn plan( diff --git a/rust/cubesql/cubesql/src/sql/compiler_cache.rs b/rust/cubesql/cubesql/src/sql/compiler_cache.rs index 220d62cd713ad..1c18b617810ed 100644 --- a/rust/cubesql/cubesql/src/sql/compiler_cache.rs +++ b/rust/cubesql/cubesql/src/sql/compiler_cache.rs @@ -22,20 +22,20 @@ use uuid::Uuid; pub trait CompilerCache: Send + Sync + Debug { async fn rewrite_rules( &self, - ctx: AuthContextRef, + compiler_id: Uuid, protocol: DatabaseProtocol, eval_stable_functions: bool, ) -> Result>>, CubeError>; async fn meta( &self, - ctx: AuthContextRef, + compiler_id: Uuid, protocol: DatabaseProtocol, ) -> Result, CubeError>; async fn parameterized_rewrite( &self, - ctx: AuthContextRef, + compiler_id: Uuid, cube_context: Arc, input_plan: EGraph, qtrace: &mut Option, @@ -43,12 +43,18 @@ pub trait CompilerCache: Send + Sync + Debug { async fn rewrite( &self, - ctx: AuthContextRef, + compiler_id: Uuid, cube_context: Arc, input_plan: EGraph, param_values: &HashMap, qtrace: &mut Option, ) -> Result, CubeError>; + + async fn get_compiler_id_and_refresh_if_needed( + &self, + ctx: AuthContextRef, + protocol: DatabaseProtocol, + ) -> Result; } #[derive(Debug)] @@ -73,11 +79,11 @@ crate::di_service!(CompilerCacheImpl, [CompilerCache]); impl CompilerCache for CompilerCacheImpl { async fn rewrite_rules( &self, - ctx: AuthContextRef, + compiler_id: Uuid, protocol: DatabaseProtocol, eval_stable_functions: bool, ) -> Result>>, CubeError> { - let cache_entry = self.get_cache_entry(ctx.clone(), protocol).await?; + let cache_entry = self.get_cache_entry(compiler_id, protocol).await?; let rewrite_rules = { cache_entry @@ -108,22 +114,22 @@ impl CompilerCache for CompilerCacheImpl { async fn meta( &self, - ctx: AuthContextRef, + compiler_id: Uuid, protocol: DatabaseProtocol, ) -> Result, CubeError> { - let cache_entry = self.get_cache_entry(ctx.clone(), protocol).await?; + let cache_entry = self.get_cache_entry(compiler_id, protocol).await?; Ok(cache_entry.meta_context.clone()) } async fn parameterized_rewrite( &self, - ctx: AuthContextRef, + compiler_id: Uuid, cube_context: Arc, parameterized_graph: EGraph, qtrace: &mut Option, ) -> Result, CubeError> { let cache_entry = self - .get_cache_entry(ctx.clone(), cube_context.session_state.protocol.clone()) + .get_cache_entry(compiler_id, cube_context.session_state.protocol.clone()) .await?; let graph_key = egraph_hash(¶meterized_graph, None); @@ -134,7 +140,7 @@ impl CompilerCache for CompilerCacheImpl { } else { let mut rewriter = Rewriter::new(parameterized_graph, cube_context); let rewrite_entry = rewriter - .run_rewrite_to_completion(ctx.clone(), qtrace) + .run_rewrite_to_completion(compiler_id, qtrace) .await?; rewrites_cache_lock.put(graph_key, rewrite_entry.clone()); Ok(rewrite_entry) @@ -143,7 +149,7 @@ impl CompilerCache for CompilerCacheImpl { async fn rewrite( &self, - ctx: AuthContextRef, + compiler_id: Uuid, cube_context: Arc, input_plan: EGraph, param_values: &HashMap, @@ -152,10 +158,12 @@ impl CompilerCache for CompilerCacheImpl { if !self.config_obj.enable_rewrite_cache() { let mut rewriter = Rewriter::new(input_plan, cube_context); rewriter.add_param_values(param_values)?; - return Ok(rewriter.run_rewrite_to_completion(ctx, qtrace).await?); + return Ok(rewriter + .run_rewrite_to_completion(compiler_id, qtrace) + .await?); } let cache_entry = self - .get_cache_entry(ctx.clone(), cube_context.session_state.protocol.clone()) + .get_cache_entry(compiler_id, cube_context.session_state.protocol.clone()) .await?; let graph_key = egraph_hash(&input_plan, Some(param_values)); @@ -165,18 +173,61 @@ impl CompilerCache for CompilerCacheImpl { Ok(plan.clone()) } else { let graph = if self.config_obj.enable_parameterized_rewrite_cache() { - self.parameterized_rewrite(ctx.clone(), cube_context.clone(), input_plan, qtrace) + self.parameterized_rewrite(compiler_id, cube_context.clone(), input_plan, qtrace) .await? } else { input_plan }; let mut rewriter = Rewriter::new(graph, cube_context); rewriter.add_param_values(param_values)?; - let final_plan = rewriter.run_rewrite_to_completion(ctx, qtrace).await?; + let final_plan = rewriter + .run_rewrite_to_completion(compiler_id, qtrace) + .await?; rewrites_cache_lock.put(graph_key, final_plan.clone()); Ok(final_plan) } } + + async fn get_compiler_id_and_refresh_if_needed( + &self, + ctx: AuthContextRef, + protocol: DatabaseProtocol, + ) -> Result { + let compiler_id = self.transport.compiler_id(ctx.clone()).await?; + let has_entry = { + self.compiler_id_to_entry + .lock() + .await + .contains(&(compiler_id, protocol.clone())) + }; + if has_entry { + return Ok(compiler_id); + } + + let meta_context = self.transport.meta(ctx).await?; + let compiler_id = { + let mut compiler_id_to_entry = self.compiler_id_to_entry.lock().await; + if !compiler_id_to_entry.contains(&(meta_context.compiler_id, protocol.clone())) { + let cache_entry = Arc::new(CompilerCacheEntry { + meta_context: meta_context.clone(), + rewrite_rules: RWLockAsync::new(HashMap::new()), + parameterized_cache: MutexAsync::new(LruCache::new( + NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), + )), + queries_cache: MutexAsync::new(LruCache::new( + NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), + )), + }); + compiler_id_to_entry.put( + (meta_context.compiler_id.clone(), protocol.clone()), + cache_entry.clone(), + ); + } + meta_context.compiler_id + }; + + Ok(compiler_id) + } } impl CompilerCacheImpl { @@ -193,44 +244,19 @@ impl CompilerCacheImpl { pub async fn get_cache_entry( &self, - ctx: AuthContextRef, + compiler_id: Uuid, protocol: DatabaseProtocol, ) -> Result, CubeError> { - let compiler_id = self.transport.compiler_id(ctx.clone()).await?; - let cache_entry = { - self.compiler_id_to_entry - .lock() - .await - .get(&(compiler_id, protocol.clone())) - .cloned() - }; - // Double checked locking - let cache_entry = if let Some(cache_entry) = cache_entry { - cache_entry - } else { - let meta_context = self.transport.meta(ctx.clone()).await?; - let mut compiler_id_to_entry = self.compiler_id_to_entry.lock().await; - compiler_id_to_entry - .get(&(meta_context.compiler_id, protocol.clone())) - .cloned() - .unwrap_or_else(|| { - let cache_entry = Arc::new(CompilerCacheEntry { - meta_context: meta_context.clone(), - rewrite_rules: RWLockAsync::new(HashMap::new()), - parameterized_cache: MutexAsync::new(LruCache::new( - NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), - )), - queries_cache: MutexAsync::new(LruCache::new( - NonZeroUsize::new(self.config_obj.query_cache_size()).unwrap(), - )), - }); - compiler_id_to_entry.put( - (meta_context.compiler_id.clone(), protocol.clone()), - cache_entry.clone(), - ); - cache_entry - }) - }; - Ok(cache_entry) + self.compiler_id_to_entry + .lock() + .await + .get(&(compiler_id, protocol.clone())) + .cloned() + .ok_or_else(|| { + CubeError::internal(format!( + "Compiler cache entry for {:?} not found", + compiler_id, + )) + }) } } diff --git a/rust/cubesql/cubesql/src/sql/postgres/shim.rs b/rust/cubesql/cubesql/src/sql/postgres/shim.rs index 8de9ac405c5ff..734c9464b4c5a 100644 --- a/rust/cubesql/cubesql/src/sql/postgres/shim.rs +++ b/rust/cubesql/cubesql/src/sql/postgres/shim.rs @@ -240,6 +240,18 @@ impl AsyncPostgresShim { return Ok(()); } + async fn get_compiler_id_and_refresh_cache_if_needed(&self) -> Result { + self.session + .session_manager + .server + .compiler_cache + .get_compiler_id_and_refresh_if_needed( + self.auth_context()?, + self.session.state.protocol.clone(), + ) + .await + } + pub async fn run_on( fast_shutdown_interruptor: CancellationToken, semifast_shutdown_interruptor: CancellationToken, @@ -1021,11 +1033,12 @@ impl AsyncPostgresShim { source_statement.bind(body.to_bind_values(¶meters)?)?; drop(statements_guard); + let compiler_id = self.get_compiler_id_and_refresh_cache_if_needed().await?; let meta = self .session .server .compiler_cache - .meta(self.auth_context()?, self.session.state.protocol.clone()) + .meta(compiler_id, self.session.state.protocol.clone()) .await?; let plan = convert_statement_to_cube_query( @@ -1113,11 +1126,12 @@ impl AsyncPostgresShim { .map(|param| param.coltype.to_pg_tid()) .collect(); + let compiler_id = self.get_compiler_id_and_refresh_cache_if_needed().await?; let meta = self .session .server .compiler_cache - .meta(self.auth_context()?, self.session.state.protocol.clone()) + .meta(compiler_id, self.session.state.protocol.clone()) .await?; let stmt_replacer = StatementPlaceholderReplacer::new(); @@ -1709,11 +1723,12 @@ impl AsyncPostgresShim { qtrace: &mut Option, span_id: Option>, ) -> Result<(), ConnectionError> { + let compiler_id = self.get_compiler_id_and_refresh_cache_if_needed().await?; let meta = self .session .server .compiler_cache - .meta(self.auth_context()?, self.session.state.protocol.clone()) + .meta(compiler_id, self.session.state.protocol.clone()) .await?; let statements =