From 15540b0a3171fe4618b3b757d703796a7837693b Mon Sep 17 00:00:00 2001 From: joshua-spacetime Date: Tue, 9 Jun 2026 12:59:52 -0700 Subject: [PATCH 1/2] Parameterized query plans --- crates/core/src/host/module_host.rs | 7 +- crates/core/src/subscription/delta.rs | 10 +- crates/core/src/subscription/metrics.rs | 2 +- crates/core/src/subscription/mod.rs | 21 +- .../subscription/module_subscription_actor.rs | 15 +- .../module_subscription_manager.rs | 103 +++-- crates/core/src/subscription/query.rs | 17 +- crates/core/src/subscription/subscription.rs | 6 +- crates/execution/src/pipelined.rs | 55 ++- crates/expr/src/check.rs | 3 +- crates/expr/src/expr.rs | 72 +++- crates/expr/src/lib.rs | 20 +- crates/expr/src/rls.rs | 8 +- crates/expr/src/statement.rs | 4 +- crates/physical-plan/src/compile.rs | 1 + crates/physical-plan/src/dml.rs | 23 +- crates/physical-plan/src/plan.rs | 377 ++++++++++++++++-- crates/physical-plan/src/rules.rs | 36 +- crates/query/src/lib.rs | 8 +- crates/subscription/src/lib.rs | 104 +++-- 20 files changed, 712 insertions(+), 180 deletions(-) diff --git a/crates/core/src/host/module_host.rs b/crates/core/src/host/module_host.rs index 9ce704426f7..3be9b535ad3 100644 --- a/crates/core/src/host/module_host.rs +++ b/crates/core/src/host/module_host.rs @@ -53,7 +53,7 @@ use spacetimedb_datastore::traits::{IsolationLevel, Program, TxData}; pub use spacetimedb_durability::{DurabilityExited, DurableOffset}; use spacetimedb_execution::pipelined::{PipelinedProject, ViewProject}; use spacetimedb_execution::RelValue; -use spacetimedb_expr::expr::CollectViews; +use spacetimedb_expr::expr::{BindEnv, CollectViews}; use spacetimedb_lib::db::raw_def::v9::Lifecycle; use spacetimedb_lib::http::{Request as HttpRequest, Response as HttpResponse}; use spacetimedb_lib::identity::{AuthCtx, RequestId}; @@ -3272,13 +3272,14 @@ impl ModuleHost { plans, _, table_name, - _, + requires_sender_binding, ) = compile_subscription(query, &schema_tx, auth)?; + let bind_env = BindEnv::for_sender_binding(requires_sender_binding, auth.caller()); // Optimize each fragment. let optimized = plans .into_iter() - .map(|plan| plan.optimize(auth)) + .map(|plan| plan.optimize(auth).map(|plan| plan.bind_params(&bind_env))) .collect::, _>>()?; check_row_limit( diff --git a/crates/core/src/subscription/delta.rs b/crates/core/src/subscription/delta.rs index b2747a933e2..fd7c91088fa 100644 --- a/crates/core/src/subscription/delta.rs +++ b/crates/core/src/subscription/delta.rs @@ -2,6 +2,7 @@ use crate::host::module_host::UpdatesRelValue; use anyhow::Result; use spacetimedb_data_structures::map::{HashCollectionExt as _, HashMap}; use spacetimedb_execution::{Datastore, DeltaStore, RelValue, Row}; +use spacetimedb_expr::expr::BindEnv; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_primitives::ColList; use spacetimedb_sats::product_value::InvalidFieldError; @@ -20,6 +21,7 @@ pub fn eval_delta<'a, Tx: Datastore + DeltaStore>( tx: &'a Tx, metrics: &mut ExecutionMetrics, plan: &SubscriptionPlan, + bind_env: &BindEnv, ) -> Result>> { metrics.delta_queries_evaluated += 1; @@ -42,12 +44,12 @@ pub fn eval_delta<'a, Tx: Datastore + DeltaStore>( if !plan.is_join() { // Single table plans will never return redundant rows, // so there's no need to track row counts. - plan.for_each_insert(tx, metrics, &mut |row| { + plan.for_each_insert(bind_env, tx, metrics, &mut |row| { inserts.push(maybe_project(row)?); Ok(()) })?; - plan.for_each_delete(tx, metrics, &mut |row| { + plan.for_each_delete(bind_env, tx, metrics, &mut |row| { deletes.push(maybe_project(row)?); Ok(()) })?; @@ -57,7 +59,7 @@ pub fn eval_delta<'a, Tx: Datastore + DeltaStore>( let mut insert_counts = HashMap::new(); let mut delete_counts = HashMap::new(); - plan.for_each_insert(tx, metrics, &mut |row| { + plan.for_each_insert(bind_env, tx, metrics, &mut |row| { let row = maybe_project(row)?; let n = insert_counts.entry(row).or_default(); if *n > 0 { @@ -67,7 +69,7 @@ pub fn eval_delta<'a, Tx: Datastore + DeltaStore>( Ok(()) })?; - plan.for_each_delete(tx, metrics, &mut |row| { + plan.for_each_delete(bind_env, tx, metrics, &mut |row| { let row = maybe_project(row)?; match insert_counts.get_mut(&row) { // We have not seen an insert for this row. diff --git a/crates/core/src/subscription/metrics.rs b/crates/core/src/subscription/metrics.rs index b367a57056b..1247236ad5a 100644 --- a/crates/core/src/subscription/metrics.rs +++ b/crates/core/src/subscription/metrics.rs @@ -64,7 +64,7 @@ fn extract_columns( extract_columns(expr, schema, columns); } } - PhysicalExpr::Value(_) => {} + PhysicalExpr::Value(_) | PhysicalExpr::Param(..) => {} } } diff --git a/crates/core/src/subscription/mod.rs b/crates/core/src/subscription/mod.rs index 40d6d712504..1614e119371 100644 --- a/crates/core/src/subscription/mod.rs +++ b/crates/core/src/subscription/mod.rs @@ -256,14 +256,27 @@ pub fn execute_plans( ) -> Result<(ws_v1::DatabaseUpdate, ExecutionMetrics, Vec), DBError> { plans .par_iter() - .flat_map_iter(|plan| plan.plans_fragments().map(|fragment| (plan.sql(), fragment))) - .filter(|(_, plan)| { + .flat_map_iter(|plan| { + plan.plans_fragments() + .map(|fragment| (plan.sql(), plan.bind_env(), fragment)) + }) + .filter(|(_, _, plan)| { // Since subscriptions only support selects and inner joins, // we filter out any plans that read from an empty table. plan.table_ids().all(|table_id| tx.row_count(table_id) > 0) }) - .map(|(sql, plan)| (sql, plan, plan.subscribed_table_id(), plan.subscribed_table_name())) - .map(|(sql, plan, table_id, table_name)| (sql, plan.optimized_physical_plan().clone(), table_id, table_name)) + .map(|(sql, bind_env, plan)| { + ( + sql, + bind_env, + plan, + plan.subscribed_table_id(), + plan.subscribed_table_name(), + ) + }) + .map(|(sql, bind_env, plan, table_id, table_name)| { + (sql, plan.bound_optimized_physical_plan(bind_env), table_id, table_name) + }) .map(|(sql, plan, table_id, table_name)| (sql, plan.optimize(auth), table_id, table_name)) .map(|(sql, plan, table_id, table_name)| { plan.and_then(|plan| { diff --git a/crates/core/src/subscription/module_subscription_actor.rs b/crates/core/src/subscription/module_subscription_actor.rs index aede13262df..aa4392c088e 100644 --- a/crates/core/src/subscription/module_subscription_actor.rs +++ b/crates/core/src/subscription/module_subscription_actor.rs @@ -427,7 +427,9 @@ impl ModuleSubscriptions { tx, |plan, tx| { plan.plans_fragments() - .map(|plan_fragment| estimate_rows_scanned(tx, plan_fragment.optimized_physical_plan())) + .map(|plan_fragment| { + estimate_rows_scanned(tx, &plan_fragment.bound_optimized_physical_plan(plan.bind_env())) + }) .fold(0, |acc, rows_scanned| acc.saturating_add(rows_scanned)) }, auth, @@ -438,8 +440,7 @@ impl ModuleSubscriptions { let plans = query .plans_fragments() - .map(|fragment| fragment.optimized_physical_plan()) - .cloned() + .map(|fragment| fragment.bound_optimized_physical_plan(query.bind_env())) .map(|plan| plan.optimize(auth)) .collect::, _>>()?; @@ -533,7 +534,9 @@ impl ModuleSubscriptions { tx, |plan, tx| { plan.plans_fragments() - .map(|plan_fragment| estimate_rows_scanned(tx, plan_fragment.optimized_physical_plan())) + .map(|plan_fragment| { + estimate_rows_scanned(tx, &plan_fragment.bound_optimized_physical_plan(plan.bind_env())) + }) .fold(0, |acc, rows_scanned| acc.saturating_add(rows_scanned)) }, auth, @@ -1538,7 +1541,9 @@ impl ModuleSubscriptions { &tx, |plan, tx| { plan.plans_fragments() - .map(|plan_fragment| estimate_rows_scanned(tx, plan_fragment.optimized_physical_plan())) + .map(|plan_fragment| { + estimate_rows_scanned(tx, &plan_fragment.bound_optimized_physical_plan(plan.bind_env())) + }) .fold(0, |acc, rows_scanned| acc.saturating_add(rows_scanned)) }, &auth, diff --git a/crates/core/src/subscription/module_subscription_manager.rs b/crates/core/src/subscription/module_subscription_manager.rs index 0700846fc74..07907e71862 100644 --- a/crates/core/src/subscription/module_subscription_manager.rs +++ b/crates/core/src/subscription/module_subscription_manager.rs @@ -27,7 +27,7 @@ use spacetimedb_data_structures::map::{ }; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_durability::TxOffset; -use spacetimedb_expr::expr::CollectViews; +use spacetimedb_expr::expr::{BindEnv, CollectViews}; use spacetimedb_lib::metrics::ExecutionMetrics; use spacetimedb_lib::{AlgebraicValue, ConnectionId, Identity, ProductValue}; use spacetimedb_primitives::{ColId, IndexId, TableId, ViewId}; @@ -45,7 +45,7 @@ use tokio::sync::{mpsc, oneshot}; /// Identity is insufficient because different ConnectionIds can use the same Identity. /// TODO: Determine if ConnectionId is sufficient for uniquely identifying a client. type ClientId = (Identity, ConnectionId); -type Query = Arc; +type Query = Arc; type Client = Arc; type SwitchedTableUpdate = ws_v1::FormatSwitch, ws_v1::TableUpdate>; @@ -60,14 +60,18 @@ type ClientQueryId = ws_v1::QueryId; type SubscriptionId = (ClientId, ClientQueryId); type SubscriptionIdV2 = (ClientId, ClientQuerySetId); +/// The unbound, reusable subscription query plan. +/// +/// Runtime values such as `:sender` are represented as formal parameters inside these plans. +/// TODO: Intern and share `PlanTemplate`s across bound [`PlanInstance`]s with the same template hash. +/// Today, [`PlanInstance::new`] still allocates a fresh template for each cached subscription instance. #[derive(Debug)] -pub struct Plan { - hash: QueryHash, +pub struct PlanTemplate { sql: String, plans: Vec, } -impl CollectViews for Plan { +impl CollectViews for PlanTemplate { fn collect_views(&self, views: &mut HashSet) { for plan in &self.plans { plan.collect_views(views); @@ -75,10 +79,42 @@ impl CollectViews for Plan { } } -impl Plan { - /// Create a new subscription plan to be cached - pub fn new(plans: Vec, hash: QueryHash, text: String) -> Self { - Self { plans, hash, sql: text } +impl PlanTemplate { + pub fn new(plans: Vec, text: String) -> Self { + Self { sql: text, plans } + } +} + +/// A concrete subscription query instance with runtime parameter bindings. +#[derive(Debug)] +pub struct PlanInstance { + hash: QueryHash, + template: Arc, + bind_env: BindEnv, +} + +/// A subscription plan tracked by the subscription manager. +pub type Plan = PlanInstance; + +impl CollectViews for PlanInstance { + fn collect_views(&self, views: &mut HashSet) { + self.template.collect_views(views); + } +} + +impl PlanInstance { + /// Create a new subscription plan instance to be cached. + pub fn new(plans: Vec, hash: QueryHash, text: String, bind_env: BindEnv) -> Self { + Self::from_template(Arc::new(PlanTemplate::new(plans, text)), hash, bind_env) + } + + /// Create a new subscription plan instance from an existing plan template. + pub fn from_template(template: Arc, hash: QueryHash, bind_env: BindEnv) -> Self { + Self { + hash, + template, + bind_env, + } } /// Returns the query hash for this subscription @@ -89,18 +125,19 @@ impl Plan { /// A subscription query return rows from a single table. /// This method returns the id of that table. pub fn subscribed_table_id(&self) -> TableId { - self.plans[0].subscribed_table_id() + self.template.plans[0].subscribed_table_id() } /// A subscription query return rows from a single table. /// This method returns the name of that table. pub fn subscribed_table_name(&self) -> &TableName { - self.plans[0].subscribed_table_name() + self.template.plans[0].subscribed_table_name() } /// Returns the index ids from which this subscription reads pub fn index_ids(&self) -> impl Iterator + use<> { - self.plans + self.template + .plans .iter() .flat_map(|plan| plan.index_ids()) .collect::>() @@ -109,7 +146,8 @@ impl Plan { /// Returns the table ids from which this subscription reads pub fn table_ids(&self) -> impl Iterator + '_ { - self.plans + self.template + .plans .iter() .flat_map(|plan| plan.table_ids()) .collect::>() @@ -120,9 +158,10 @@ impl Plan { fn search_args(&self) -> impl Iterator + use<> { let mut args = HashSet::new(); for arg in self + .template .plans .iter() - .flat_map(|subscription| subscription.optimized_physical_plan().search_args()) + .flat_map(|subscription| subscription.bound_optimized_physical_plan(&self.bind_env).search_args()) { args.insert(arg); } @@ -132,22 +171,30 @@ impl Plan { /// Returns the plan fragments that comprise this subscription. /// Will only return one element unless there is a table with multiple RLS rules. pub fn plans_fragments(&self) -> impl Iterator + '_ { - self.plans.iter() + self.template.plans.iter() } /// Returns the join edges for this plan, if any. pub fn join_edges(&self) -> impl Iterator + '_ { - self.plans.iter().filter_map(|plan| plan.join_edge()) + self.template + .plans + .iter() + .filter_map(|plan| plan.bound_join_edge(&self.bind_env)) } /// The `SQL` text of this subscription. pub fn sql(&self) -> &str { - &self.sql + &self.template.sql + } + + /// Runtime parameter bindings for this subscription instance. + pub fn bind_env(&self) -> &BindEnv { + &self.bind_env } /// Does this plan return rows from an event table? pub fn returns_event_table(&self) -> bool { - self.plans.iter().any(|p| p.returns_event_table()) + self.template.plans.iter().any(|p| p.returns_event_table()) } } @@ -1480,15 +1527,15 @@ impl SubscriptionManager { }) .fold(FoldState::default(), |mut acc, (qstate, plan, _hash)| { let table_name = plan.subscribed_table_name().clone(); - match eval_delta(tx, &mut acc.metrics, plan) { + match eval_delta(tx, &mut acc.metrics, plan, qstate.query.bind_env()) { Err(err) => { tracing::error!( message = "Query errored during tx update", - sql = qstate.query.sql, + sql = qstate.query.sql(), reason = ?err, ); let err = DBError::WithSql { - sql: qstate.query.sql.as_str().into(), + sql: qstate.query.sql().into(), error: Box::new(err.into()), } .to_string() @@ -1637,15 +1684,15 @@ impl SubscriptionManager { let clients_for_query = qstate.all_v1_clients(); - match eval_delta(tx, &mut acc.metrics, plan) { + match eval_delta(tx, &mut acc.metrics, plan, qstate.query.bind_env()) { Err(err) => { tracing::error!( message = "Query errored during tx update", - sql = qstate.query.sql, + sql = qstate.query.sql(), reason = ?err, ); let err = DBError::WithSql { - sql: qstate.query.sql.as_str().into(), + sql: qstate.query.sql().into(), error: Box::new(err.into()), } .to_string() @@ -2191,6 +2238,7 @@ mod tests { use std::{sync::Arc, time::Duration}; use spacetimedb_client_api_messages::websocket::{v1 as ws_v1, v2 as ws_v2}; + use spacetimedb_expr::expr::BindEnv; use spacetimedb_lib::AlgebraicValue; use spacetimedb_lib::{error::ResultTest, identity::AuthCtx, AlgebraicType, ConnectionId, Identity, Timestamp}; use spacetimedb_primitives::{ColId, TableId}; @@ -2231,9 +2279,10 @@ mod tests { fn compile_plan_with_auth(db: &RelationalDB, sql: &str, auth: AuthCtx) -> ResultTest> { with_read_only(db, |tx| { let tx = SchemaViewer::new(&*tx, &auth); - let (plans, has_param) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap(); - let hash = QueryHash::from_string(sql, auth.caller(), has_param); - Ok(Arc::new(Plan::new(plans, hash, sql.into()))) + let (plans, requires_sender_binding) = SubscriptionPlan::compile(sql, &tx, &auth).unwrap(); + let hash = QueryHash::from_string(sql, auth.caller(), requires_sender_binding); + let bind_env = BindEnv::for_sender_binding(requires_sender_binding, auth.caller()); + Ok(Arc::new(Plan::new(plans, hash, sql.into(), bind_env))) }) } diff --git a/crates/core/src/subscription/query.rs b/crates/core/src/subscription/query.rs index bea30f96b7f..ba4c6bd243f 100644 --- a/crates/core/src/subscription/query.rs +++ b/crates/core/src/subscription/query.rs @@ -7,6 +7,7 @@ use once_cell::sync::Lazy; use regex::Regex; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; use spacetimedb_execution::Datastore; +use spacetimedb_expr::expr::BindEnv; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_subscription::SubscriptionPlan; @@ -31,9 +32,10 @@ pub fn compile_read_only_query(auth: &AuthCtx, tx: &Tx, input: &str) -> Result

( } let tx = SchemaViewer::new(tx, auth); - let (plans, has_param) = SubscriptionPlan::compile(input, &tx, auth)?; + let (plans, requires_sender_binding) = SubscriptionPlan::compile(input, &tx, auth)?; + let bind_env = BindEnv::for_sender_binding(requires_sender_binding, auth.caller()); - if auth.bypass_rls() || has_param { - return Ok(Plan::new(plans, hash_with_param, input.to_owned())); + if auth.bypass_rls() || requires_sender_binding { + return Ok(Plan::new(plans, hash_with_param, input.to_owned(), bind_env)); } - Ok(Plan::new(plans, hash, input.to_owned())) + Ok(Plan::new(plans, hash, input.to_owned(), bind_env)) } diff --git a/crates/core/src/subscription/subscription.rs b/crates/core/src/subscription/subscription.rs index a9f3bd12f62..fa59d1f89f2 100644 --- a/crates/core/src/subscription/subscription.rs +++ b/crates/core/src/subscription/subscription.rs @@ -4,6 +4,7 @@ use crate::db::relational_db::RelationalDB; use crate::error::DBError; use crate::sql::ast::SchemaViewer; use spacetimedb_datastore::locking_tx_datastore::state_view::StateView; +use spacetimedb_expr::expr::BindEnv; use spacetimedb_lib::db::auth::StTableType; use spacetimedb_lib::identity::AuthCtx; use spacetimedb_schema::schema::TableSchema; @@ -27,11 +28,12 @@ where .map(|schema| { let sql = format!("SELECT * FROM {}", schema.table_name); let tx = SchemaViewer::new(tx, auth); - SubscriptionPlan::compile(&sql, &tx, auth).map(|(plans, has_param)| { + SubscriptionPlan::compile(&sql, &tx, auth).map(|(plans, requires_sender_binding)| { Plan::new( plans, - QueryHash::from_string(&sql, auth.caller(), auth.bypass_rls() || has_param), + QueryHash::from_string(&sql, auth.caller(), auth.bypass_rls() || requires_sender_binding), sql, + BindEnv::for_sender_binding(requires_sender_binding, auth.caller()), ) }) }) diff --git a/crates/execution/src/pipelined.rs b/crates/execution/src/pipelined.rs index c9462fc8cad..2abd30c94a5 100644 --- a/crates/execution/src/pipelined.rs +++ b/crates/execution/src/pipelined.rs @@ -8,14 +8,23 @@ use itertools::Either; use spacetimedb_expr::expr::AggType; use spacetimedb_lib::{metrics::ExecutionMetrics, query::Delta, sats::size_of::SizeOf, AlgebraicValue, ProductValue}; use spacetimedb_physical_plan::plan::{ - HashJoin, IxJoin, IxScan, PhysicalExpr, PhysicalPlan, ProjectField, ProjectListPlan, ProjectPlan, Sarg, Semi, - TableScan, TupleField, + HashJoin, IxJoin, IxScan, PhysicalExpr, PhysicalPlan, ProjectField, ProjectListPlan, ProjectPlan, Sarg, SargValue, + Semi, TableScan, TupleField, }; use spacetimedb_primitives::{ColId, ColList, IndexId, TableId}; use spacetimedb_sats::product; use crate::{Datastore, DeltaStore, Row, Tuple}; +fn expect_bound_sarg_value(value: SargValue) -> AlgebraicValue { + match value { + SargValue::Literal(value) => value, + SargValue::Param(id, _) => panic!( + "unbound query parameter {id:?} reached pipelined index scan; bind parameters before constructing a PipelinedProject" + ), + } +} + /// An executor for explicit column projections. /// Note, this plan can only be constructed from the http api, /// which is not considered performance critical. @@ -579,14 +588,17 @@ impl From for PipelinedIxDeltaScanRange { arg: Sarg::Eq(_, v), delta: Some(delta), .. - } => Self { - table_id: schema.table_id, - index_id, - prefix: prefix.into_iter().map(|(_, v)| v).collect(), - lower: Bound::Included(v.clone()), - upper: Bound::Included(v), - delta, - }, + } => { + let v = expect_bound_sarg_value(v); + Self { + table_id: schema.table_id, + index_id, + prefix: prefix.into_iter().map(|(_, v)| v).collect(), + lower: Bound::Included(v.clone()), + upper: Bound::Included(v), + delta, + } + } IxScan { schema, index_id, @@ -706,7 +718,7 @@ impl From for PipelinedIxDeltaScanEq { } => Self { table_id: schema.table_id, index_id, - point: combine_prefix_and_last(prefix, last), + point: combine_prefix_and_last(prefix, expect_bound_sarg_value(last)), delta, }, IxScan { .. } => unreachable!(), @@ -773,14 +785,17 @@ impl From for PipelinedIxScanRange { index_id, prefix, arg: Sarg::Eq(_, v), - } => Self { - table_id: schema.table_id, - index_id, - limit, - prefix: prefix.into_iter().map(|(_, v)| v).collect(), - lower: Bound::Included(v.clone()), - upper: Bound::Included(v), - }, + } => { + let v = expect_bound_sarg_value(v); + Self { + table_id: schema.table_id, + index_id, + limit, + prefix: prefix.into_iter().map(|(_, v)| v).collect(), + lower: Bound::Included(v.clone()), + upper: Bound::Included(v), + } + } IxScan { schema, limit, @@ -909,7 +924,7 @@ impl From for PipelinedIxScanEq { table_id: schema.table_id, index_id, limit, - point: combine_prefix_and_last(prefix, last), + point: combine_prefix_and_last(prefix, expect_bound_sarg_value(last)), }, IxScan { .. } => unreachable!(), } diff --git a/crates/expr/src/check.rs b/crates/expr/src/check.rs index 01e151f4357..ca3bb0304f2 100644 --- a/crates/expr/src/check.rs +++ b/crates/expr/src/check.rs @@ -157,10 +157,9 @@ impl TypeChecker for SubChecker { } /// Parse and type check a subscription query -pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> { +pub fn parse_and_type_sub(sql: &str, tx: &impl SchemaView, _auth: &AuthCtx) -> TypingResult<(ProjectName, bool)> { let ast = parse_subscription(sql)?; let has_param = ast.has_parameter(); - let ast = ast.resolve_sender(auth.caller()); expect_table_type(SubChecker::type_ast(ast, tx)?).map(|plan| (plan, has_param)) } diff --git a/crates/expr/src/expr.rs b/crates/expr/src/expr.rs index c9753e5c195..ee6143c419b 100644 --- a/crates/expr/src/expr.rs +++ b/crates/expr/src/expr.rs @@ -1,11 +1,73 @@ use spacetimedb_data_structures::map::HashSet; -use spacetimedb_lib::{query::Delta, AlgebraicType, AlgebraicValue}; +use spacetimedb_lib::{query::Delta, AlgebraicType, AlgebraicValue, Identity}; use spacetimedb_primitives::{TableId, ViewId}; use spacetimedb_sats::raw_identifier::RawIdentifier; use spacetimedb_schema::{identifier::Identifier, schema::TableOrViewSchema}; use spacetimedb_sql_parser::ast::{BinOp, LogOp}; use std::sync::Arc; +/// A formal parameter slot in a typed query plan. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub struct ParamId(pub u16); + +impl ParamId { + /// The only parameter currently supported by SQL syntax: `:sender`. + pub const SENDER: Self = Self(0); +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum ParamBinding { + Sender(Identity), +} + +impl ParamBinding { + fn value_for_type(&self, ty: &AlgebraicType) -> Option { + match self { + Self::Sender(sender) if ty.is_identity() => Some((*sender).into()), + // Preserve existing `:sender` behavior for legacy filters that compare it to bytes. + Self::Sender(sender) if ty.is_bytes() => Some(AlgebraicValue::Bytes( + sender.to_be_byte_array().to_vec().into_boxed_slice(), + )), + Self::Sender(_) => None, + } + } +} + +/// Runtime parameter bindings for a parameterized query plan. +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct BindEnv { + values: Vec, +} + +impl BindEnv { + pub fn empty() -> Self { + Self::default() + } + + pub fn sender(sender: Identity) -> Self { + Self { + values: vec![ParamBinding::Sender(sender)], + } + } + + pub fn for_sender_binding(requires_sender_binding: bool, sender: Identity) -> Self { + if requires_sender_binding { + Self::sender(sender) + } else { + Self::empty() + } + } + + pub fn get(&self, id: ParamId, ty: &AlgebraicType) -> Option { + self.values.get(id.0 as usize)?.value_for_type(ty) + } + + pub fn expect(&self, id: ParamId, ty: &AlgebraicType, context: &str) -> AlgebraicValue { + self.get(id, ty) + .unwrap_or_else(|| panic!("missing or mistyped binding for query parameter {id:?} while {context}")) + } +} + pub trait CollectViews { fn collect_views(&self, views: &mut HashSet); } @@ -383,6 +445,8 @@ pub enum Expr { LogOp(LogOp, Box, Box), /// A typed literal expression Value(AlgebraicValue, AlgebraicType), + /// A typed runtime parameter. + Param(ParamId, AlgebraicType), /// A field projection Field(FieldProject), } @@ -396,7 +460,7 @@ impl Expr { a.visit(f); b.visit(f); } - Self::Value(..) | Self::Field(..) => {} + Self::Value(..) | Self::Param(..) | Self::Field(..) => {} } } @@ -408,7 +472,7 @@ impl Expr { a.visit_mut(f); b.visit_mut(f); } - Self::Value(..) | Self::Field(..) => {} + Self::Value(..) | Self::Param(..) | Self::Field(..) => {} } } @@ -426,7 +490,7 @@ impl Expr { pub fn ty(&self) -> &AlgebraicType { match self { Self::BinOp(..) | Self::LogOp(..) => &AlgebraicType::Bool, - Self::Value(_, ty) | Self::Field(FieldProject { ty, .. }) => ty, + Self::Value(_, ty) | Self::Param(_, ty) | Self::Field(FieldProject { ty, .. }) => ty, } } } diff --git a/crates/expr/src/lib.rs b/crates/expr/src/lib.rs index 2a6f02f5a21..6a9fcb9570f 100644 --- a/crates/expr/src/lib.rs +++ b/crates/expr/src/lib.rs @@ -9,7 +9,7 @@ use errors::{DuplicateName, InvalidLiteral, InvalidOp, InvalidWildcard, Unexpect use ethnum::i256; use ethnum::u256; use expr::AggType; -use expr::{Expr, FieldProject, ProjectList, ProjectName, RelExpr}; +use expr::{Expr, FieldProject, ParamId, ProjectList, ProjectName, RelExpr}; use spacetimedb_data_structures::map::HashCollectionExt as _; use spacetimedb_data_structures::map::HashSet; use spacetimedb_lib::ser::Serialize; @@ -19,7 +19,7 @@ use spacetimedb_sats::algebraic_type::fmt::fmt_algebraic_type; use spacetimedb_sats::algebraic_value::ser::ValueSerializer; use spacetimedb_sats::uuid::Uuid; use spacetimedb_schema::schema::ColumnSchema; -use spacetimedb_sql_parser::ast::{self, BinOp, ProjectElem, SqlExpr, SqlIdent, SqlLiteral}; +use spacetimedb_sql_parser::ast::{self, BinOp, Parameter, ProjectElem, SqlExpr, SqlIdent, SqlLiteral}; use spacetimedb_sql_parser::parser::recursion; use std::{ops::Deref, str::FromStr}; @@ -99,6 +99,13 @@ fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, d parse(&v, ty).map_err(|_| InvalidLiteral::new(v.into_string(), ty))?, ty.clone(), )), + (SqlExpr::Param(Parameter::Sender), Some(ty)) if ty.is_identity() || ty.is_bytes() => { + Ok(Expr::Param(ParamId::SENDER, ty.clone())) + } + (SqlExpr::Param(Parameter::Sender), Some(ty)) => { + Err(UnexpectedType::new(&AlgebraicType::identity(), ty).into()) + } + (SqlExpr::Param(Parameter::Sender), None) => Err(Unresolved::Literal.into()), (SqlExpr::Field(SqlIdent(table), SqlIdent(field)), expected) => { let table_type = vars.deref().get(&*table).ok_or_else(|| Unresolved::var(&table))?; let ColumnSchema { col_pos, col_type, .. } = table_type @@ -123,7 +130,9 @@ fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, d let b = _type_expr(vars, *b, Some(&AlgebraicType::Bool), depth + 1)?; Ok(Expr::LogOp(op, Box::new(a), Box::new(b))) } - (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) if matches!(&*a, SqlExpr::Lit(_)) => { + (SqlExpr::Bin(a, b, op), None | Some(AlgebraicType::Bool)) + if matches!(&*a, SqlExpr::Lit(_) | SqlExpr::Param(_)) => + { let b = _type_expr(vars, *b, None, depth + 1)?; let a = _type_expr(vars, *a, Some(b.ty()), depth + 1)?; if !op_supports_type(op, a.ty()) { @@ -140,9 +149,8 @@ fn _type_expr(vars: &Relvars, expr: SqlExpr, expected: Option<&AlgebraicType>, d Ok(Expr::BinOp(op, Box::new(a), Box::new(b))) } (SqlExpr::Bin(..) | SqlExpr::Log(..), Some(ty)) => Err(UnexpectedType::new(&AlgebraicType::Bool, ty).into()), - // Both unqualified names as well as parameters are syntactic constructs. - // Unqualified names are qualified and parameters are resolved before type checking. - (SqlExpr::Var(_) | SqlExpr::Param(_), _) => unreachable!(), + // Unqualified names are syntactic constructs qualified before type checking. + (SqlExpr::Var(_), _) => unreachable!(), } } diff --git a/crates/expr/src/rls.rs b/crates/expr/src/rls.rs index 7382bd5e630..b7db1f5c92e 100644 --- a/crates/expr/src/rls.rs +++ b/crates/expr/src/rls.rs @@ -487,7 +487,7 @@ mod tests { use crate::{ check::{parse_and_type_sub, test_utils::build_module_def, SchemaView}, - expr::{Expr, FieldProject, LeftDeepJoin, ProjectName, RelExpr, Relvar}, + expr::{Expr, FieldProject, LeftDeepJoin, ParamId, ProjectName, RelExpr, Relvar}, }; use super::resolve_views_for_sub; @@ -602,7 +602,7 @@ mod tests { field: 0, ty: AlgebraicType::identity(), })), - Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())) + Box::new(Expr::Param(ParamId::SENDER, AlgebraicType::identity())) ) ), "users".into() @@ -649,7 +649,7 @@ mod tests { field: 0, ty: AlgebraicType::identity(), })), - Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())), + Box::new(Expr::Param(ParamId::SENDER, AlgebraicType::identity())), ), )), Expr::BinOp( @@ -700,7 +700,7 @@ mod tests { field: 0, ty: AlgebraicType::identity(), })), - Box::new(Expr::Value(Identity::ONE.into(), AlgebraicType::identity())), + Box::new(Expr::Param(ParamId::SENDER, AlgebraicType::identity())), ), )), Expr::BinOp( diff --git a/crates/expr/src/statement.rs b/crates/expr/src/statement.rs index b7422dd031c..8083d813706 100644 --- a/crates/expr/src/statement.rs +++ b/crates/expr/src/statement.rs @@ -450,8 +450,8 @@ impl TypeChecker for SqlChecker { } } -pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, auth: &AuthCtx) -> TypingResult { - match parse_sql(sql)?.resolve_sender(auth.caller()) { +pub fn parse_and_type_sql(sql: &str, tx: &impl SchemaView, _auth: &AuthCtx) -> TypingResult { + match parse_sql(sql)? { SqlAst::Select(ast) => Ok(Statement::Select(SqlChecker::type_ast(ast, tx)?)), SqlAst::Insert(insert) => Ok(Statement::DML(DML::Insert(type_insert(insert, tx)?))), SqlAst::Delete(delete) => Ok(Statement::DML(DML::Delete(type_delete(delete, tx)?))), diff --git a/crates/physical-plan/src/compile.rs b/crates/physical-plan/src/compile.rs index 1bda4fd65cd..e5b59cca384 100644 --- a/crates/physical-plan/src/compile.rs +++ b/crates/physical-plan/src/compile.rs @@ -21,6 +21,7 @@ fn compile_expr(expr: Expr, var: &mut impl VarLabel) -> PhysicalExpr { PhysicalExpr::BinOp(op, a, b) } Expr::Value(v, _) => PhysicalExpr::Value(v), + Expr::Param(id, ty) => PhysicalExpr::Param(id, ty), Expr::Field(proj) => PhysicalExpr::Field(compile_field_project(var, proj)), } } diff --git a/crates/physical-plan/src/dml.rs b/crates/physical-plan/src/dml.rs index f6a0aa1aa73..fa6eb88626b 100644 --- a/crates/physical-plan/src/dml.rs +++ b/crates/physical-plan/src/dml.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use anyhow::Result; use spacetimedb_expr::{ - expr::{ProjectName, RelExpr, Relvar}, + expr::{BindEnv, ProjectName, RelExpr, Relvar}, statement::{TableDelete, TableInsert, TableUpdate}, }; use spacetimedb_lib::{identity::AuthCtx, AlgebraicValue, ProductValue}; @@ -27,6 +27,15 @@ impl MutationPlan { Self::Update(plan) => Ok(Self::Update(plan.optimize(auth)?)), } } + + /// Replace runtime parameters with bound values. + pub fn bind_params(self, bind_env: &BindEnv) -> Self { + match self { + Self::Insert(..) => self, + Self::Delete(plan) => Self::Delete(plan.bind_params(bind_env)), + Self::Update(plan) => Self::Update(plan.bind_params(bind_env)), + } + } } /// A plan for inserting rows into a table @@ -57,6 +66,12 @@ impl DeletePlan { Ok(Self { table, filter }) } + fn bind_params(self, bind_env: &BindEnv) -> Self { + let Self { table, filter } = self; + let filter = filter.bind_params(bind_env); + Self { table, filter } + } + /// Logical to physical conversion pub(crate) fn compile(delete: TableDelete) -> Self { let TableDelete { table, filter } = delete; @@ -91,6 +106,12 @@ impl UpdatePlan { Ok(Self { columns, table, filter }) } + fn bind_params(self, bind_env: &BindEnv) -> Self { + let Self { table, columns, filter } = self; + let filter = filter.bind_params(bind_env); + Self { columns, table, filter } + } + /// Logical to physical conversion pub(crate) fn compile(update: TableUpdate) -> Self { let TableUpdate { table, columns, filter } = update; diff --git a/crates/physical-plan/src/plan.rs b/crates/physical-plan/src/plan.rs index 0db3e7910ce..c273924a7ed 100644 --- a/crates/physical-plan/src/plan.rs +++ b/crates/physical-plan/src/plan.rs @@ -3,10 +3,12 @@ use derive_more::From; use either::Either; use spacetimedb_data_structures::map::HashSet; use spacetimedb_expr::{ - expr::{AggType, CollectViews}, + expr::{AggType, BindEnv, CollectViews, ParamId}, StatementSource, }; -use spacetimedb_lib::{identity::AuthCtx, query::Delta, sats::size_of::SizeOf, AlgebraicValue, ProductValue}; +use spacetimedb_lib::{ + identity::AuthCtx, query::Delta, sats::size_of::SizeOf, AlgebraicType, AlgebraicValue, ProductValue, +}; use spacetimedb_primitives::{ColId, ColSet, IndexId, TableId, ViewId}; use spacetimedb_schema::schema::{IndexSchema, TableSchema}; use spacetimedb_sql_parser::ast::{BinOp, LogOp}; @@ -128,6 +130,21 @@ impl ProjectPlan { Self::None(plan) | Self::Name(plan, ..) => plan.reads_from_event_table(), } } + + /// Replace runtime parameters with bound values. + pub fn bind_params(self, bind_env: &BindEnv) -> Self { + match self { + Self::None(plan) => Self::None(plan.bind_params(bind_env)), + Self::Name(plan, label, pos) => Self::Name(plan.bind_params(bind_env), label, pos), + } + } + + /// Returns whether this plan contains a runtime parameter. + pub fn requires_param(&self, id: ParamId) -> bool { + match self { + Self::None(plan) | Self::Name(plan, ..) => plan.requires_param(id), + } + } } /// Physical plans always terminate with a projection. @@ -165,6 +182,28 @@ pub enum ProjectListPlan { } impl ProjectListPlan { + /// Replace runtime parameters with bound values. + pub fn bind_params(self, bind_env: &BindEnv) -> Self { + match self { + Self::Name(plans) => Self::Name(plans.into_iter().map(|plan| plan.bind_params(bind_env)).collect()), + Self::List(plans, fields) => Self::List( + plans.into_iter().map(|plan| plan.bind_params(bind_env)).collect(), + fields, + ), + Self::Limit(plan, n) => Self::Limit(Box::new((*plan).bind_params(bind_env)), n), + Self::Agg(plans, agg) => Self::Agg(plans.into_iter().map(|plan| plan.bind_params(bind_env)).collect(), agg), + } + } + + /// Returns whether this plan contains a runtime parameter. + pub fn requires_param(&self, id: ParamId) -> bool { + match self { + Self::Name(plans) => plans.iter().any(|plan| plan.requires_param(id)), + Self::List(plans, _) | Self::Agg(plans, _) => plans.iter().any(|plan| plan.requires_param(id)), + Self::Limit(plan, _) => plan.requires_param(id), + } + } + pub fn optimize(self, auth: &AuthCtx) -> Result { match self { Self::Name(plan) => Ok(Self::Name( @@ -339,6 +378,15 @@ impl PhysicalPlan { ok } + /// Returns whether this plan contains a runtime parameter. + pub fn requires_param(&self, id: ParamId) -> bool { + self.any(&|plan| match plan { + Self::IxScan(scan, _) => scan.arg.requires_param(id), + Self::Filter(_, expr) => expr.requires_param(id), + Self::TableScan(..) | Self::IxJoin(..) | Self::HashJoin(..) | Self::NLJoin(..) => false, + }) + } + /// Applies `f` recursively to all subplans pub fn map(self, f: &impl Fn(Self) -> Self) -> Self { match f(self) { @@ -363,6 +411,69 @@ impl PhysicalPlan { } } + /// Replace runtime parameters with bound values. + pub fn bind_params(self, bind_env: &BindEnv) -> Self { + match self { + Self::TableScan(..) => self, + Self::IxScan(mut scan, label) => { + scan.arg = scan.arg.bind_params(bind_env); + Self::IxScan(scan, label) + } + Self::IxJoin( + IxJoin { + lhs, + rhs, + rhs_label, + rhs_index, + rhs_prefix, + rhs_field, + unique, + lhs_field, + rhs_delta, + }, + semi, + ) => Self::IxJoin( + IxJoin { + lhs: Box::new(lhs.bind_params(bind_env)), + rhs, + rhs_label, + rhs_index, + rhs_prefix, + rhs_field, + unique, + lhs_field, + rhs_delta, + }, + semi, + ), + Self::HashJoin( + HashJoin { + lhs, + rhs, + lhs_field, + rhs_field, + unique, + }, + semi, + ) => Self::HashJoin( + HashJoin { + lhs: Box::new(lhs.bind_params(bind_env)), + rhs: Box::new(rhs.bind_params(bind_env)), + lhs_field, + rhs_field, + unique, + }, + semi, + ), + Self::NLJoin(lhs, rhs) => { + Self::NLJoin(Box::new(lhs.bind_params(bind_env)), Box::new(rhs.bind_params(bind_env))) + } + Self::Filter(input, expr) => { + Self::Filter(Box::new(input.bind_params(bind_env)), expr.bind_params(bind_env)) + } + } + } + /// Applies `f` to a subplan if `ok` returns a match. /// Recurses until an `ok` match is found. pub fn map_if( @@ -474,9 +585,9 @@ impl PhysicalPlan { /// 3. Turn filters into index scans if possible /// 4. Determine index and semijoins /// 5. Compute positions for tuple labels - pub fn optimize(self, auth: &AuthCtx, reqs: Vec